Commit 321c57d0 authored by Tri Dao's avatar Tri Dao
Browse files

Set block size of SM75 fwd to 256 if there's no dropout

This speeds up the fwd by 1.5x.
parent f2d8d410
...@@ -144,7 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot ...@@ -144,7 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256); // int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int base_N = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256; int base_N = (head_size == 128 || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256;
// int base_N = 256; // int base_N = 256;
int seq_len = 512; int seq_len = 512;
if( max_seq_len <= 128 ) { if( max_seq_len <= 128 ) {
......
...@@ -111,8 +111,13 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l ...@@ -111,8 +111,13 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (dprops->major == 7 && dprops->minor == 5) { } else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; if (launch_params.is_dropout) { // Need to use the same block size as backward
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} }
} }
} else if (launch_params.params.d == 128) { } else if (launch_params.params.d == 128) {
...@@ -136,8 +141,13 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l ...@@ -136,8 +141,13 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) { // } else if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; // if (launch_params.is_dropout) { // Need to use the same block size as backward
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// } // }
// } // }
// } // }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment