Commit ea8a25ca authored by Tri Dao's avatar Tri Dao
Browse files

Remove configure in bwd kernel launch

parent af01244d
/****************************************************************************** /******************************************************************************
* Copyright (c) 2023, Tri Dao. * Copyright (c) 2024, Tri Dao.
******************************************************************************/ ******************************************************************************/
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
...@@ -204,7 +204,7 @@ void set_params_dgrad(Flash_bwd_params &params, ...@@ -204,7 +204,7 @@ void set_params_dgrad(Flash_bwd_params &params,
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) { void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] { HEADDIM_SWITCH(params.d, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim>(params, stream); run_mha_fwd_<elem_type, kHeadDim>(params, stream);
} else { } else {
...@@ -695,25 +695,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -695,25 +695,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
} }
void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
if (params.d <= 32) { HEADDIM_SWITCH(params.d, [&] {
run_mha_bwd_<elem_type, 32>(params, stream, configure); run_mha_bwd_<elem_type, kHeadDim>(params, stream);
} else if (params.d <= 64) { });
run_mha_bwd_<elem_type, 64>(params, stream, configure);
} else if (params.d <= 96) {
run_mha_bwd_<elem_type, 96>(params, stream, configure);
} else if (params.d <= 128) {
run_mha_bwd_<elem_type, 128>(params, stream, configure);
} else if (params.d <= 160) {
run_mha_bwd_<elem_type, 160>(params, stream, configure);
} else if (params.d <= 192) {
run_mha_bwd_<elem_type, 192>(params, stream, configure);
} else if (params.d <= 224) {
run_mha_bwd_<elem_type, 224>(params, stream, configure);
} else if (params.d <= 256) {
run_mha_bwd_<elem_type, 256>(params, stream, configure);
}
}); });
} }
...@@ -898,7 +884,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -898,7 +884,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd; auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -930,7 +915,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -930,7 +915,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
} }
if (seqlen_q > 0) { if (seqlen_q > 0) {
launch(params, stream, /*configure=*/false); launch(params, stream);
} else { } else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_(); dk_expanded.zero_();
...@@ -1154,7 +1139,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1154,7 +1139,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd; auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -1186,7 +1170,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1186,7 +1170,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
} }
if (max_seqlen_q > 0) { if (max_seqlen_q > 0) {
launch(params, stream, /*configure=*/false); launch(params, stream);
} else { } else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_(); dk_expanded.zero_();
......
...@@ -182,4 +182,4 @@ struct Flash_bwd_params : public Flash_fwd_params { ...@@ -182,4 +182,4 @@ struct Flash_bwd_params : public Flash_fwd_params {
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure); template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim224<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim224<cutlass::half_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream, configure); run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream);
} }
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h" #include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::half_t>(params, stream, configure); run_mha_bwd_hdim96<cutlass::half_t>(params, stream);
} }
...@@ -43,7 +43,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { ...@@ -43,7 +43,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
} }
template<typename Kernel_traits, bool Is_dropout> template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h); dim3 grid_m(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
...@@ -99,13 +99,12 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -99,13 +99,12 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
} }
template<typename Kernel_traits, bool Is_dropout> template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
if (configure) return; run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32; constexpr static int Headdim = 32;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
...@@ -118,18 +117,18 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const boo ...@@ -118,18 +117,18 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const boo
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
} }
} else { // 96 KB } else { // 96 KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} }
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64; constexpr static int Headdim = 64;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
...@@ -142,39 +141,39 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo ...@@ -142,39 +141,39 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// Changing AtomLayoutMdQ from 2 to 4 takes the same time // Changing AtomLayoutMdQ from 2 to 4 takes the same time
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// This is slightly faster. We want to split M more so we need fewer registers to store LSE. // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if (max_smem_per_block >= 144 * 1024) { if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// This has a lot of register spilling // This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else { } else {
// if (params.h == params.h_k) { // if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
// } else { // } else {
// } // }
} }
}); });
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96; constexpr static int Headdim = 96;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
...@@ -188,19 +187,19 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo ...@@ -188,19 +187,19 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) { if (max_smem_per_block >= 116 * 1024) {
if constexpr(!Is_dropout) { // 92KB if constexpr(!Is_dropout) { // 92KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else { // 116 KB } else { // 116 KB
// This is faster for dropout since we don't have many registers to spare // This is faster for dropout since we don't have many registers to spare
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
} }
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
} }
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128; constexpr static int Headdim = 128;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
...@@ -212,29 +211,29 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -212,29 +211,29 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
} }
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers). // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
if (max_smem_per_block >= 144 * 1024) { if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure); // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
} else { } else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
} }
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160; constexpr static int Headdim = 160;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
...@@ -246,15 +245,15 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -246,15 +245,15 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bo
} }
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) { if (max_smem_per_block >= 116 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
} }
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192; constexpr static int Headdim = 192;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
...@@ -266,23 +265,23 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -266,23 +265,23 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bo
} }
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 136 * 1024) { if (max_smem_per_block >= 136 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
} }
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224; constexpr static int Headdim = 224;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
}); });
} }
template<typename T> template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256; constexpr static int Headdim = 256;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
...@@ -294,9 +293,9 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -294,9 +293,9 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bo
} }
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 176 * 1024) { // H100 if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else { // A100, we don't do double buffering to save smem } else { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
} }
}); });
} }
...@@ -32,8 +32,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ...@@ -32,8 +32,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params
KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h" KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
template<> template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {{ void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure); run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
}} }}
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment