combine.cu 9.41 KB
Newer Older
1
#include "combine.h"
2

3
#include <math_constants.h>
4
5
6
7
8
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>

9
10
#include <kerutils/kerutils.cuh>

11
12
13
14
15
#include "params.h"
#include "utils.h"

using namespace cute;

16
17
namespace smxx::decode {

18
19
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
20
21
flash_fwd_mla_combine_kernel(__grid_constant__ const CombineParams params) {
    // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
22
23
24
    // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
    static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
    const int batch_idx = blockIdx.x;
25
26
    const int s_q_idx = blockIdx.y;
    const int h_block_idx = blockIdx.z;
27
28
29
    const int warp_idx = threadIdx.x / 32;
    const int lane_idx = threadIdx.x % 32;

30
31
32
33
34
    int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);
    if (warp_idx >= num_valid_heads) {
        return;
    }

35
36
37
38
39
40
41
    const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
    const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
    const int my_num_splits = end_split_idx - start_split_idx;
    if (my_num_splits == 1) {
        return;
    }
    
42
43
    FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
    
44
    Tensor gLseAccum = make_tensor(
45
        make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
46
        Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
47
        make_stride(params.stride_lse_accum_split, _1{})
48
49
    );
    Tensor gLse = make_tensor(
50
        make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),
51
52
53
54
        Shape<Int<BLOCK_SIZE_M>>{},
        Stride<_1>{}
    );
    
55
56
    __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];

57
58
59
    // Wait for the previous kernel (the MLA kernel) to finish
    cudaGridDependencySynchronize();

60
61
62
63
64
65
66
67
68
    // Prefetch
    static_assert(HEAD_DIM_V % (32*4) == 0);
    constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4);
    float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
    float4 datas[ELEMS_PER_THREAD];
    CUTLASS_PRAGMA_UNROLL
    for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
        datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL
    }
69
70
71
72
73
74
75
76

    // Warp #i gathers LseAccum for seq #i
    {
        constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
        float local_lse[NUM_LSE_PER_THREAD];
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
            const int split_idx = i*32 + lane_idx;
77
            local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        }

        float max_lse = -INFINITY;
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
            max_lse = max(max_lse, local_lse[i]);
        CUTLASS_PRAGMA_UNROLL
        for (int offset = 16; offset >= 1; offset /= 2)
            max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
        max_lse = max_lse == -INFINITY ? 0.0f : max_lse;  // In case all local LSEs are -inf

        float sum_lse = 0;
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
            sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
        CUTLASS_PRAGMA_UNROLL
        for (int offset = 16; offset >= 1; offset /= 2)
            sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);

97
        float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;
98
99
        if (lane_idx == 0)
            gLse(warp_idx) = global_lse / (float)M_LOG2E;
100
101
102
103
104
105
106
107
108
109
110
111
112
        
        if (params.attn_sink != nullptr) {
            int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
            float attn_sink = __ldg(params.attn_sink + q_head_idx);
            if (global_lse != INFINITY) {
                // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
                // If attn_sink is -inf, this has no effect on global_lse
                global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));
            } else {
                // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
                global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
            }
        }
113
114
115
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
            const int split_idx = i*32 + lane_idx;
116
            smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);
117
118
119
120
121
122
123
        }
    }

    __syncwarp();

    // Warp #i accumulates activation for seq #i
    {
124
        float4 result[ELEMS_PER_THREAD];
125
126
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < ELEMS_PER_THREAD; ++i)
127
            result[i] = {0.0f, 0.0f, 0.0f, 0.0f};
128

129
        #pragma unroll 1
130
        for (int split = 0; split < my_num_splits; ++split) {
131
132
133
134
135
136
137
138
139
140
            float lse_scale = smem_buf[warp_idx][split];
            // if (lse_scale != 0.f) {
            CUTLASS_PRAGMA_UNROLL
            for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
                result[i].x += lse_scale * datas[i].x;
                result[i].y += lse_scale * datas[i].y;
                result[i].z += lse_scale * datas[i].z;
                result[i].w += lse_scale * datas[i].w;
                if (split != my_num_splits-1) {
                    datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128);
141
142
                }
            }
143
            // }
144
145
        }
        
146
147
        const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
        ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;
148
149

        CUTLASS_PRAGMA_UNROLL
150
151
152
153
154
155
156
157
158
159
        for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
            float4 data = result[i];
            ElementT data_converted[4];
            data_converted[0] = (ElementT)(data.x);
            data_converted[1] = (ElementT)(data.y);
            data_converted[2] = (ElementT)(data.z);
            data_converted[3] = (ElementT)(data.w);
            static_assert(sizeof(ElementT) == 2);
            *(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted;
        }
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    }
}


#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...)       \
    [&] {                                                  \
        if (NUM_SPLITS <= 32) {                            \
            constexpr static int NAME = 32;                \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 64) {                     \
            constexpr static int NAME = 64;                \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 96) {                     \
            constexpr static int NAME = 96;                \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 128) {                    \
            constexpr static int NAME = 128;               \
            return __VA_ARGS__();                          \
        } else if (NUM_SPLITS <= 160) {                    \
            constexpr static int NAME = 160;               \
            return __VA_ARGS__();                          \
        } else {                                           \
            FLASH_ASSERT(false);                           \
        }                                                  \
    }()


template<typename ElementT>
188
void run_flash_mla_combine_kernel(CombineParams &params) {
189
190
    static constexpr int HEAD_DIM_V = 512;  // Since only this head dimension is supported by Flash MLA
    FLASH_ASSERT(params.d_v == HEAD_DIM_V);
191
192
193
194
    MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
        constexpr int BLOCK_SIZE_M = 8;
        constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
        constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
195
        auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
196
197
198
199
200
201
        CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
        // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
        cudaLaunchAttribute attribute[1];
        attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
        attribute[0].val.programmaticStreamSerializationAllowed = 1;
        cudaLaunchConfig_t combine_kernel_config = {
202
            dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
203
            dim3(NUM_THREADS, 1, 1),
204
205
            0,
            params.stream,
206
207
208
            attribute,
            1
        };
209
        CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params));
210
211
212
213
    });
    CHECK_CUDA_KERNEL_LAUNCH();
}

214
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(CombineParams &params);
215
216

#ifndef FLASH_MLA_DISABLE_FP16
217
218
219
220
template void run_flash_mla_combine_kernel<cutlass::half_t>(CombineParams &params);
#endif

}