Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d11341fd
Commit
d11341fd
authored
Oct 30, 2022
by
Tri Dao
Browse files
Fix Triton fwd to support seqlen not multiples of 128
parent
b0c0db81
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
23 deletions
+31
-23
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+9
-6
tests/test_flash_attn.py
tests/test_flash_attn.py
+22
-17
No files found.
flash_attn/flash_attn_triton.py
View file @
d11341fd
...
@@ -4,6 +4,7 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
...
@@ -4,6 +4,7 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
Changes:
Changes:
- Support both causal and non-causal attention.
- Support both causal and non-causal attention.
- Support arbitrary seqlens (not just multiples of 128) in the forward pass.
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
- Make the backward for d=128 much faster by reducing register spilling.
- Make the backward for d=128 much faster by reducing register spilling.
- Add the option to parallelize the backward pass across seqlen_k, to deal with the case of
- Add the option to parallelize the backward pass across seqlen_k, to deal with the case of
...
@@ -30,7 +31,7 @@ import triton.language as tl
...
@@ -30,7 +31,7 @@ import triton.language as tl
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
{
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_M"
:
lambda
args
:
args
[
"seqlen_q"
]
%
args
[
"BLOCK_M"
]
==
0
,
"EVEN_N"
:
lambda
args
:
args
[
"seqlen_k"
]
%
(
args
[
"BLOCK_N"
]
)
==
0
,
"EVEN_N"
:
lambda
args
:
args
[
"seqlen_k"
]
%
args
[
"BLOCK_N"
]
==
0
,
}
}
)
)
@
triton
.
jit
@
triton
.
jit
...
@@ -42,7 +43,7 @@ def _fwd_kernel(
...
@@ -42,7 +43,7 @@ def _fwd_kernel(
stride_kb
,
stride_kh
,
stride_kn
,
stride_kb
,
stride_kh
,
stride_kn
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_ob
,
stride_oh
,
stride_om
,
stride_ob
,
stride_oh
,
stride_om
,
nheads
,
seqlen_q
,
seqlen_k
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
...
@@ -68,12 +69,14 @@ def _fwd_kernel(
...
@@ -68,12 +69,14 @@ def _fwd_kernel(
k_ptrs
=
K
+
off_b
*
stride_kb
+
off_h
*
stride_kh
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_d
[
None
,
:])
k_ptrs
=
K
+
off_b
*
stride_kb
+
off_h
*
stride_kh
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_d
[
None
,
:])
v_ptrs
=
V
+
off_b
*
stride_vb
+
off_h
*
stride_vh
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
v_ptrs
=
V
+
off_b
*
stride_vb
+
off_h
*
stride_vh
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
# initialize pointer to m and l
# initialize pointer to m and l
t_ptrs
=
TMP
+
off_hb
*
seqlen_q
+
offs_m
t_ptrs
=
TMP
+
off_hb
*
seqlen_q
_rounded
+
offs_m
lse_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
lse_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
acc_o
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
acc_o
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
# load q: it will stay in SRAM throughout
# load q: it will stay in SRAM throughout
if
EVEN_M
:
# [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler?
if
EVEN_M
&
EVEN_N
:
q
=
tl
.
load
(
q_ptrs
)
q
=
tl
.
load
(
q_ptrs
)
else
:
else
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
...
@@ -130,7 +133,7 @@ def _fwd_kernel(
...
@@ -130,7 +133,7 @@ def _fwd_kernel(
start_m
=
tl
.
program_id
(
0
)
start_m
=
tl
.
program_id
(
0
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# write back l and m
# write back l and m
lse_ptrs
=
Lse
+
off_hb
*
seqlen_q
+
offs_m
lse_ptrs
=
Lse
+
off_hb
*
seqlen_q
_rounded
+
offs_m
tl
.
store
(
lse_ptrs
,
lse_i
)
tl
.
store
(
lse_ptrs
,
lse_i
)
# initialize pointers to output
# initialize pointers to output
offs_n
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
...
@@ -373,7 +376,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
...
@@ -373,7 +376,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
nheads
,
seqlen_q
,
seqlen_k
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
...
...
tests/test_flash_attn.py
View file @
d11341fd
...
@@ -855,15 +855,17 @@ def test_flash_attn_multigpu():
...
@@ -855,15 +855,17 @@ def test_flash_attn_multigpu():
from
flash_attn.flash_attn_triton
import
flash_attn_func
from
flash_attn.flash_attn_triton
import
flash_attn_func
@
pytest
.
mark
.
skipif
(
not
is_sm80
,
reason
=
'Triton version is only tested on A100'
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.
b
float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [
Tru
e])
# @pytest.mark.parametrize('causal', [
Fals
e])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
@
pytest
.
mark
.
parametrize
(
'seqlen_q,seqlen_k'
,
[
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
2048
,
2048
)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(
5
12, 256)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(12
7
, 256)])
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
def
test_flash_attn_triton
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
):
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
if
seqlen_q
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
...
@@ -885,6 +887,8 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -885,6 +887,8 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
run_bwd
=
(
seqlen_q
%
128
==
0
)
and
(
seqlen_k
%
128
==
0
)
if
run_bwd
:
g
=
torch
.
randn_like
(
output
)
g
=
torch
.
randn_like
(
output
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
output
,
(
q
,
k
,
v
),
g
)
dq_ref
,
dk_ref
,
dv_ref
,
=
torch
.
autograd
.
grad
(
output_ref
,
(
q
,
k
,
v
),
g
)
dq_ref
,
dk_ref
,
dv_ref
,
=
torch
.
autograd
.
grad
(
output_ref
,
(
q
,
k
,
v
),
g
)
...
@@ -901,6 +905,7 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -901,6 +905,7 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
assert
(
output
-
output_ref
).
abs
().
max
().
item
()
<=
2
*
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
assert
(
output
-
output_ref
).
abs
().
max
().
item
()
<=
2
*
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
if
run_bwd
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment