flash_bwd_launch_template.h 19.4 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// Copyright (c) 2023, Tri Dao.

#pragma once

#include <ATen/cuda/CUDAContext.h>

#include "static_switch.h"
#include "flash.h"
#include "flash_bwd_kernel.h"

template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) {
    flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}

template<typename Kernel_traits>
__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) {
    flash::clear_dKVaccum<Kernel_traits>(params);
}

21
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
Tri Dao's avatar
Tri Dao committed
22
__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
23
    flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
Tri Dao's avatar
Tri Dao committed
24
25
}

26
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
Tri Dao's avatar
Tri Dao committed
27
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
Tri Dao's avatar
Tri Dao committed
28
    static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
29
    flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
Tri Dao's avatar
Tri Dao committed
30
31
}

32
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K>
Tri Dao's avatar
Tri Dao committed
33
__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) {
34
    flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_N, Is_even_K>(params);
Tri Dao's avatar
Tri Dao committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
}

template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) {
    flash::convert_dQ<Kernel_traits>(params);
}

template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) {
    flash::convert_dKV<Kernel_traits>(params);
}

template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
    dim3 grid_m(num_m_block, params.b, params.h);
    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
    dim3 grid_n(num_n_block, params.b, params.h);

    flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
    C10_CUDA_KERNEL_LAUNCH_CHECK();

57
58
59
    // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
    // a multiple of kBlockN, we'll need to apply mask in the loop.
    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
Tri Dao's avatar
Tri Dao committed
60
61
62
    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
    constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
    // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
63
    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
64
        BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
Tri Dao's avatar
Tri Dao committed
65
            BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
66
                BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
67
68
69
70
71
72
73
74
75
76
77
78
79
                    BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                        // If Is_local, set Is_causal to false
                        auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
                        // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
                        if (smem_size_dq_dk_dv >= 48 * 1024)  {
                            C10_CUDA_CHECK(cudaFuncSetAttribute(
                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
                        }
                        kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
                        C10_CUDA_KERNEL_LAUNCH_CHECK();
                    });
Tri Dao's avatar
Tri Dao committed
80
                });
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
            });
        });
    });

    auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
86
    if (Kernel_traits::kSmemdQSize >= 48 * 1024)  {
Tri Dao's avatar
Tri Dao committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        C10_CUDA_CHECK(cudaFuncSetAttribute(
            kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
    }
    kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
    dim3 grid_n(num_n_block, params.b, params.h_k);
    flash_bwd_clear_dkvaccum_kernel<Kernel_traits><<<grid_n, Kernel_traits::kNThreads, 0, stream>>>(params);
    C10_CUDA_KERNEL_LAUNCH_CHECK();

    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
    dim3 grid_m(num_m_block, params.b, params.h);
    // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
    // for cu_seqlens_k as well.
    const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
    constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock;
    // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
109
    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
Tri Dao's avatar
Tri Dao committed
110
111
        BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
            BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
112
113
114
115
116
117
118
119
120
121
122
                BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
                    // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                    auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
                    // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
                    if (smem_size_dq_dk_dv >= 48 * 1024)  {
                        C10_CUDA_CHECK(cudaFuncSetAttribute(
                            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
                    }
                    kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
                    C10_CUDA_KERNEL_LAUNCH_CHECK();
                });
Tri Dao's avatar
Tri Dao committed
123
124
125
126
127
            });
        });
    });

    auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
128
    if (Kernel_traits::kSmemKVSize >= 48 * 1024)  {
Tri Dao's avatar
Tri Dao committed
129
130
131
132
133
134
135
136
137
138
        C10_CUDA_CHECK(cudaFuncSetAttribute(
            kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize));
    }
    kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemKVSize, stream>>>(params);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
    if (configure) return;
139
    run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
Tri Dao's avatar
Tri Dao committed
140
141
142
143
}

template<typename T>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
144
    constexpr static int Headdim = 32;
Tri Dao's avatar
Tri Dao committed
145
146
147
148
149
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
150
151
152
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
            if constexpr(!Is_dropout) {  // We can afford more registers to keep V in registers
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
            } else {
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
            }
        } else {  // 96 KB
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
        }
    });
}

template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
168
    constexpr static int Headdim = 64;
Tri Dao's avatar
Tri Dao committed
169
170
171
172
173
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
174
175
176
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        // Changing AtomLayoutMdQ from 2 to 4 takes the same time
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream, configure);
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream, configure);
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
        // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
        if (max_smem_per_block >= 144 * 1024) {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
            // This has a lot of register spilling
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
        } else {
            // if (params.h == params.h_k) {
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure);
            // } else {
            //     run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
            // }
        }
    });
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
    // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream, configure);

    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream, configure);
}

template<typename T>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
214
    constexpr static int Headdim = 96;
Tri Dao's avatar
Tri Dao committed
215
216
217
218
219
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
220
221
222
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        // if (params.h == params.h_k) {
            if (max_smem_per_block >= 116 * 1024) {
                if constexpr(!Is_dropout) {  // 92KB
                    run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
                } else {  // 116 KB
                    // This is faster for dropout since we don't have many registers to spare
                    run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
                }
            } else {
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
            }
        // } else {
            // run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
        // }
    });
}

template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
244
    constexpr static int Headdim = 128;
Tri Dao's avatar
Tri Dao committed
245
246
247
248
249
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
250
251
252
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        // if (params.h == params.h_k) {
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
            // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
            // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
            if (max_smem_per_block >= 144 * 1024) {
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
            } else {
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
            }
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);

            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
        // } else {
            // run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
        // }
    });
}

template<typename T>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
283
    constexpr static int Headdim = 160;
Tri Dao's avatar
Tri Dao committed
284
285
286
287
288
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
289
290
291
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
292
293
294
295
296
297
298
299
300
301
302
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        if (max_smem_per_block >= 116 * 1024) {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
        } else {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
        }
    });
}

template<typename T>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
303
    constexpr static int Headdim = 192;
Tri Dao's avatar
Tri Dao committed
304
305
306
307
308
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
309
310
311
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
312
313
314
315
316
317
318
319
320
321
322
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        if (max_smem_per_block >= 136 * 1024) {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
        } else {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream, configure);
        }
    });
}

template<typename T>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
323
    constexpr static int Headdim = 224;
Tri Dao's avatar
Tri Dao committed
324
325
326
327
328
329
330
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
    });
}

template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
331
    constexpr static int Headdim = 256;
Tri Dao's avatar
Tri Dao committed
332
333
334
335
336
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
337
338
339
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
340
341
342
343
344
345
346
347
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        if (max_smem_per_block >= 176 * 1024) {  // H100
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
        } else {  // A100, we don't do double buffering to save smem
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream, configure);
        }
    });
}