flash_bwd_launch_template.h 53.5 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#pragma once

#include <ATen/cuda/CUDAContext.h>

#include "static_switch.h"
#include "flash.h"
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"

// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif

#if defined(DCU_ASM)
    #define ARCH_SUPPORTS_FLASH
#endif

// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");

// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)

DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
    #if defined(ARCH_SUPPORTS_FLASH)
    flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}

DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}

DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dk_dv_trans_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dk_dv_trans_seqk_parallel_16x64<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dk_dv_trans_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}
hly's avatar
hly committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_trans_16x64_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dk_trans_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dv_trans_16x64_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dv_trans_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}
zhangshao's avatar
zhangshao committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_mla_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dk_dv_trans_16x64_mla_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_loop_seqq_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dq_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_loop_16x64_seqq_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dq_seqq_parallel_16x64<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_loop_16x64_prefetch_seqq_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dq_seqq_parallel_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}


template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
    flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}

template<typename Kernel_traits>
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
    flash::clear_dKVaccum<Kernel_traits>(params);
}

template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
    flash::convert_dQ<Kernel_traits>(params, nsplits);
}

template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
    flash::convert_dKV<Kernel_traits>(params);
}

template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
    dim3 grid_m(num_m_block, params.b, params.h);
    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
    int gridDimx = num_n_block;
    // if (params.deterministic) {
    //     auto dprops = at::cuda::getCurrentDeviceProperties();
    //     gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
    // }
    dim3 grid_n(gridDimx, params.b, params.h);
    // printf("run_flash_bwd_seqk_parallel: grid_m=%d, %d, %d, \n", grid_m.x, grid_m.y, grid_m.z);

    flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    // printf("flash_bwd_dot_do_o_kernel done, params.deterministic=%d, params.seqlen_q=%d, params.seqlen_k=%d, \n", 
    //     params.deterministic, params.seqlen_q, params.seqlen_k);

    // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
    // a multiple of kBlockN, we'll need to apply mask in the loop.
    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
    constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
    // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
    // printf("run_flash_bwd_seqk_parallel: grid_n=%d, %d, %d, \n", grid_n.x, grid_n.y, grid_n.z);
    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
                ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
                    SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                        // If Is_local, set Is_causal to false
                        auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<
                            Kernel_traits, 
                            Is_dropout && !Is_softcap, Is_causal, 
                            Is_local && !Is_causal, 
                            Has_alibi, 
                            IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                            IsEvenKConst, 
                            Is_softcap>;
                        
                        kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
                        C10_CUDA_KERNEL_LAUNCH_CHECK();
                    });
                });
            });
        });
    });

    auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
    // if (Kernel_traits::kSmemdQSize >= 48 * 1024)  {
    //     C10_CUDA_CHECK(cudaFuncSetAttribute(
    //         kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
    // }
    kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
#endif
}


template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
    dim3 grid_m(num_m_block, params.b, params.h);
    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
    int gridDimx = num_n_block;
    // if (params.deterministic) {
    //     auto dprops = at::cuda::getCurrentDeviceProperties();
    //     gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
    // }
    dim3 grid_n(gridDimx, params.b, params.h);
    // printf("run_flash_bwd_seqk_parallel: grid_m=%d, %d, %d, \n", grid_m.x, grid_m.y, grid_m.z);

    // if (!params.deterministic) {
    //     flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
    // } else {
        
    // }
    flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    // printf("flash_bwd_dot_do_o_kernel done, params.deterministic=%d, params.seqlen_q=%d, params.seqlen_k=%d, \n", 
    //     params.deterministic, params.seqlen_q, params.seqlen_k);

    // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
    // a multiple of kBlockN, we'll need to apply mask in the loop.
    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
    
    #ifdef BWDTRANS
    constexpr int smem_size_dq_dk_dv = Kernel_trans_traits::kSmemSizeTrans1colblock;
    #else
    constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
    #endif
    constexpr int smem_size_dq = Kernel_traits::kSmemSize1rowblock;
    // printf("smem_size_dq_dk_dv = %d smem_size_dq = %d\n", smem_size_dq_dk_dv, smem_size_dq);
    // printf("run_flash_bwd_seqk_parallel: grid_n=%d, %d, %d, \n", grid_n.x, grid_n.y, grid_n.z);
    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
                ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
                    SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                        // If Is_local, set Is_causal to false
                        #ifdef BWDTRANS
                        auto kernel = &flash_bwd_dk_dv_trans_loop_seqk_parallel_kernel<
                            Kernel_trans_traits, 
                            Is_dropout && !Is_softcap, Is_causal, 
                            Is_local && !Is_causal, 
                            Has_alibi, 
                            IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                            IsEvenKConst, 
                            Is_softcap>;
                        #else
                        auto kernel = &flash_bwd_dk_dv_loop_seqk_parallel_kernel<
                            Kernel_traits, 
                            Is_dropout && !Is_softcap, Is_causal, 
                            Is_local && !Is_causal, 
                            Has_alibi, 
                            IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                            IsEvenKConst, 
                            Is_softcap>;
                        #endif
                        
                        kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
                        C10_CUDA_KERNEL_LAUNCH_CHECK();
                        auto kernel_dq = flash_bwd_dq_loop_seqq_parallel_kernel<
                            Kernel_traits, 
                            Is_dropout && !Is_softcap, Is_causal, 
                            Is_local && !Is_causal, 
                            Has_alibi, 
                            IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                            IsEvenKConst, 
                            Is_softcap>;
                        kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
                        C10_CUDA_KERNEL_LAUNCH_CHECK();

                    });
                });
            });
        });
    });
#endif
}

template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_seqk_parallel_trans(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
    dim3 grid_m(num_m_block, params.h, params.b);
    dim3 grid_m_do(num_m_block, params.b, params.h);
    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
    // if (params.deterministic) {
    //     auto dprops = at::cuda::getCurrentDeviceProperties();
    //     gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
    // }
    dim3 grid_n(num_n_block, params.h, params.b);
    // printf("run_flash_bwd_seqk_parallel: grid_m=%d, %d, %d, \n", grid_m.x, grid_m.y, grid_m.z);

    flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(params);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    // printf("flash_bwd_dot_do_o_kernel done, params.deterministic=%d, params.seqlen_q=%d, params.seqlen_k=%d, \n", 
    //     params.deterministic, params.seqlen_q, params.seqlen_k);

    // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
    // a multiple of kBlockN, we'll need to apply mask in the loop.
    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
    const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
        // printf("is_even_MN = %d Kernel_traits::kBlockN = %d params.seqlen_k = %d\n", is_even_MN, Kernel_traits::kBlockN, params.seqlen_k);

// printf("is_even_MN = %d Kernel_traits::kBlockN = %d\n", is_even_MN, Kernel_traits::kBlockN);
    #if 1
    constexpr int smem_size_dq_dk_dv = Kernel_trans_traits::kSmemSizeTrans1colblock;
    #else
    constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
    #endif
    constexpr int smem_size_dq = Kernel_traits::kSmemSize1rowblock;
    // printf("smem_size_dq_dk_dv = %d smem_size_dq = %d\n", smem_size_dq_dk_dv, smem_size_dq);
    // printf("run_flash_bwd_seqk_parallel: grid_n=%d, %d, %d, \n", grid_n.x, grid_n.y, grid_n.z);
    BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
        BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
            EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
                LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
                    ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
                        SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
                            // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                            // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                            // // If Is_local, set Is_causal to false
                            auto kernel = &flash_bwd_dk_dv_trans_16x64_loop_seqk_parallel_kernel<
                                Kernel_trans_traits, 
                                Is_dropout && !Is_softcap, Is_causal, 
                                Is_local && !Is_causal, 
                                Has_alibi, 
                                IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                                IsEvenKConst, 
                                Is_softcap>;
                        
                            kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
                            C10_CUDA_KERNEL_LAUNCH_CHECK();
                            auto kernel_dq = flash_bwd_dq_loop_16x64_seqq_parallel_kernel<
                                Kernel_traits, 
                                Is_dropout && !Is_softcap, Is_causal, 
                                Is_local && !Is_causal, 
                                Has_alibi, 
                                IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                                IsEvenKConst, 
                                Is_softcap>;
                            kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
                            C10_CUDA_KERNEL_LAUNCH_CHECK();

                        });
                    });
                });
            });
        });
    });
#endif
}

hly's avatar
hly committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
static inline int calc_se_balance_cnt(int b_h_num)
{
    int se_balance_cnt = 1;
    if(b_h_num % 13 == 0){
        se_balance_cnt = 13;
    } else if(b_h_num % 9 == 0){
        se_balance_cnt = 9;
    } else if(b_h_num % 8 == 0){
        se_balance_cnt = 8;
    } else if(b_h_num % 7 ==0){
        se_balance_cnt = 7;
    } else if(b_h_num % 6 ==0){
        se_balance_cnt = 6;
    } else if(b_h_num % 5 ==0){
        se_balance_cnt = 5;
    } else if(b_h_num % 4 ==0){
        se_balance_cnt = 4;
    } else if(b_h_num % 3 ==0){
        se_balance_cnt = 3;
    } else if(b_h_num % 2 ==0){
        se_balance_cnt = 2;
    } else {
        se_balance_cnt = 1;
    }
    return se_balance_cnt;
}
zhangshao's avatar
zhangshao committed
402
403
404
405
406
407
408
409
410
411
412
413
414

template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stream) {
    // const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
#ifndef FLASHATTENTION_DISABLE_BACKWARD    
#ifdef NO_CAUSAL_OPT
    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
#else
    const int non_causal_num_n_block =  (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
    const int num_n_block = (Is_causal && Kernel_trans_traits::kHeadDim != 64) ? (non_causal_num_n_block + 1 ) >> 1 : 
        non_causal_num_n_block;
    const int non_causal_num_m_block =  (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
hly's avatar
hly committed
415
    const int num_m_block =  (Is_causal && Kernel_trans_traits::kHeadDim != 64) ? (non_causal_num_m_block + 1 ) >> 1 : 
zhangshao's avatar
zhangshao committed
416
417
418
419
420
        non_causal_num_m_block;
#endif
    dim3 grid_m(num_m_block, params.h, params.b);
    dim3 grid_m_do((params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM, params.b, params.h);
    dim3 grid_n(num_n_block, params.h, params.b);
hly's avatar
hly committed
421
422
423
424
425
426
427
428
429
430
431
432
    if constexpr (Kernel_trans_traits::kHeadDim == 64 && Is_causal)
    {
        int b_h_num = params.b * params.h;
        params.se_balance_cnt = calc_se_balance_cnt(b_h_num);
        grid_n.x = params.se_balance_cnt;
        grid_n.y = num_n_block;
        grid_n.z = (params.h * params.b/params.se_balance_cnt);

        grid_m.x = params.se_balance_cnt;
        grid_m.y = num_m_block;
        grid_m.z = (params.h * params.b/params.se_balance_cnt);

zhangshao's avatar
zhangshao committed
433

hly's avatar
hly committed
434
    }
zhangshao's avatar
zhangshao committed
435
436
437
438
439
    flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(params);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
    const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
hly's avatar
hly committed
440
    constexpr int smem_size_dropout = Kernel_trans_traits::kHeadDim == 64 ? 4096 : Kernel_trans_traits::kBlockM * Kernel_trans_traits::kBlockN;
zhangshao's avatar
zhangshao committed
441
    constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize;
zhangshao's avatar
update  
zhangshao committed
442
    constexpr int smem_size_dk_dv_total = (Kernel_trans_traits::kHeadDim == 128 || Kernel_trans_traits::kHeadDim == 64) ? (smem_size_dk_dv + smem_size_dropout) : (smem_size_dk_dv);
zhangshao's avatar
zhangshao committed
443
444
445
446
447
448
449
450
451
452
453
454
    constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize;
    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
        // constexpr static bool IsEvenMNConst = false;
        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
            // constexpr static bool IsEvenKConst = true;
            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
                // constexpr static bool Is_local = false;
                ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
                    // constexpr static bool Has_alibi = false;
                    SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
                        // constexpr static bool Is_softcap = false;
                        BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
hly's avatar
hly committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
                            if constexpr (Kernel_trans_traits::kHeadDim == 256) {
                                auto kernel = &flash_bwd_dv_trans_16x64_prefetch<
                                    Kernel_trans_traits, 
                                    Is_dropout && !Is_softcap, Is_causal, 
                                    Is_local && !Is_causal, 
                                    Has_alibi, 
                                    IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                                    IsEvenKConst, 
                                    Is_softcap>;
                                kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
                                auto kernel_dk = &flash_bwd_dk_trans_16x64_prefetch<
                                    Kernel_trans_traits, 
                                    Is_dropout && !Is_softcap, Is_causal, 
                                    Is_local && !Is_causal, 
                                    Has_alibi, 
                                    IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                                    IsEvenKConst, 
                                    Is_softcap>;
                                kernel_dk<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
                                C10_CUDA_KERNEL_LAUNCH_CHECK();
                            } else {
                                auto kernel = &flash_bwd_dk_dv_trans_16x64_prefetch<
                                    Kernel_trans_traits, 
                                    Is_dropout && !Is_softcap, Is_causal, 
                                    Is_local && !Is_causal, 
                                    Has_alibi, 
                                    IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                                    IsEvenKConst, 
                                    Is_softcap>;
                                kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
                                C10_CUDA_KERNEL_LAUNCH_CHECK();
                            }
zhangshao's avatar
zhangshao committed
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
                        });
                        auto kernel_dq = flash_bwd_dq_loop_16x64_prefetch_seqq_parallel_kernel<
                            Kernel_traits, 
                            Is_dropout && !Is_softcap, Is_causal, 
                            Is_local && !Is_causal, 
                            Has_alibi, 
                            IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                            IsEvenKConst, 
                            Is_softcap>;
                        kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
                        C10_CUDA_KERNEL_LAUNCH_CHECK();

                    });
                });
            });
        });
    });
#endif
}

template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_mla_prefetch(Flash_bwd_params &params, cudaStream_t stream) {
    // const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
#ifndef FLASHATTENTION_DISABLE_BACKWARD    
// #ifdef NO_CAUSAL_OPT
    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
// #else
//     const int non_causal_num_n_block =  (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
//     const int num_n_block = Is_causal ? (non_causal_num_n_block + 1 ) >> 1 : 
//         non_causal_num_n_block;
//     const int non_causal_num_m_block =  (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
//     const int num_m_block = Is_causal ? (non_causal_num_m_block + 1 ) >> 1 : 
//         non_causal_num_m_block;
// #endif
    dim3 grid_m(num_m_block, params.h, params.b);
    dim3 grid_m_do((params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM, params.b, params.h);
    dim3 grid_n(num_n_block, params.h, params.b);

    flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(params);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
    const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
    constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize;
    constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize;
    BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
        BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
            EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
                LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
                    ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
                        SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
                            auto kernel = &flash_bwd_dk_dv_trans_16x64_mla_prefetch<
                                Kernel_trans_traits, 
                                Is_dropout && !Is_softcap, Is_causal, 
                                Is_local && !Is_causal, 
                                Has_alibi, 
                                IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                                IsEvenKConst, 
                                Is_softcap>;
                            kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv, stream>>>(params);
                            C10_CUDA_KERNEL_LAUNCH_CHECK();
                            auto kernel_dq = flash_bwd_dq_loop_16x64_prefetch_seqq_parallel_kernel<
                                Kernel_traits, 
                                Is_dropout && !Is_softcap, Is_causal, 
                                Is_local && !Is_causal, 
                                Has_alibi, 
                                IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, 
                                IsEvenKConst, 
                                Is_softcap>;
                            kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
                            C10_CUDA_KERNEL_LAUNCH_CHECK();

                        });
                    });
                });
            });
        });
    });
#endif
}


template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
    run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout, Is_causal>(params, stream);
#endif
}

template<typename Kernel_dq_traits, typename Kernel_dkdv_traits, bool Is_dropout, bool Is_causal>
void run_flash_separate_bwd(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
    run_flash_bwd_separate_seqk_parallel<Kernel_dq_traits, Kernel_dkdv_traits, Is_dropout, Is_causal>(params, stream);
#endif
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
    // printf("run_mha_bwd_hdim32..\n");
    constexpr static int Headdim = 32;
    // int device;
    // cudaGetDevice(&device);
    // int max_smem_per_block;
    // cudaError status_ = cudaDeviceGetAttribute(
    //     &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
    // if (status_ != cudaSuccess) {
    //   C10_CUDA_CHECK(status_);
    // }
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        // if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
        //     if constexpr(!Is_dropout) {  // We can afford more registers to keep V in registers
        //         run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
        //     } else {
        //         run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
        //     }
        // } else {  // 96 KB
        //     run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
        // }
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 32, 4, 1, 1, 1, true, true, T>, Is_dropout, Is_causal>(params, stream);
        
        #ifdef BWDSEPARATE
        using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, 
            /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T>;
        using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
            /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
        run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        // run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
        #else
        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/1,
            /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
        #endif
    });
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
    // printf("run_mha_bwd_hdim64..\n");
    constexpr static int Headdim = 64;
    // int device;
    // cudaGetDevice(&device);
    // int max_smem_per_block;
    // cudaError status_ = cudaDeviceGetAttribute(
    //     &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
    // if (status_ != cudaSuccess) {
    //   C10_CUDA_CHECK(status_);
    // }
    // // printf("max_smem_per_block = %d\n", max_smem_per_block);
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
hly's avatar
hly committed
639
        if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a")
zhangshao's avatar
zhangshao committed
640
        {
hly's avatar
hly committed
641
            using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/Is_dropout ? 128 : 128, /*kNWarps_*/4, T, 3>;
zhangshao's avatar
zhangshao committed
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
            using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, 
            /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T, 3>;
            run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);


        } 
        else 
        {
            using kernel_traits = Flash_bwd_kernel_dq_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, 
            /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T>;
            using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/128, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
            /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
            run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        }


    });
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 96;
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
hly's avatar
hly committed
668
        if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
zhangshao's avatar
zhangshao committed
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
            using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>;
            using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, T, 3>;
            run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        } else {
            #ifdef BWDSEPARATE
            using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, 
                /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T>;
            using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
                /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
            run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
            // run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
            #else
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/1,
                /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
            #endif
        }
    });
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 128;
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
hly's avatar
hly committed
697
        if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a"){
zhangshao's avatar
zhangshao committed
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
            using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>;
            // using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? 64 : 128, /*kBlockN_*/64, /*kNWarps_*/4,
            using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? (Is_causal ? 64 : 128) : 128, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, 
            /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T, 3>;
            run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        } else {
            using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
                /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
            // run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
            if constexpr (std::is_same_v<T, cutlass::bfloat16_t>) {
                using kernel_traits = Flash_bwd_kernel_dq_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, 
                /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T>;
                run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
            }
            else {
                using kernel_traits = Flash_bwd_kernel_dq_16x64_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, 
                /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T>;
                run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
            }
        }
    });
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
    // printf("run_mha_bwd_hdim160..\n");
    constexpr static int Headdim = 160;
    // int device;
    // cudaGetDevice(&device);
    // int max_smem_per_block;
    // cudaError status_ = cudaDeviceGetAttribute(
    //     &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
    // if (status_ != cudaSuccess) {
    //   C10_CUDA_CHECK(status_);
    // }
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        // if (max_smem_per_block >= 116 * 1024) {
        //     run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
        // } else {
        //     run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
        // }
        #ifdef BWDSEPARATE
        using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/32, /*kNWarps_*/2,
            /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false, 
            /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
        using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
            /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
        run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        // run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
        #else
        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
            /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
        #endif
    });
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim192_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 192;
#if 1
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {

hly's avatar
hly committed
766
        if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
zhangshao's avatar
zhangshao committed
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
            // using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>;
            using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_mla_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>;
            using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, 
            /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T, 3, 128>;
            run_flash_bwd_separate_mla_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        } else {
            //static_assert(0, "FA headdim 192 128 only support BW\n");
        }
        // using kernel_traits = Flash_bwd_kernel_dq_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
        // /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false, 
        // /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T, 128>;
        // using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
        //     /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
        //     /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T, 128>;
        // // run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        // run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
});
#endif
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
    // printf("run_mha_bwd_hdim192..\n");
    constexpr static int Headdim = 192;
    // int device;
    // cudaGetDevice(&device);
    // int max_smem_per_block;
    // cudaError status_ = cudaDeviceGetAttribute(
    //     &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
    // if (status_ != cudaSuccess) {
    //   C10_CUDA_CHECK(status_);
    // }
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        // if (max_smem_per_block >= 136 * 1024) {
        //     run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
        // } else {
        //     run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout, Is_causal>(params, stream);
        // }
        #ifdef BWDSEPARATE
        using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/32, /*kNWarps_*/2,
            /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false, 
            /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T>;
        using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
            /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
        run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        // run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
        #else
        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
            /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
        #endif
    });
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
    // printf("run_mha_bwd_hdim224..\n");
    constexpr static int Headdim = 224;
    // int device;
    // cudaGetDevice(&device);
    // int max_smem_per_block;
    // cudaError status_ = cudaDeviceGetAttribute(
    //     &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
    // if (status_ != cudaSuccess) {
    //   C10_CUDA_CHECK(status_);
    // }
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        // if (max_smem_per_block >= 136 * 1024) {
        //     run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
        // } else {
        //     run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout, Is_causal>(params, stream);
        // }
        #ifdef BWDSEPARATE
        using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/32, /*kNWarps_*/2,
            /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false, 
            /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T>;
        using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
            /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
        run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        // run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
        #else
        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
            /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
            /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
        #endif
    });
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 256;
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
hly's avatar
hly committed
862
        if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
zhangshao's avatar
zhangshao committed
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
            // printf("%s:%d\n", __FILE__, __LINE__);
            using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>;
            using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>;
            run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        } else {
            #ifdef BWDSEPARATE
            using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/16, /*kNWarps_*/2,
                /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false, 
                /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T>;
            using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
                /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
            run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
            // run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
            #else
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
                /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
            #endif
        }
    });
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim512(Flash_bwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 512;
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
hly's avatar
hly committed
890
        if (get_device_name() == "gfx936" || get_device_name() == "gfx938"|| get_device_name() == "gfx92a") {
zhangshao's avatar
zhangshao committed
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
            // printf("%s:%d\n", __FILE__, __LINE__);
            using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>;
            using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>;
            run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
        } else {
            #ifdef BWDSEPARATE
            using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/16, /*kNWarps_*/2,
                /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false, 
                /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true,  T>;
            using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
                /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
            run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
            // run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
            #else
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
                /*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
                /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
            #endif
        }
    });
}