traits.h 3.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#pragma once

#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/barrier.h>

#include "config.h"

using TMABarrier = cutlass::arch::ClusterTransactionBarrier;
using namespace cute;

template<typename InputT_>
struct Traits {
    using InputT = InputT_;
    
    static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
    static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
    static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
    static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;

    static constexpr int NUM_THREADS = 256;

    static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);

    using TiledMMA_QK_sQ = decltype(make_tiled_mma(
        GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
        Layout<Shape<_1, _1, _1>>{}
    ));

    using TiledMMA_QK_rQ = decltype(make_tiled_mma(
        GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
        Layout<Shape<_1, _1, _1>>{}
    ));

    using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
        GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
        Layout<Shape<_1, _1, _1>>{}
    ));

    using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
        GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
        Layout<Shape<_1, _1, _1>>{}
    ));

    using SmemLayoutQ = decltype(tile_to_shape(
        GMMA::Layout_K_SW128_Atom<InputT>{},
        Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_K>>{}
    ));

    using SmemLayoutK = decltype(tile_to_shape(
        GMMA::Layout_K_SW128_Atom<InputT>{},
        Shape<Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>{}
    ));

    using SmemLayoutV = decltype(composition(
        SmemLayoutK{},
        make_layout(Shape<Int<HEAD_DIM_V>, Int<PAGE_BLOCK_SIZE>>{}, GenRowMajor{})
    ));	// A transposed version of SmemLayoutK

    using SmemLayoutP0 = decltype(tile_to_shape(
        GMMA::Layout_K_SW128_Atom<InputT>{},
        Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
    ));

    using rP0Layout = decltype(layout(partition_fragment_C(
        TiledMMA_QK_sQ{},
        Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
    )));

    struct SharedMemoryPlan {
        cute::array_aligned<InputT, cosize_v<SmemLayoutQ>> smem_sQ;
        cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK0;
        cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK1;
        cute::array_aligned<InputT, cosize_v<SmemLayoutP0>> smem_sP0;
        cute::array_aligned<float, BLOCK_SIZE_M> smem_sM;
        cute::array_aligned<float, 2*BLOCK_SIZE_M> sL_reduction_wksp;
        cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale0;
        cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale1;
        TMABarrier barriers_K0[HEAD_DIM_K/64];
        TMABarrier barriers_K1[HEAD_DIM_K/64];
        TMABarrier barrier_Q;
    };

};

template<
    typename ShapeQ, typename TMA_Q,
    typename ShapeK, typename TMA_K,
    typename ShapeO, typename TMA_O
>
struct TmaParams {
    ShapeQ shape_Q;
    TMA_Q tma_Q;
    ShapeK shape_K;
    TMA_K tma_K;
    ShapeO shape_O;
    TMA_O tma_O;
};

enum NamedBarriers : int {
    sScale0Ready = 0,
    sScale1Ready = 1,
    sP0Ready = 2,
ljss's avatar
ljss committed
105
106
    rO1sP0sV0RIssued = 3,
    sMInitialized = 4,
107
};