flash_bwd_launch_template.h 16.9 KB
Newer Older
1
2
3
/******************************************************************************
 * Copyright (c) 2024, Tri Dao.
 ******************************************************************************/
Tri Dao's avatar
Tri Dao committed
4
5
6
7
8
9
10

#pragma once

#include <ATen/cuda/CUDAContext.h>

#include "static_switch.h"
#include "flash.h"
11
#include "flash_bwd_preprocess_kernel.h"
Tri Dao's avatar
Tri Dao committed
12
13
#include "flash_bwd_kernel.h"

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif

// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");

// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)

DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
    #if defined(ARCH_SUPPORTS_FLASH)
       flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}

DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
    #if defined(ARCH_SUPPORTS_FLASH)
        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
        flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
    #else
        FLASH_UNSUPPORTED_ARCH
    #endif
}


Tri Dao's avatar
Tri Dao committed
48
template<bool Clear_dQaccum=true, typename Kernel_traits>
49
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
Tri Dao's avatar
Tri Dao committed
50
51
52
53
    flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}

template<typename Kernel_traits>
54
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
Tri Dao's avatar
Tri Dao committed
55
56
57
58
    flash::clear_dKVaccum<Kernel_traits>(params);
}

template<typename Kernel_traits>
59
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
60
    flash::convert_dQ<Kernel_traits>(params, nsplits);
Tri Dao's avatar
Tri Dao committed
61
62
63
}

template<typename Kernel_traits>
64
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
Tri Dao's avatar
Tri Dao committed
65
66
67
68
    flash::convert_dKV<Kernel_traits>(params);
}

template<typename Kernel_traits, bool Is_dropout>
69
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
Tri Dao's avatar
Tri Dao committed
70
71
72
    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
    dim3 grid_m(num_m_block, params.b, params.h);
    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
73
74
75
76
77
78
    int gridDimx = num_n_block;
    if (params.deterministic) {
        auto dprops = at::cuda::getCurrentDeviceProperties();
        gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
    }
    dim3 grid_n(gridDimx, params.b, params.h);
Tri Dao's avatar
Tri Dao committed
79

80
81
82
83
84
    if (!params.deterministic) {
        flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
    } else {
        flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
    }
Tri Dao's avatar
Tri Dao committed
85
86
    C10_CUDA_KERNEL_LAUNCH_CHECK();

87
88
89
    // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
    // a multiple of kBlockN, we'll need to apply mask in the loop.
    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
Tri Dao's avatar
Tri Dao committed
90
91
92
    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
    constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
    // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
93
    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
94
        BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
95
96
97
            EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
                LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
                    ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
98
99
100
                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                        // If Is_local, set Is_causal to false
101
                        auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
102
                        // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
103
104
105
106
107
108
109
                        if (smem_size_dq_dk_dv >= 48 * 1024)  {
                            C10_CUDA_CHECK(cudaFuncSetAttribute(
                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
                        }
                        kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
                        C10_CUDA_KERNEL_LAUNCH_CHECK();
                    });
Tri Dao's avatar
Tri Dao committed
110
                });
Tri Dao's avatar
Tri Dao committed
111
112
113
114
115
            });
        });
    });

    auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
116
    if (Kernel_traits::kSmemdQSize >= 48 * 1024)  {
Tri Dao's avatar
Tri Dao committed
117
118
119
        C10_CUDA_CHECK(cudaFuncSetAttribute(
            kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
    }
120
    kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
Tri Dao's avatar
Tri Dao committed
121
122
123
124
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template<typename Kernel_traits, bool Is_dropout>
125
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
126
#ifndef FLASHATTENTION_DISABLE_BACKWARD
127
    run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
128
#endif
Tri Dao's avatar
Tri Dao committed
129
130
131
}

template<typename T>
132
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
133
    constexpr static int Headdim = 32;
Tri Dao's avatar
Tri Dao committed
134
135
136
137
138
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
139
140
141
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
142
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
Tri Dao's avatar
Tri Dao committed
143
144
        if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
            if constexpr(!Is_dropout) {  // We can afford more registers to keep V in registers
145
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
146
            } else {
147
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
148
149
            }
        } else {  // 96 KB
150
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
151
152
153
154
155
        }
    });
}

template<typename T>
156
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
157
    constexpr static int Headdim = 64;
Tri Dao's avatar
Tri Dao committed
158
159
160
161
162
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
163
164
165
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
166
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
167
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
Tri Dao's avatar
Tri Dao committed
168
        // Changing AtomLayoutMdQ from 2 to 4 takes the same time
169
170
171
172
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
173
174
        // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
        if (max_smem_per_block >= 144 * 1024) {
175
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
176
            // This has a lot of register spilling
177
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
178
179
        } else {
            // if (params.h == params.h_k) {
180
181
182
183
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
184
185
186
187
            // } else {
            // }
        }
    });
188
189
190
191
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
Tri Dao's avatar
Tri Dao committed
192
    // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
193
194
195
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
Tri Dao's avatar
Tri Dao committed
196

197
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
Tri Dao's avatar
Tri Dao committed
198
199
200
}

template<typename T>
201
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
202
    constexpr static int Headdim = 96;
Tri Dao's avatar
Tri Dao committed
203
204
205
206
207
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
208
209
210
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
211
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
212
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
213
214
        if (max_smem_per_block >= 116 * 1024) {
            if constexpr(!Is_dropout) {  // 92KB
215
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
216
217
            } else {  // 116 KB
                // This is faster for dropout since we don't have many registers to spare
218
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
219
            }
220
        } else {
221
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
222
        }
Tri Dao's avatar
Tri Dao committed
223
224
225
226
    });
}

template<typename T>
227
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
228
    constexpr static int Headdim = 128;
Tri Dao's avatar
Tri Dao committed
229
230
231
232
233
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
234
235
236
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
237
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
238
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
239
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
240
241
        // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
        // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
242
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
243
        if (max_smem_per_block >= 144 * 1024) {
244
245
246
247
248
249
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
250
        } else {
251
252
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
253
        }
254
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
Tri Dao's avatar
Tri Dao committed
255

256
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
Tri Dao's avatar
Tri Dao committed
257
258
259
260
    });
}

template<typename T>
261
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
262
    constexpr static int Headdim = 160;
Tri Dao's avatar
Tri Dao committed
263
264
265
266
267
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
268
269
270
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
271
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
Tri Dao's avatar
Tri Dao committed
272
        if (max_smem_per_block >= 116 * 1024) {
273
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
274
        } else {
275
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
276
277
278
279
280
        }
    });
}

template<typename T>
281
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
282
    constexpr static int Headdim = 192;
Tri Dao's avatar
Tri Dao committed
283
284
285
286
287
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
288
289
290
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
291
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
Tri Dao's avatar
Tri Dao committed
292
        if (max_smem_per_block >= 136 * 1024) {
293
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
294
        } else {
295
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
        }
    });
}

template<typename T>
301
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
302
    constexpr static int Headdim = 224;
303
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
304
        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
Tri Dao's avatar
Tri Dao committed
305
306
307
308
    });
}

template<typename T>
309
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
310
    constexpr static int Headdim = 256;
Tri Dao's avatar
Tri Dao committed
311
312
313
314
315
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
316
317
318
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
319
    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
Tri Dao's avatar
Tri Dao committed
320
        if (max_smem_per_block >= 176 * 1024) {  // H100
321
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
322
        } else if (max_smem_per_block >= 144 * 1024) {  // A100, we don't do double buffering to save smem
323
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
324
325
326
327
        } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
            if constexpr (!Is_dropout) {
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
            }
Tri Dao's avatar
Tri Dao committed
328
329
330
        }
    });
}