Commit 321c57d0 authored by Tri Dao's avatar Tri Dao
Browse files

Set block size of SM75 fwd to 256 if there's no dropout

This speeds up the fwd by 1.5x.
parent f2d8d410
......@@ -144,7 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int base_N = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256;
int base_N = (head_size == 128 || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256;
// int base_N = 256;
int seq_len = 512;
if( max_seq_len <= 128 ) {
......
......@@ -111,8 +111,13 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
if (launch_params.is_dropout) { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}
}
} else if (launch_params.params.d == 128) {
......@@ -136,8 +141,13 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
// }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment