combine.cu 9.82 KB
Newer Older
1
#include "combine.h"
2

zhanghj2's avatar
zhanghj2 committed
3
// #include <math_constants.h>
4
5
6
7
8
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>

9
10
#include <kerutils/kerutils.cuh>

11
12
13
14
15
#include "params.h"
#include "utils.h"

using namespace cute;

16
17
namespace smxx::decode {

18
19
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
zhanghj2's avatar
zhanghj2 committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
flash_fwd_mla_combine_kernel(const CombineParams params) {
    // // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
    // // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
    // static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
    // const int batch_idx = blockIdx.x;
    // const int s_q_idx = blockIdx.y;
    // const int h_block_idx = blockIdx.z;
    // const int warp_idx = threadIdx.x / 32;
    // const int lane_idx = threadIdx.x % 32;

    // int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);
    // if (warp_idx >= num_valid_heads) {
    //     return;
    // }

    // const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
    // const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
    // const int my_num_splits = end_split_idx - start_split_idx;
    // if (my_num_splits == 1) {
    //     return;
    // }
41
    
zhanghj2's avatar
zhanghj2 committed
42
    // FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
43
    
zhanghj2's avatar
zhanghj2 committed
44
45
46
47
48
49
50
51
52
53
    // Tensor gLseAccum = make_tensor(
    //     make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
    //     Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
    //     make_stride(params.stride_lse_accum_split, _1{})
    // );
    // Tensor gLse = make_tensor(
    //     make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),
    //     Shape<Int<BLOCK_SIZE_M>>{},
    //     Stride<_1>{}
    // );
54
    
zhanghj2's avatar
zhanghj2 committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    // __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];

    // // Wait for the previous kernel (the MLA kernel) to finish
    // cudaGridDependencySynchronize();

    // // Prefetch
    // static_assert(HEAD_DIM_V % (32*4) == 0);
    // constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4);
    // float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
    // float4 datas[ELEMS_PER_THREAD];
    // CUTLASS_PRAGMA_UNROLL
    // for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
    //     datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL
    // }

    // // Warp #i gathers LseAccum for seq #i
    // {
    //     constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
    //     float local_lse[NUM_LSE_PER_THREAD];
    //     CUTLASS_PRAGMA_UNROLL
    //     for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
    //         const int split_idx = i*32 + lane_idx;
    //         local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;
    //     }

    //     float max_lse = -INFINITY;
    //     CUTLASS_PRAGMA_UNROLL
    //     for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
    //         max_lse = max(max_lse, local_lse[i]);
    //     CUTLASS_PRAGMA_UNROLL
    //     for (int offset = 16; offset >= 1; offset /= 2)
    //         max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
    //     max_lse = max_lse == -INFINITY ? 0.0f : max_lse;  // In case all local LSEs are -inf

    //     float sum_lse = 0;
    //     CUTLASS_PRAGMA_UNROLL
    //     for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
    //         sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
    //     CUTLASS_PRAGMA_UNROLL
    //     for (int offset = 16; offset >= 1; offset /= 2)
    //         sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);

    //     float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;
    //     if (lane_idx == 0)
    //         gLse(warp_idx) = global_lse / (float)M_LOG2E;
100
        
zhanghj2's avatar
zhanghj2 committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    //     if (params.attn_sink != nullptr) {
    //         int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
    //         float attn_sink = __ldg(params.attn_sink + q_head_idx);
    //         if (global_lse != INFINITY) {
    //             // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
    //             // If attn_sink is -inf, this has no effect on global_lse
    //             global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));
    //         } else {
    //             // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
    //             global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
    //         }
    //     }
    //     CUTLASS_PRAGMA_UNROLL
    //     for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
    //         const int split_idx = i*32 + lane_idx;
    //         smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);
    //     }
    // }

    // __syncwarp();

    // // Warp #i accumulates activation for seq #i
    // {
    //     float4 result[ELEMS_PER_THREAD];
    //     CUTLASS_PRAGMA_UNROLL
    //     for (int i = 0; i < ELEMS_PER_THREAD; ++i)
    //         result[i] = {0.0f, 0.0f, 0.0f, 0.0f};

    //     #pragma unroll 1
    //     for (int split = 0; split < my_num_splits; ++split) {
    //         float lse_scale = smem_buf[warp_idx][split];
    //         // if (lse_scale != 0.f) {
    //         CUTLASS_PRAGMA_UNROLL
    //         for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
    //             result[i].x += lse_scale * datas[i].x;
    //             result[i].y += lse_scale * datas[i].y;
    //             result[i].z += lse_scale * datas[i].z;
    //             result[i].w += lse_scale * datas[i].w;
    //             if (split != my_num_splits-1) {
    //                 datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128);
    //             }
    //         }
    //         // }
    //     }
145
        
zhanghj2's avatar
zhanghj2 committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    //     const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
    //     ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;

    //     CUTLASS_PRAGMA_UNROLL
    //     for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
    //         float4 data = result[i];
    //         ElementT data_converted[4];
    //         data_converted[0] = (ElementT)(data.x);
    //         data_converted[1] = (ElementT)(data.y);
    //         data_converted[2] = (ElementT)(data.z);
    //         data_converted[3] = (ElementT)(data.w);
    //         static_assert(sizeof(ElementT) == 2);
    //         *(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted;
    //     }
    // }
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
}


#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...)       \
    [&] {                                                  \
        if (NUM_SPLITS <= 32) {                            \
            constexpr static int NAME = 32;                \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 64) {                     \
            constexpr static int NAME = 64;                \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 96) {                     \
            constexpr static int NAME = 96;                \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 128) {                    \
            constexpr static int NAME = 128;               \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 160) {                    \
            constexpr static int NAME = 160;               \
            return __VA_ARGS__();                          \
        } else {                                           \
            FLASH_ASSERT(false);                           \
        }                                                  \
    }()


template<typename ElementT>
188
void run_flash_mla_combine_kernel(CombineParams &params) {
189
190
    static constexpr int HEAD_DIM_V = 512;  // Since only this head dimension is supported by Flash MLA
    FLASH_ASSERT(params.d_v == HEAD_DIM_V);
zhanghj2's avatar
zhanghj2 committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    // MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
    //     constexpr int BLOCK_SIZE_M = 8;
    //     constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
    //     constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
    //     auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
    //     CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
    //     // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
    //     cudaLaunchAttribute attribute[1];
    //     attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
    //     attribute[0].val.programmaticStreamSerializationAllowed = 1;
    //     cudaLaunchConfig_t combine_kernel_config = {
    //         dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
    //         dim3(NUM_THREADS, 1, 1),
    //         0,
    //         params.stream,
    //         attribute,
    //         1
    //     };
    //     CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params));
    // });
211
212
213
    CHECK_CUDA_KERNEL_LAUNCH();
}

214
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(CombineParams &params);
215
216

#ifndef FLASH_MLA_DISABLE_FP16
217
218
219
220
template void run_flash_mla_combine_kernel<cutlass::half_t>(CombineParams &params);
#endif

}