Commit 9448264d authored by Tri Dao's avatar Tri Dao
Browse files

Remove seqq_parallel backward kernel that's not used

parent 1274ec3e
This diff is collapsed.
...@@ -29,11 +29,6 @@ __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params pa ...@@ -29,11 +29,6 @@ __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params pa
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params); flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_N, Is_even_K>(params);
}
template<typename Kernel_traits> template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) { __global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) {
flash::convert_dQ<Kernel_traits>(params, nsplits); flash::convert_dQ<Kernel_traits>(params, nsplits);
...@@ -100,48 +95,6 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -100,48 +95,6 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
dim3 grid_n(num_n_block, params.b, params.h_k);
flash_bwd_clear_dkvaccum_kernel<Kernel_traits><<<grid_n, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_k as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
if (Kernel_traits::kSmemKVSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize));
}
kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemKVSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename Kernel_traits, bool Is_dropout> template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) { void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
if (configure) return; if (configure) return;
...@@ -202,7 +155,6 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo ...@@ -202,7 +155,6 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure);
// } else { // } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
// } // }
} }
}); });
...@@ -231,7 +183,6 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo ...@@ -231,7 +183,6 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo
} }
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (params.h == params.h_k) {
if (max_smem_per_block >= 116 * 1024) { if (max_smem_per_block >= 116 * 1024) {
if constexpr(!Is_dropout) { // 92KB if constexpr(!Is_dropout) { // 92KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
...@@ -242,9 +193,6 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo ...@@ -242,9 +193,6 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo
} else { } else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
} }
// } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
// }
}); });
} }
...@@ -261,7 +209,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -261,7 +209,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
} }
// printf("max_smem_per_block = %d\n", max_smem_per_block); // printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers). // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
...@@ -270,7 +217,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -270,7 +217,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure); run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure); // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
...@@ -281,9 +227,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo ...@@ -281,9 +227,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure); // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
// } else {
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
// }
}); });
} }
......
...@@ -288,9 +288,6 @@ struct Flash_bwd_kernel_traits : public Base { ...@@ -288,9 +288,6 @@ struct Flash_bwd_kernel_traits : public Base {
+ (!Is_V_in_regs + (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + kSmemPSize ? kSmemKVSize + kSmemdSSize + kSmemPSize
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
+ kSmemdSSize + kSmemPSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment