Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
254e8815
Commit
254e8815
authored
Apr 06, 2023
by
Jimmy Zhang
Browse files
refactor flash attention
parent
f1a50a3c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
28 deletions
+20
-28
megatron/model/transformer.py
megatron/model/transformer.py
+20
-28
No files found.
megatron/model/transformer.py
View file @
254e8815
...
...
@@ -361,43 +361,35 @@ class FlashSelfAttention(torch.nn.Module):
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
assert
all
((
i
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
for
i
in
(
q
,
k
,
v
)))
assert
all
((
i
.
is_cuda
for
i
in
(
q
,
k
,
v
)))
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
k
.
shape
[
1
]
if
self
.
training
:
# during training q,k,v all have same seqlen
q
,
k
,
v
=
[
rearrange
(
x
,
'b s ... -> (b s) ...'
)
for
x
in
[
q
,
k
,
v
]]
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
q
,
k
,
v
=
[
rearrange
(
x
,
'b s ... -> (b s) ...'
)
for
x
in
[
q
,
k
,
v
]]
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens
,
cu_seqlens
,
seqlen_q
,
seqlen_q
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
else
:
# during inference q seqlen is different than k,v seqlen
assert
k
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
k
.
is_cuda
if
self
.
training
:
# during training q,k,v always have same seqlen
assert
seqlen_k
==
seqlen_q
is_causal
=
self
.
causal
cu_seqlens_k
=
cu_seqlens_q
else
:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step do q,k,v have same seqlen
seqlen_k
=
k
.
shape
[
1
]
# only on first autoregressive step q,k,v have same seqlen
is_causal
=
seqlen_q
==
seqlen_k
q
,
k
,
v
=
[
rearrange
(
x
,
'b s ... -> (b s) ...'
)
for
x
in
[
q
,
k
,
v
]]
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
device
=
q
.
device
)
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
is_causal
)
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
is_causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment