"doc/vscode:/vscode.git/clone" did not exist on "c98c5dd1b2f9998b28e3d7b177af3d4bafcb36a4"
Commit bc2c2102 authored by Tri Dao's avatar Tri Dao
Browse files

Don't nest BOOL_SWITCH to work around gcc 7 bug

parent d1fc80a3
...@@ -27,16 +27,19 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_ ...@@ -27,16 +27,19 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv); // printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { auto kernel = params.is_causal
auto kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel< ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
Kernel_traits, IsDropoutConst, IsCausalConst>; : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
if (params.seqlen_k == blocksize_c) { if (params.seqlen_k == blocksize_c) {
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel< kernel = params.is_causal
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/1>; ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) { } else if (params.seqlen_k == blocksize_c * 2) {
kernel = &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel< kernel = params.is_causal
Kernel_traits, IsDropoutConst, IsCausalConst, /*loop_steps=*/2>; ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
} }
if( smem_size_dq_dk_dv >= 48 * 1024 ) { if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute( FMHA_CHECK_CUDA(cudaFuncSetAttribute(
...@@ -46,7 +49,6 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_ ...@@ -46,7 +49,6 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params); kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
}); });
});
} }
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) { void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
......
...@@ -59,11 +59,17 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params, ...@@ -59,11 +59,17 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>() const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0); + (loop_steps > 1 ? smem_size_softmax_lse : 0);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(launch_params.params.is_causal, IsCausalConst, [&] { auto kernel = launch_params.params.is_causal
BOOL_SWITCH(launch_params.return_softmax, ReturnSoftmaxConst, [&] { ? (launch_params.return_softmax
auto kernel = &fmha_fprop_fp16_sm80_loop_kernel< ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
Kernel_traits, IsDropoutConst, IsCausalConst, ReturnSoftmaxConst>; : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
: (launch_params.return_softmax
? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
: &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
if( smem_size >= 48 * 1024 ) { if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute( FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
...@@ -73,8 +79,6 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params, ...@@ -73,8 +79,6 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
launch_params.params); launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
}); });
});
});
} }
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment