Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
908a5b22
Commit
908a5b22
authored
Nov 07, 2022
by
Tri Dao
Browse files
Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty)
parent
74797571
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
13 deletions
+14
-13
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+14
-13
No files found.
flash_attn/flash_attn_triton.py
View file @
908a5b22
...
@@ -44,14 +44,15 @@ import triton
...
@@ -44,14 +44,15 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
@
triton
.
autotune
(
# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
configs
=
[
# @triton.autotune(
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
# configs=[
# This config has a race condition when EVEN_M == False, disabling it for now.
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
# # This config has a race condition when EVEN_M == False, disabling it for now.
],
# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
key
=
[
'CACHE_KEY_SEQLEN_Q'
,
'CACHE_KEY_SEQLEN_K'
,
'BIAS_TYPE'
,
'IS_CAUSAL'
,
'BLOCK_HEADDIM'
]
# ],
)
# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
# )
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
{
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
...
@@ -617,8 +618,8 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
...
@@ -617,8 +618,8 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
BLOCK_HEADDIM
=
max
(
triton
.
next_power_of_2
(
d
),
16
)
BLOCK_HEADDIM
=
max
(
triton
.
next_power_of_2
(
d
),
16
)
#
BLOCK = 128
BLOCK
=
128
#
num_warps = 4 if d <= 64 else 8
num_warps
=
4
if
d
<=
64
else
8
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
_fwd_kernel
[
grid
](
_fwd_kernel
[
grid
](
q
,
k
,
v
,
bias
,
o
,
q
,
k
,
v
,
bias
,
o
,
...
@@ -634,9 +635,9 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
...
@@ -634,9 +635,9 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type
,
causal
,
BLOCK_HEADDIM
,
bias_type
,
causal
,
BLOCK_HEADDIM
,
#
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_M
=
BLOCK
,
BLOCK_N
=
BLOCK
,
#
num_warps=num_warps,
num_warps
=
num_warps
,
#
num_stages=1,
num_stages
=
1
,
)
)
return
o
,
lse
,
softmax_scale
# softmax_scale could have been updated
return
o
,
lse
,
softmax_scale
# softmax_scale could have been updated
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment