Commit bc2c2102 authored by Tri Dao's avatar Tri Dao
Browse files

Don't nest BOOL_SWITCH to work around gcc 7 bug

parent d1fc80a3
......@@ -27,16 +27,19 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
auto kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst>;
auto kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
if (params.seqlen_k == blocksize_c) {
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/1>;
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/2>;
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
......@@ -46,7 +49,6 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
});
}
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
......
......@@ -59,11 +59,17 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(launch_params.params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(launch_params.return_softmax, ReturnSoftmaxConst, [&] {
auto kernel = &fmha_fprop_fp16_sm80_loop_kernel<
Kernel_traits, IsDropoutConst, IsCausalConst, ReturnSoftmaxConst>;
auto kernel = launch_params.params.is_causal
? (launch_params.return_softmax
? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
: &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
: (launch_params.return_softmax
? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
: &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
......@@ -73,8 +79,6 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
});
});
}
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment