Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
75e334d4
Commit
75e334d4
authored
Jul 22, 2023
by
Tri Dao
Browse files
[MLP] Add ParallelMLP
parent
b3177dfa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
6 deletions
+41
-6
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+9
-6
flash_attn/modules/block.py
flash_attn/modules/block.py
+2
-0
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+30
-0
No files found.
flash_attn/models/gpt.py
View file @
75e334d4
...
...
@@ -18,7 +18,7 @@ from einops import rearrange
from
flash_attn.ops.activations
import
sqrelu_fwd
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
Mlp
,
GatedMlp
,
FusedMLP
,
ParallelFusedMLP
from
flash_attn.modules.mlp
import
Mlp
,
GatedMlp
,
ParallelMLP
,
FusedMLP
,
ParallelFusedMLP
from
flash_attn.modules.block
import
Block
,
ParallelBlock
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
...
...
@@ -112,10 +112,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
assert
config
.
activation_function
==
'sqrelu'
,
(
'fused_dense_sqrelu_dense only '
'supports approximate activation_function sqrelu'
)
assert
not
(
fused_dense_sqrelu_dense
and
fused_mlp
)
if
process_group
is
not
None
:
assert
fused_mlp
,
'Tensor Parallel is only implemented for FusedMLP'
if
not
fused_mlp
and
not
fused_dense_sqrelu_dense
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
,
'glu'
,
'swiglu'
,
'geglu'
]
if
config
.
activation_function
in
[
'glu'
,
'swiglu'
,
'geglu'
]:
activation
=
(
F
.
sigmoid
if
config
.
activation_function
==
'glu'
...
...
@@ -132,8 +130,13 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
approximate
=
(
'tanh'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'none'
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
factory_kwargs
)
mlp_cls
=
Mlp
if
process_group
is
None
else
ParallelMLP
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
parallel_kwargs
,
**
factory_kwargs
)
else
:
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
...
...
flash_attn/modules/block.py
View file @
75e334d4
...
...
@@ -288,6 +288,8 @@ class ParallelBlock(nn.Module):
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
fused_add_norm_fn
=
(
dropout_add_rms_norm_parallel_residual
if
isinstance
(
self
.
norm1
,
RMSNorm
)
else
dropout_add_layer_norm_parallel_residual
)
...
...
flash_attn/modules/mlp.py
View file @
75e334d4
...
...
@@ -3,6 +3,12 @@
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.distributed
import
ProcessGroup
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
RowParallelLinear
except
ImportError
:
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
try
:
from
flash_attn.ops.fused_dense
import
FusedMLP
,
ParallelFusedMLP
...
...
@@ -30,6 +36,30 @@ class Mlp(nn.Module):
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
ParallelMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
process_group
:
ProcessGroup
=
None
,
sequence_parallel
=
True
,
bias1
=
True
,
bias2
=
True
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
assert
ColumnParallelLinear
is
not
None
,
"Need to install fused_dense"
assert
RowParallelLinear
is
not
None
,
"Need to install fused_dense"
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
return
y
class
GatedMlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment