Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
7fc39832
Commit
7fc39832
authored
Oct 21, 2022
by
Tri Dao
Browse files
Use block_size=128 for headdim=128 on SM80
Previously we were using block_size=256.
parent
a44f48df
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
14 deletions
+11
-14
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+11
-14
No files found.
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
7fc39832
...
...
@@ -115,6 +115,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
/*max_splits=*/
std
::
min
(
30
,
(
launch_params
.
params
.
seqlen_q
+
M
-
1
/
M
))
);
}
// printf("smem_size = %d\n", smem_size);
dim3
grid
(
launch_params
.
params
.
b
,
launch_params
.
params
.
h
,
launch_params
.
params
.
num_splits
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
);
...
...
@@ -156,20 +157,16 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
128
,
16
,
1
,
4
,
0x18u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
// Need to use the same block size as backward
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
// to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB,
// reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives
// some speedup (6-10%) for large batch size, but slows things down for smal batch size.
// Now that we have better parallelism (over seqlen_q), block size 128 is faster for small
// batch size and only slightly slower (~3%) on large batch size.
// For causal=True, block size 128 seems always faster (for small & large batch size).
// So we're just gonna use block size 128 for simplicity.
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
// if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment