Commit 908511b2 authored by Tri Dao's avatar Tri Dao
Browse files

Split into more .cu files to speed up compilation

parent 1d536d7d
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream); run_mha_fwd_hdim256<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream); run_mha_fwd_hdim32<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream); run_mha_fwd_hdim64<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t>(params, stream); run_mha_fwd_hdim96<cutlass::half_t, false>(params, stream);
} }
...@@ -95,7 +95,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -95,7 +95,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
}); });
} }
template<typename Kernel_traits> template<typename Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
...@@ -104,7 +104,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -104,7 +104,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
...@@ -131,7 +130,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -131,7 +130,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
}); });
}); });
}); });
});
if (params.num_splits > 1) { if (params.num_splits > 1) {
// We want kBlockM to be as small as possible for more parallelism. // We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
...@@ -159,31 +157,28 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -159,31 +157,28 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
} }
} }
template<typename T, int Headdim> template<typename T, int Headdim, bool Is_causal>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64; // Fixed for all head dimensions constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128. // and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition. // Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream); run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32; constexpr static int Headdim = 32;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}); });
});
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64; constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k // Using block size (64 x 256) is 27% slower for seqlen=2k
...@@ -198,16 +193,14 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -198,16 +193,14 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} }
}); });
});
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96; constexpr static int Headdim = 96;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0; bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) { if (is_sm8x) {
if constexpr(!Is_causal) { if constexpr(!Is_causal) {
...@@ -224,16 +217,14 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -224,16 +217,14 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
}); });
});
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0; bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
...@@ -261,16 +252,14 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -261,16 +252,14 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
} }
}); });
});
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160; constexpr static int Headdim = 160;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0; bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, H100, 128 x 32 is the fastest. // For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 64 with 8 warps is the fastest for non-causal. // and 128 x 64 with 8 warps is the fastest for non-causal.
...@@ -291,14 +280,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -291,14 +280,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
}); });
});
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192; constexpr static int Headdim = 192;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else { } else {
...@@ -310,10 +297,9 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -310,10 +297,9 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
}); });
});
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224; constexpr static int Headdim = 224;
int device; int device;
...@@ -326,7 +312,6 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -326,7 +312,6 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
} }
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else { } else {
...@@ -339,10 +324,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -339,10 +324,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
// is 8 elements. This means we can only use 128 threads and not 256 threads. // is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
}); });
});
} }
template<typename T> template<typename T, bool Is_causal>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256; constexpr static int Headdim = 256;
int device; int device;
...@@ -357,7 +341,6 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -357,7 +341,6 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
} }
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem). // For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
...@@ -370,5 +353,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -370,5 +353,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
// 96 KB // 96 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
}); });
});
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment