Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wxj
Megatron-LM
Commits
f1a50a3c
Commit
f1a50a3c
authored
Mar 03, 2023
by
Jimmy Zhang
Browse files
Flash Attention inference fix
parent
717c5274
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
10 deletions
+36
-10
megatron/model/transformer.py
megatron/model/transformer.py
+36
-10
No files found.
megatron/model/transformer.py
View file @
f1a50a3c
...
...
@@ -363,16 +363,42 @@ class FlashSelfAttention(torch.nn.Module):
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
batch_size
,
seqlen
=
q
.
shape
[
0
],
q
.
shape
[
1
]
q
,
k
,
v
=
[
rearrange
(
x
,
'b s ... -> (b s) ...'
)
for
x
in
[
q
,
k
,
v
]]
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
if
self
.
training
:
# during training q,k,v all have same seqlen
q
,
k
,
v
=
[
rearrange
(
x
,
'b s ... -> (b s) ...'
)
for
x
in
[
q
,
k
,
v
]]
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens
,
cu_seqlens
,
seqlen_q
,
seqlen_q
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
else
:
# during inference q seqlen is different than k,v seqlen
assert
k
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
k
.
is_cuda
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step do q,k,v have same seqlen
seqlen_k
=
k
.
shape
[
1
]
is_causal
=
seqlen_q
==
seqlen_k
q
,
k
,
v
=
[
rearrange
(
x
,
'b s ... -> (b s) ...'
)
for
x
in
[
q
,
k
,
v
]]
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
is_causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment