Commit 65f723bb authored by Tri Dao's avatar Tri Dao
Browse files

Split bwd into more .cu files to speed up compilation

parent 5ca83a9c
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_<cutlass::half_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t>(params, stream); run_mha_bwd_hdim32<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_<cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream); run_mha_bwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_<cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t>(params, stream); run_mha_bwd_hdim64<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_<cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream); run_mha_bwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_<cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::half_t>(params, stream); run_mha_bwd_hdim96<cutlass::half_t, false>(params, stream);
} }
...@@ -65,7 +65,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { ...@@ -65,7 +65,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
flash::convert_dKV<Kernel_traits>(params); flash::convert_dKV<Kernel_traits>(params);
} }
template<typename Kernel_traits, bool Is_dropout> template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) { void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h); dim3 grid_m(num_m_block, params.b, params.h);
...@@ -90,24 +90,22 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) ...@@ -90,24 +90,22 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>; if (smem_size_dq_dk_dv >= 48 * 1024) {
if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute(
C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); }
} kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK();
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
...@@ -123,14 +121,14 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) ...@@ -123,14 +121,14 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
template<typename Kernel_traits, bool Is_dropout> template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) { void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD #ifndef FLASHATTENTION_DISABLE_BACKWARD
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream); run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout, Is_causal>(params, stream);
#endif #endif
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32; constexpr static int Headdim = 32;
int device; int device;
...@@ -144,17 +142,17 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -144,17 +142,17 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} }
} else { // 96 KB } else { // 96 KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
} }
}); });
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64; constexpr static int Headdim = 64;
int device; int device;
...@@ -174,13 +172,13 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -174,13 +172,13 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// This is slightly faster. We want to split M more so we need fewer registers to store LSE. // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if (max_smem_per_block >= 144 * 1024) { if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// This has a lot of register spilling // This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else { } else {
// if (params.h == params.h_k) { // if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
// } else { // } else {
...@@ -199,7 +197,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -199,7 +197,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96; constexpr static int Headdim = 96;
int device; int device;
...@@ -214,18 +212,18 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -214,18 +212,18 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) { if (max_smem_per_block >= 116 * 1024) {
if constexpr(!Is_dropout) { // 92KB if constexpr(!Is_dropout) { // 92KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
} else { // 116 KB } else { // 116 KB
// This is faster for dropout since we don't have many registers to spare // This is faster for dropout since we don't have many registers to spare
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} }
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
} }
}); });
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
int device; int device;
...@@ -243,7 +241,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -243,7 +241,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
if (max_smem_per_block >= 144 * 1024) { if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream); // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream); // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
...@@ -251,7 +249,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -251,7 +249,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
} else { } else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout, Is_causal>(params, stream);
} }
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
...@@ -259,7 +257,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -259,7 +257,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
}); });
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160; constexpr static int Headdim = 160;
int device; int device;
...@@ -272,14 +270,14 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -272,14 +270,14 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
} }
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) { if (max_smem_per_block >= 116 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
} }
}); });
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192; constexpr static int Headdim = 192;
int device; int device;
...@@ -292,14 +290,14 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -292,14 +290,14 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
} }
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 136 * 1024) { if (max_smem_per_block >= 136 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout, Is_causal>(params, stream);
} }
}); });
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) { void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256; constexpr static int Headdim = 256;
int device; int device;
...@@ -312,12 +310,12 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) { ...@@ -312,12 +310,12 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
} }
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 176 * 1024) { // H100 if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout, Is_causal>(params, stream);
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering. } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
if constexpr (!Is_dropout) { if constexpr (!Is_dropout) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false, Is_causal>(params, stream);
} }
} }
}); });
......
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py" // This file is auto-generated. See "generate_kernels.py"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment