config.h 6.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#pragma once

#include <cutlass/numeric_types.h>
#include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>

#include "defines.h"
#include "params.h"

using namespace cute;

namespace sm90::decode::sparse_fp8 {

template<ModelType MODEL_TYPE, int NUM_HEADS>
class KernelTemplate {
public:

zhanghj2's avatar
zhanghj2 committed
19
20
21
22
23
static_assert(NUM_HEADS == 64 || NUM_HEADS == 128 || NUM_HEADS == 16);
// todo only support tp8
static constexpr int BLOCK_M = 16;
static constexpr int NUM_M_BLOCKS = NUM_HEADS / BLOCK_M;
static constexpr bool Is_causal = false;
24
25
26
27
28
29
30
31
static constexpr int HEAD_DIM_K = MODEL_TYPE == ModelType::V32 ? 576 : 512;
static constexpr int HEAD_DIM_V = 512;
static constexpr int HEAD_DIM_ROPE = 64;
static constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE;

static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;
static constexpr int NUM_SCALES = MODEL_TYPE == ModelType::V32 ? 4 : 8;  // For MODEL1: 7 fp8_e4m3 + 1 padding

zhanghj2's avatar
zhanghj2 committed
32
static constexpr int NUM_THREADS = 256;
33
static constexpr int TOPK_BLOCK_SIZE = 64;
zhanghj2's avatar
zhanghj2 committed
34
35
36
37
38
using elem_type = cutlass::bfloat16_t;
using MMA_Atom_Arch = std::conditional_t<
    std::is_same_v<elem_type, cutlass::half_t>,
    MMA_Atom<GFX928_16x16x64_F32F16F16F32_NT>,
    MMA_Atom<GFX928_16x16x64_F32BF16BF16F32_NT>
39
>;
zhanghj2's avatar
zhanghj2 committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
static constexpr int kNWarps = 4;
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
using TiledMma = TiledMMA<
    MMA_Atom_Arch,
    Layout<Shape<_1, Int<kNWarps>, _1>>,  // 1x4x1 or 1x8x1 thread group
    ValLayoutMNK>;

using MMA_Atom_Arch_16_16_32 = std::conditional_t<
    std::is_same_v<elem_type, cutlass::half_t>,
    MMA_Atom<GFX928_16x16x32_F32F16F16F32_NN>,
    MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NN>
>;
using TiledMma_16_16_32 = TiledMMA<
    MMA_Atom_Arch_16_16_32,
    Layout<Shape<_1, Int<kNWarps>, _1>>,  // 1x4x1 or 1x8x1 thread group
    ValLayoutMNK>;

using MMA_Atom_Arch_16x32_NT = std::conditional_t<
    std::is_same_v<elem_type, cutlass::half_t>,
    MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
    MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
>;
using TiledMma_O = TiledMMA<
    MMA_Atom_Arch_16x32_NT,
    Layout<Shape<_1, Int<kNWarps>, _1>>,  // 1x4x1 or 1x8x1 thread group
    ValLayoutMNK>;

using SmemLayoutAtomK = decltype(composition(
    Swizzle<3, 3, 3>{},
    Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
    SmemLayoutAtomK{},
    Shape<Int<TOPK_BLOCK_SIZE>, Int<8 * 32>>{}));

using SmemLayoutAtomV = SmemLayoutAtomK;   
using SmemLayoutV = decltype(tile_to_shape(
    SmemLayoutAtomV{},
    Shape<Int<TOPK_BLOCK_SIZE>, Int<512>>{}));

using SmemLayoutVtransposed = decltype(
    composition(SmemLayoutV{}, make_layout(Shape<Int<512>, Int<TOPK_BLOCK_SIZE>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));

using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
    SmemLayoutAtomP{},
    Shape<Int<4*16*16>>{}));
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>; 

using Element = cutlass::bfloat16_t;
using ElementAccum = float;
91
struct SharedMemoryPlan {
zhanghj2's avatar
zhanghj2 committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    union {
        struct {
            cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
        };
        struct {
            // cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutV_tmp>> smem_v_tmp;  // Double buffer
            cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
            cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
            cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;

        };
        // struct {
        //     cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
        //     // cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
        //     // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
        //     // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
        // };
        // struct {
        //     cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
        // };
    };

zhanghj2's avatar
zhanghj2 committed
114
115
116
117
118
119
120
121
122
123
124
    // array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
    // union {
    //     array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
    //     array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
    //     array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
    // } u;
    // CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s;
    // bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];

    // float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M];
    // transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
125
126
};

zhanghj2's avatar
zhanghj2 committed
127
128
129
// template<
//     typename Shape_Q, typename TMA_Q
// >
130

zhanghj2's avatar
zhanghj2 committed
131
132
133
134
// using TiledMMA_QK = decltype(make_tiled_mma(
//     GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
//     Layout<Shape<_1, _1, _1>>{}
// ));
135

zhanghj2's avatar
zhanghj2 committed
136
137
138
139
// using TiledMMA_QK_rQ = decltype(make_tiled_mma(
//     GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
//     Layout<Shape<_1, _1, _1>>{}
// ));
140

zhanghj2's avatar
zhanghj2 committed
141
142
143
144
145
146
147
148
149
// using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
//     GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
//     Layout<Shape<_1, _1, _1>>{}
// ));

// using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
//     GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
//     Layout<Shape<_1, _1, _1>>{}
// ));
150
151
152
153
154
155






zhanghj2's avatar
zhanghj2 committed
156
157
static __device__ __forceinline__ void
compute_attn_1rowblock_splitkv_sparse_mla_fp8(const SparseAttnDecodeParams &params, const DecodingSchedMeta& sched_meta, int batch_idx);
158
159

static __device__ __forceinline__ void
zhanghj2's avatar
zhanghj2 committed
160
devfunc(const SparseAttnDecodeParams &params);
161
162
163
164
165
166

static void run(const SparseAttnDecodeParams &params);

};

}