flash_bwd_launch_template.h 16.5 KB
Newer Older
1
2
3
/******************************************************************************
 * Copyright (c) 2024, Tri Dao.
 ******************************************************************************/
Tri Dao's avatar
Tri Dao committed
4
5
6
7
8
9
10

#pragma once

#include <ATen/cuda/CUDAContext.h>

#include "static_switch.h"
#include "flash.h"
11
#include "flash_bwd_preprocess_kernel.h"
Tri Dao's avatar
Tri Dao committed
12
13
14
#include "flash_bwd_kernel.h"

template<bool Clear_dQaccum=true, typename Kernel_traits>
15
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
Tri Dao's avatar
Tri Dao committed
16
17
18
19
    flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}

template<typename Kernel_traits>
20
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
Tri Dao's avatar
Tri Dao committed
21
22
23
    flash::clear_dKVaccum<Kernel_traits>(params);
}

24
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
25
__global__ void flash_bwd_dq_dk_dv_loop_kernel(__grid_constant__ const Flash_bwd_params params) {
26
    flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
Tri Dao's avatar
Tri Dao committed
27
28
}

29
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
30
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(__grid_constant__ const Flash_bwd_params params) {
Tri Dao's avatar
Tri Dao committed
31
    static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
32
    flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
Tri Dao's avatar
Tri Dao committed
33
34
35
}

template<typename Kernel_traits>
36
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
37
    flash::convert_dQ<Kernel_traits>(params, nsplits);
Tri Dao's avatar
Tri Dao committed
38
39
40
}

template<typename Kernel_traits>
41
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
Tri Dao's avatar
Tri Dao committed
42
43
44
45
46
47
48
49
    flash::convert_dKV<Kernel_traits>(params);
}

template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
    dim3 grid_m(num_m_block, params.b, params.h);
    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
50
51
52
53
54
55
    int gridDimx = num_n_block;
    if (params.deterministic) {
        auto dprops = at::cuda::getCurrentDeviceProperties();
        gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
    }
    dim3 grid_n(gridDimx, params.b, params.h);
Tri Dao's avatar
Tri Dao committed
56

57
58
59
60
61
    if (!params.deterministic) {
        flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
    } else {
        flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
    }
Tri Dao's avatar
Tri Dao committed
62
63
    C10_CUDA_KERNEL_LAUNCH_CHECK();

64
65
66
    // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
    // a multiple of kBlockN, we'll need to apply mask in the loop.
    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
Tri Dao's avatar
Tri Dao committed
67
68
69
    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
    constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
    // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
70
    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
71
        BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
Tri Dao's avatar
Tri Dao committed
72
            BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
73
                BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
74
                    BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
75
76
77
                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                        // If Is_local, set Is_causal to false
78
                        auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
79
                        // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
80
81
82
83
84
85
86
                        if (smem_size_dq_dk_dv >= 48 * 1024)  {
                            C10_CUDA_CHECK(cudaFuncSetAttribute(
                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
                        }
                        kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
                        C10_CUDA_KERNEL_LAUNCH_CHECK();
                    });
Tri Dao's avatar
Tri Dao committed
87
                });
Tri Dao's avatar
Tri Dao committed
88
89
90
91
92
            });
        });
    });

    auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
93
    if (Kernel_traits::kSmemdQSize >= 48 * 1024)  {
Tri Dao's avatar
Tri Dao committed
94
95
96
        C10_CUDA_CHECK(cudaFuncSetAttribute(
            kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
    }
97
    kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
Tri Dao's avatar
Tri Dao committed
98
99
100
101
102
103
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
    if (configure) return;
104
    run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
Tri Dao's avatar
Tri Dao committed
105
106
107
108
}

template<typename T>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
109
    constexpr static int Headdim = 32;
Tri Dao's avatar
Tri Dao committed
110
111
112
113
114
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
115
116
117
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
            if constexpr(!Is_dropout) {  // We can afford more registers to keep V in registers
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
            } else {
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
            }
        } else {  // 96 KB
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
        }
    });
}

template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
133
    constexpr static int Headdim = 64;
Tri Dao's avatar
Tri Dao committed
134
135
136
137
138
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
139
140
141
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        // Changing AtomLayoutMdQ from 2 to 4 takes the same time
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream, configure);
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream, configure);
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
        // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
        if (max_smem_per_block >= 144 * 1024) {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
            // This has a lot of register spilling
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
        } else {
            // if (params.h == params.h_k) {
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure);
            // } else {
            // }
        }
    });
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
    // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream, configure);
    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream, configure);

    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream, configure);
}

template<typename T>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
178
    constexpr static int Headdim = 96;
Tri Dao's avatar
Tri Dao committed
179
180
181
182
183
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
184
185
186
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
187
188
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
189
190
        if (max_smem_per_block >= 116 * 1024) {
            if constexpr(!Is_dropout) {  // 92KB
Tri Dao's avatar
Tri Dao committed
191
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
192
193
194
            } else {  // 116 KB
                // This is faster for dropout since we don't have many registers to spare
                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
Tri Dao's avatar
Tri Dao committed
195
            }
196
197
198
        } else {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
        }
Tri Dao's avatar
Tri Dao committed
199
200
201
202
203
    });
}

template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
204
    constexpr static int Headdim = 128;
Tri Dao's avatar
Tri Dao committed
205
206
207
208
209
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
210
211
212
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
213
214
    // printf("max_smem_per_block = %d\n", max_smem_per_block);
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
        // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
        // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
        if (max_smem_per_block >= 144 * 1024) {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure);
            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
        } else {
            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
        }
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
Tri Dao's avatar
Tri Dao committed
231

232
        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
Tri Dao's avatar
Tri Dao committed
233
234
235
236
237
    });
}

template<typename T>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
238
    constexpr static int Headdim = 160;
Tri Dao's avatar
Tri Dao committed
239
240
241
242
243
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
244
245
246
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
247
248
249
250
251
252
253
254
255
256
257
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        if (max_smem_per_block >= 116 * 1024) {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
        } else {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
        }
    });
}

template<typename T>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
258
    constexpr static int Headdim = 192;
Tri Dao's avatar
Tri Dao committed
259
260
261
262
263
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
264
265
266
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
267
268
269
270
271
272
273
274
275
276
277
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        if (max_smem_per_block >= 136 * 1024) {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
        } else {
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream, configure);
        }
    });
}

template<typename T>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
278
    constexpr static int Headdim = 224;
Tri Dao's avatar
Tri Dao committed
279
280
281
282
283
284
285
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
    });
}

template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
286
    constexpr static int Headdim = 256;
Tri Dao's avatar
Tri Dao committed
287
288
289
290
291
    int device;
    cudaGetDevice(&device);
    int max_smem_per_block;
    cudaError status_ = cudaDeviceGetAttribute(
        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
Driss Guessous's avatar
Driss Guessous committed
292
293
294
    if (status_ != cudaSuccess) {
      C10_CUDA_CHECK(status_);
    }
Tri Dao's avatar
Tri Dao committed
295
296
297
298
299
300
301
302
    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
        if (max_smem_per_block >= 176 * 1024) {  // H100
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
        } else {  // A100, we don't do double buffering to save smem
            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream, configure);
        }
    });
}