config.h 4.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#pragma once

#include <cutlass/numeric_types.h>
#include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>

#include "defines.h"
#include "params.h"

using namespace cute;

namespace sm90::decode::sparse_fp8 {

template<ModelType MODEL_TYPE, int NUM_HEADS>
class KernelTemplate {
public:

static_assert(NUM_HEADS == 64 || NUM_HEADS == 128);
static constexpr int NUM_M_BLOCKS = NUM_HEADS / 64;
static constexpr int CLUSTER_SIZE = NUM_M_BLOCKS;

static constexpr int HEAD_DIM_K = MODEL_TYPE == ModelType::V32 ? 576 : 512;
static constexpr int HEAD_DIM_V = 512;
static constexpr int HEAD_DIM_ROPE = 64;
static constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE;

static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;
static constexpr int NUM_SCALES = MODEL_TYPE == ModelType::V32 ? 4 : 8;  // For MODEL1: 7 fp8_e4m3 + 1 padding

static constexpr int NUM_THREADS = 128*3;
static constexpr int BLOCK_M = 64;
static constexpr int TOPK_BLOCK_SIZE = 64;
static constexpr int NUM_K_BUFS = 2;

using SmemLayoutQTile = decltype(tile_to_shape(
    GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},
    Shape<Int<BLOCK_M>, Int<64>>{}
));

template<int NUM_TILES>
using SmemLayoutQTiles = decltype(tile_to_shape(
    SmemLayoutQTile{},
    Shape<Int<BLOCK_M>, Int<64*NUM_TILES>>{},
    Step<_1, _2>{}
));

using SmemLayoutQ = SmemLayoutQTiles<HEAD_DIM_K/64>;

using SmemLayoutKTile = decltype(tile_to_shape(
    GMMA::Layout_INTER_Atom<bf16, GMMA::Major::K>{},
    Shape<Int<TOPK_BLOCK_SIZE>, _64>{},
    Step<_1, _2>{}
));

template<int NUM_TILES>
using SmemLayoutKTiles = decltype(tile_to_shape(
    SmemLayoutKTile{},
    Shape<Int<TOPK_BLOCK_SIZE>, Int<64*NUM_TILES>>{},
    Step<_1, _2>{}
));

template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
	SmemLayoutKTiles<NUM_TILES>{},
	Layout<Shape<Int<64*NUM_TILES>, Int<TOPK_BLOCK_SIZE>>, Stride<Int<TOPK_BLOCK_SIZE>, _1>>{}
));

static constexpr int OBUF_SW = 64;
using SmemLayoutOBufAtom = GMMA::Layout_K_SW128_Atom<bf16>;
using SmemLayoutOBuf = decltype(tile_to_shape(
    SmemLayoutOBufAtom{},
    Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
    Step<_1, _2>{}
));

using SmemLayoutOAccumBuf = Layout<
    Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,
    Stride<Int<520>, _1>	// We use stride = 520 here to avoid bank conflict
>;

using SmemLayoutK = SmemLayoutKTiles<HEAD_DIM_K/64>;
using SmemLayoutV = SmemLayoutKTilesTransposed<HEAD_DIM_V/64>;
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<HEAD_DIM_V/64/2>;

using SmemLayoutS = decltype(tile_to_shape(
    GMMA::Layout_K_SW128_Atom<bf16>{},
    Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}
));

struct SharedMemoryPlan {
zhanghj2's avatar
zhanghj2 committed
92
93
94
95
96
97
98
99
100
101
102
    // array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
    // union {
    //     array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
    //     array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
    //     array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
    // } u;
    // CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s;
    // bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];

    // float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M];
    // transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
103
104
};

zhanghj2's avatar
zhanghj2 committed
105
106
107
// template<
//     typename Shape_Q, typename TMA_Q
// >
108

zhanghj2's avatar
zhanghj2 committed
109
110
111
112
// using TiledMMA_QK = decltype(make_tiled_mma(
//     GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
//     Layout<Shape<_1, _1, _1>>{}
// ));
113

zhanghj2's avatar
zhanghj2 committed
114
115
116
117
// using TiledMMA_QK_rQ = decltype(make_tiled_mma(
//     GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
//     Layout<Shape<_1, _1, _1>>{}
// ));
118

zhanghj2's avatar
zhanghj2 committed
119
120
121
122
123
124
125
126
127
// using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
//     GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
//     Layout<Shape<_1, _1, _1>>{}
// ));

// using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
//     GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
//     Layout<Shape<_1, _1, _1>>{}
// ));
128
129
130
131
132
133
134
135
136
137
138










static __device__ __forceinline__ void
zhanghj2's avatar
zhanghj2 committed
139
devfunc(const SparseAttnDecodeParams &params);
140
141
142
143
144
145

static void run(const SparseAttnDecodeParams &params);

};

}