Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
321c57d0
Commit
321c57d0
authored
Jun 04, 2022
by
Tri Dao
Browse files
Set block size of SM75 fwd to 256 if there's no dropout
This speeds up the fwd by 1.5x.
parent
f2d8d410
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
5 deletions
+15
-5
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+1
-1
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+14
-4
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
321c57d0
...
...
@@ -144,7 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int
base_N
=
(
head_size
==
128
||
(
is_sm75
&&
head_size
==
64
))
?
128
:
256
;
int
base_N
=
(
head_size
==
128
||
(
is_sm75
&&
head_size
==
64
&&
is_dropout
))
?
128
:
256
;
// int base_N = 256;
int
seq_len
=
512
;
if
(
max_seq_len
<=
128
)
{
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
321c57d0
...
...
@@ -111,8 +111,13 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
dprops
->
major
==
7
&&
dprops
->
minor
==
5
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
if
(
launch_params
.
is_dropout
)
{
// Need to use the same block size as backward
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
}
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
...
...
@@ -136,8 +141,13 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
// }
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment