combine.cu 12.5 KB
Newer Older
1
#include "combine.h"
2

zhanghj2's avatar
zhanghj2 committed
3
// #include <math_constants.h>
4
5
6
7
8
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>

9
10
#include <kerutils/kerutils.cuh>

11
12
#include "params.h"
#include "utils.h"
zhanghj2's avatar
zhanghj2 committed
13
#define CUDART_L2E_F            1.442695041F
14
15
16

using namespace cute;

17
18
namespace smxx::decode {

19
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
zhanghj2's avatar
zhanghj2 committed
20
__global__ void __launch_bounds__(NUM_THREADS, 1)
zhanghj2's avatar
zhanghj2 committed
21
flash_fwd_mla_combine_kernel(const CombineParams params) {
zhanghj2's avatar
zhanghj2 committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
    // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
    static_assert(NUM_THREADS/64 == BLOCK_SIZE_M); // The number of warps == block_size_m
    const int batch_idx = blockIdx.x;
    const int s_q_idx = blockIdx.y;
    const int h_block_idx = blockIdx.z;
    const int warp_idx = threadIdx.x / 64;
    const int lane_idx = threadIdx.x % 64;

    int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);
    if (warp_idx >= num_valid_heads) {
        return;
    }

    const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
    const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
    const int my_num_splits = end_split_idx - start_split_idx;
    if (my_num_splits == 1) {
        return;
    }
42
    
zhanghj2's avatar
zhanghj2 committed
43
    // FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
44
    
zhanghj2's avatar
zhanghj2 committed
45
46
47
48
49
50
51
52
53
54
    Tensor gLseAccum = make_tensor(
        make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
        Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
        make_stride(params.stride_lse_accum_split, _1{})
    );
    Tensor gLse = make_tensor(
        make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),
        Shape<Int<BLOCK_SIZE_M>>{},
        Stride<_1>{}
    );
55
    
zhanghj2's avatar
zhanghj2 committed
56
    __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];
zhanghj2's avatar
zhanghj2 committed
57
58
    // __syncthreads();
    
zhanghj2's avatar
zhanghj2 committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    // Warp #i gathers LseAccum for seq #i
    {
        constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 64);
        float local_lse[NUM_LSE_PER_THREAD];
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
            const int split_idx = i*64 + lane_idx;
            local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;
        }

        float max_lse = -INFINITY;
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
            max_lse = max(max_lse, local_lse[i]);
        CUTLASS_PRAGMA_UNROLL
        for (int offset = 32; offset >= 1; offset /= 2)
            max_lse = max(max_lse, __shfl_xor(max_lse, offset));
        max_lse = max_lse == -INFINITY ? 0.0f : max_lse;  // In case all local LSEs are -inf

        float sum_lse = 0;
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
zhanghj2's avatar
zhanghj2 committed
81
            sum_lse = sum_lse + __builtin_amdgcn_exp2f((local_lse[i] - max_lse) * CUDART_L2E_F);
zhanghj2's avatar
zhanghj2 committed
82
83
84
85
        CUTLASS_PRAGMA_UNROLL
        for (int offset = 32; offset >= 1; offset /= 2)
            sum_lse = sum_lse + __shfl_xor(sum_lse, offset);

zhanghj2's avatar
zhanghj2 committed
86
        float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : logf(sum_lse) + max_lse;
zhanghj2's avatar
zhanghj2 committed
87
        if (lane_idx == 0)
zhanghj2's avatar
zhanghj2 committed
88
            gLse(warp_idx) = global_lse;
89
        float o_scale = 1.0f;
zhanghj2's avatar
zhanghj2 committed
90
91
92
        if (params.attn_sink != nullptr) {
            int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
            float attn_sink = __ldg(params.attn_sink + q_head_idx);
zhanghj2's avatar
zhanghj2 committed
93
94
95
96
97
98
            if (flash::is_positive_infinity(attn_sink))
            {
                o_scale = 0.0f;
            }
            else
            {
zhanghj2's avatar
zhanghj2 committed
99
100
101
102
103
104
                if (!flash::is_positive_infinity(global_lse))
                {
                    float Attn_sink_exp2 = __builtin_amdgcn_exp2f(attn_sink * CUDART_L2E_F);
                    float lse_exp2 = __builtin_amdgcn_exp2f(global_lse * CUDART_L2E_F);
                    o_scale = lse_exp2 / (lse_exp2 + Attn_sink_exp2);
                }
zhanghj2's avatar
zhanghj2 committed
105
106
            }

zhanghj2's avatar
zhanghj2 committed
107
108
109
110
111
112
113
114
115

            // if (global_lse != INFINITY) {
            //     // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
            //     // If attn_sink is -inf, this has no effect on global_lse
            //     global_lse += log2f(1 + __builtin_amdgcn_exp2f(attn_sink*CUDART_L2E_F - global_lse));
            // } else {
            //     // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
            //     global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
            // }
zhanghj2's avatar
zhanghj2 committed
116
117
118
119
        }
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
            const int split_idx = i*64 + lane_idx;
zhanghj2's avatar
zhanghj2 committed
120
121
122
123
            if (split_idx < my_num_splits) {
                // printf("local_lse %.2f  global_lse = %.2f \n", local_lse[i], global_lse);
                smem_buf[warp_idx][split_idx] = __builtin_amdgcn_exp2f((local_lse[i] - global_lse) * CUDART_L2E_F) * o_scale;
            }
zhanghj2's avatar
zhanghj2 committed
124
125
126
127
        }
    }

    __syncthreads();
zhanghj2's avatar
zhanghj2 committed
128
129
    static_assert(HEAD_DIM_V % (64*4) == 0);
    constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (64*4);
zhanghj2's avatar
zhanghj2 committed
130
    static_assert(ELEMS_PER_THREAD == 2);
zhanghj2's avatar
zhanghj2 committed
131
132
133
134
    float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
    float4 datas[ELEMS_PER_THREAD];
    CUTLASS_PRAGMA_UNROLL
    for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
zhanghj2's avatar
zhanghj2 committed
135
        datas[i] = *(float4*)(oaccum_ptr + lane_idx*8 + i*4); // NOTE We don't use __ldg here since it is incompatible with PDL
zhanghj2's avatar
zhanghj2 committed
136
    }
zhanghj2's avatar
zhanghj2 committed
137
138
139
140
141
142
143
144
145
146
    // Warp #i accumulates activation for seq #i
    {
        float4 result[ELEMS_PER_THREAD];
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < ELEMS_PER_THREAD; ++i)
            result[i] = {0.0f, 0.0f, 0.0f, 0.0f};

        #pragma unroll 1
        for (int split = 0; split < my_num_splits; ++split) {
            float lse_scale = smem_buf[warp_idx][split];
zhanghj2's avatar
zhanghj2 committed
147
148
149
150
            // if (warp_idx == 2 && threadIdx.x == 128)
            // {
            //     printf("threadIdx.x = %d %.3f %.3f lse_scale = %.2f \n",threadIdx.x, datas[0].x, datas[1].x, lse_scale);
            // }
zhanghj2's avatar
zhanghj2 committed
151
152
153
154
155
156
157
158
            // if (lse_scale != 0.f) {
            CUTLASS_PRAGMA_UNROLL
            for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
                result[i].x += lse_scale * datas[i].x;
                result[i].y += lse_scale * datas[i].y;
                result[i].z += lse_scale * datas[i].z;
                result[i].w += lse_scale * datas[i].w;
                if (split != my_num_splits-1) {
zhanghj2's avatar
zhanghj2 committed
159
                    datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*8 + i*4);
zhanghj2's avatar
zhanghj2 committed
160
161
162
163
                }
            }
            // }
        }
zhanghj2's avatar
zhanghj2 committed
164
165
166
167
        // if (warp_idx == 2)
        // {
        //     printf(" %.3f \n", result[0].x);
        // }
zhanghj2's avatar
zhanghj2 committed
168
169
170
171
172
173
        auto float2bf16 = [] (float s) -> uint16_t {
            uint32_t x32 = reinterpret_cast<uint32_t const &>(s);
            x32 += 0x8000u;
            return uint16_t(x32 >> 16);
        };

zhanghj2's avatar
zhanghj2 committed
174
        const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
zhanghj2's avatar
zhanghj2 committed
175
176
177
        ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q + lane_idx * 8;
        ElementT data_converted[8];
        using result_type = cutlass::Array<ElementT, 2>;
zhanghj2's avatar
zhanghj2 committed
178
        for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
zhanghj2's avatar
zhanghj2 committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
            if constexpr(std::is_same_v<cutlass::bfloat16_t, ElementT>) {
                #if defined(__gfx938__)
                auto d0 =  __builtin_hcu_cvt_pk_bf16_f32(0, result[i].x, 0, result[i].y, 0);
                auto d1 =  __builtin_hcu_cvt_pk_bf16_f32(0, result[i].z, 0, result[i].w, 0);
                auto res0 = reinterpret_cast<result_type const &>(d0);
                auto res1 = reinterpret_cast<result_type const &>(d1);
                o_ptr[i * 4] = res0[0]; 
                o_ptr[i * 4 + 1] = res0[1]; 
                o_ptr[i * 4 + 2] = res1[0]; 
                o_ptr[i * 4 + 3] = res1[1]; 
                #else
                // auto float32_to_bfloat16 = [&](float v) -> ElementT {
                //     union {
                //         float fp32;
                //         uint32_t int32;
                //     } u = {v};
                //     ElementT res;
                //     res.storage = (u.int32 >> 16);
                //     return res;
                // };

zhanghj2's avatar
zhanghj2 committed
200
                float4 data = result[i];
zhanghj2's avatar
zhanghj2 committed
201
202
203
204
                // o_ptr[i * 4] = float32_to_bfloat16((data.x));
                // o_ptr[i * 4 + 1] = float32_to_bfloat16((data.y));
                // o_ptr[i * 4 + 2] = float32_to_bfloat16((data.z));
                // o_ptr[i * 4 + 3] = float32_to_bfloat16((data.w));
zhanghj2's avatar
zhanghj2 committed
205
206
207
208
                o_ptr[i * 4].storage = float2bf16(data.x);
                o_ptr[i * 4 + 1].storage = float2bf16(data.y);
                o_ptr[i * 4 + 2].storage = float2bf16(data.z);
                o_ptr[i * 4 + 3].storage = float2bf16(data.w);
zhanghj2's avatar
zhanghj2 committed
209
210
211
212
213
214
215
216
217
218
219
220
                #endif
            } else {
                auto d0 =  __builtin_hcu_cvt_pkrtz(result[i].x, result[i].y);
                auto d1 =  __builtin_hcu_cvt_pkrtz(result[i].z, result[i].w);
                auto res0 = reinterpret_cast<result_type const &>(d0);
                auto res1 = reinterpret_cast<result_type const &>(d1);
                o_ptr[i * 4] = res0[0]; 
                o_ptr[i * 4 + 1] = res0[1]; 
                o_ptr[i * 4 + 2] = res1[0]; 
                o_ptr[i * 4 + 3] = res1[1]; 
            }

zhanghj2's avatar
zhanghj2 committed
221
222
        }
    }
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
}


#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...)       \
    [&] {                                                  \
        if (NUM_SPLITS <= 32) {                            \
            constexpr static int NAME = 32;                \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 64) {                     \
            constexpr static int NAME = 64;                \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 96) {                     \
            constexpr static int NAME = 96;                \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 128) {                    \
            constexpr static int NAME = 128;               \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 160) {                    \
            constexpr static int NAME = 160;               \
            return __VA_ARGS__();                          \
        } else {                                           \
            FLASH_ASSERT(false);                           \
        }                                                  \
    }()


template<typename ElementT>
250
void run_flash_mla_combine_kernel(CombineParams &params) {
251
252
    static constexpr int HEAD_DIM_V = 512;  // Since only this head dimension is supported by Flash MLA
    FLASH_ASSERT(params.d_v == HEAD_DIM_V);
zhanghj2's avatar
zhanghj2 committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
        constexpr int BLOCK_SIZE_M = 4;
        constexpr int NUM_THREADS = BLOCK_SIZE_M*64;
        constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
        auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
        // CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
        // // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
        // cudaLaunchAttribute attribute[1];
        // attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
        // attribute[0].val.programmaticStreamSerializationAllowed = 1;
        // cudaLaunchConfig_t combine_kernel_config = {
        //     dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
        //     dim3(NUM_THREADS, 1, 1),
        //     0,
        //     params.stream,
        //     attribute,
        //     1
        // };
        combine_kernel<<<dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
zhanghj2's avatar
zhanghj2 committed
272
            NUM_THREADS,
zhanghj2's avatar
zhanghj2 committed
273
274
275
            smem_size,
            params.stream>>>(params);
    });
276
277
278
    CHECK_CUDA_KERNEL_LAUNCH();
}

279
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(CombineParams &params);
280
281

#ifndef FLASH_MLA_DISABLE_FP16
282
283
284
285
template void run_flash_mla_combine_kernel<cutlass::half_t>(CombineParams &params);
#endif

}