Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
65b4064b
Commit
65b4064b
authored
Dec 31, 2022
by
Tri Dao
Browse files
[FusedDense] Kick off input all_gather before weight dtype conversion
parent
71befc19
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
24 deletions
+38
-24
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+3
-1
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+35
-23
No files found.
flash_attn/modules/mha.py
View file @
65b4064b
...
...
@@ -472,13 +472,15 @@ class ParallelMHA(nn.Module):
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
0
,
use_flash_attn
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
process_group
=
process_group
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
self
.
layer_idx
=
layer_idx
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
self
.
checkpointing
=
checkpointing
...
...
flash_attn/ops/fused_dense.py
View file @
65b4064b
...
...
@@ -32,24 +32,28 @@ class FusedDenseFunc(torch.autograd.Function):
ctx
.
process_group
=
process_group
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight
]]
bias
=
bias
.
to
(
dtype
=
dtype
)
if
bias
is
not
None
else
None
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
if
process_group
is
not
None
:
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
if
torch
.
is_autocast_enabled
():
weight
=
weight
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
bias
=
bias
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
if
bias
is
not
None
else
None
weight
=
weight
.
contiguous
()
if
process_group
is
not
None
:
handle_x
.
wait
()
batch_shape
=
total_x
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
output
=
F
.
linear
(
total_x
,
weight
,
bias
)
if
ctx
.
compute_weight_gradient
:
ctx
.
save_for_backward
(
x
,
weight
)
else
:
ctx
.
save_for_backward
(
weight
)
batch_shape
,
n
=
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
if
process_group
is
not
None
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
else
:
total_x
=
x
output
=
F
.
linear
(
total_x
,
weight
,
bias
)
return
output
if
not
return_residual
else
(
output
,
x
)
@
staticmethod
...
...
@@ -188,32 +192,42 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
2: recompute gelu_in and gelu_out in the bwd
"""
assert
-
1
<=
heuristic
<=
4
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
x
,
weight1
,
weight2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
x
,
weight1
,
weight2
]]
bias1
=
bias1
.
to
(
dtype
=
dtype
)
if
bias1
is
not
None
else
None
bias2
=
bias2
.
to
(
dtype
=
dtype
)
if
bias2
is
not
None
else
None
if
not
save_pre_act
:
checkpoint_lvl
=
2
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
ctx
.
return_residual
=
return_residual
ctx
.
process_group
=
process_group
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
heuristic
=
heuristic
if
torch
.
is_autocast_enabled
():
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
if
process_group
is
not
None
:
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
total_x
=
x
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_gpu_dtype
()
weight1
,
weight2
=
[
a
.
to
(
dtype
=
dtype
)
for
a
in
[
weight1
,
weight2
]]
bias1
=
bias1
.
to
(
dtype
=
dtype
)
if
bias1
is
not
None
else
None
bias2
=
bias2
.
to
(
dtype
=
dtype
)
if
bias2
is
not
None
else
None
weight1
=
weight1
.
contiguous
()
bias1
=
bias1
.
contiguous
()
if
bias1
is
not
None
else
None
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
if
bias2
is
not
None
else
None
if
process_group
is
not
None
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
else
:
total_x
=
x
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
assert
batch_dim
<=
64
*
1024
,
'fused_dense only supports dimension at most 64k'
if
heuristic
==
-
1
:
gelu_in
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) # This is before adding bias1
# This is before adding bias1
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1)
else
:
...
...
@@ -223,8 +237,6 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if
save_pre_act
:
gelu_in
=
rest
[
0
]
output2
=
F
.
linear
(
output1
,
weight2
,
bias2
)
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
heuristic
=
heuristic
if
checkpoint_lvl
==
0
:
ctx
.
save_for_backward
(
x
,
weight1
,
weight2
,
gelu_in
,
output1
)
elif
checkpoint_lvl
==
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment