Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6b5f271c
Commit
6b5f271c
authored
Dec 14, 2022
by
Tri Dao
Browse files
[Triton] Avoid einops repeat by using Tensor.expand
parent
88c4e5db
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
12 deletions
+2
-12
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+2
-12
No files found.
flash_attn/flash_attn_triton.py
View file @
6b5f271c
...
...
@@ -38,8 +38,6 @@ import math
import
torch
from
einops
import
rearrange
,
repeat
import
triton
import
triton.language
as
tl
...
...
@@ -605,11 +603,7 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
else
:
raise
RuntimeError
(
'Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)'
)
if
bias
.
shape
[:
2
]
==
(
1
,
nheads
):
bias
=
repeat
(
bias
,
'1 h ... -> b h ...'
,
b
=
batch
)
elif
bias
.
shape
[:
2
]
==
(
batch
,
1
):
bias
=
repeat
(
bias
,
'b 1 ... -> b h ...'
,
h
=
nheads
)
assert
bias
.
shape
[:
2
]
==
(
batch
,
nheads
),
'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias
=
bias
.
expand
(
batch
,
nheads
,
seqlen_q
,
seqlen_k
)
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
))
if
has_bias
else
(
0
,
0
,
0
)
seqlen_q_rounded
=
math
.
ceil
(
seqlen_q
/
128
)
*
128
...
...
@@ -684,11 +678,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
else
:
raise
RuntimeError
(
'Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)'
)
if
bias
.
shape
[:
2
]
==
(
1
,
nheads
):
bias
=
repeat
(
bias
,
'1 h ... -> b h ...'
,
b
=
batch
)
elif
bias
.
shape
[:
2
]
==
(
batch
,
1
):
bias
=
repeat
(
bias
,
'b 1 ... -> b h ...'
,
h
=
nheads
)
assert
bias
.
shape
[:
2
]
==
(
batch
,
nheads
),
'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias
=
bias
.
expand
(
batch
,
nheads
,
seqlen_q
,
seqlen_k
)
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
))
if
has_bias
else
(
0
,
0
,
0
)
# BLOCK_M = 128
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment