Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8ee62efc
Unverified
Commit
8ee62efc
authored
Jul 27, 2023
by
Haodong Lyu
Committed by
GitHub
Jul 26, 2023
Browse files
Implement ParallelGatedMlp (#251)
parent
56ccaff1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
152 additions
and
5 deletions
+152
-5
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+11
-3
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+27
-2
tests/modules/test_mlp_parallel.py
tests/modules/test_mlp_parallel.py
+114
-0
No files found.
flash_attn/models/gpt.py
View file @
8ee62efc
...
...
@@ -18,7 +18,8 @@ from einops import rearrange
from
flash_attn.ops.activations
import
sqrelu_fwd
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
Mlp
,
GatedMlp
,
ParallelMLP
,
FusedMLP
,
ParallelFusedMLP
from
flash_attn.modules.mlp
import
Mlp
,
ParallelMLP
,
FusedMLP
,
ParallelFusedMLP
from
flash_attn.modules.mlp
import
GatedMlp
,
ParallelGatedMlp
from
flash_attn.modules.block
import
Block
,
ParallelBlock
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
...
...
@@ -122,8 +123,13 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
activation
=
(
F
.
sigmoid
if
config
.
activation_function
==
'glu'
else
(
F
.
silu
if
config
.
activation_function
==
'swiglu'
else
F
.
gelu
))
mlp_cls
=
partial
(
GatedMlp
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
factory_kwargs
)
mlp_cls
=
GatedMlp
if
process_group
is
None
else
ParallelGatedMlp
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
parallel_kwargs
,
**
factory_kwargs
)
else
:
if
config
.
activation_function
==
'relu'
:
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
...
...
@@ -160,6 +166,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
**
parallel_kwargs
,
**
factory_kwargs
)
elif
fused_dense_sqrelu_dense
:
if
process_group
is
not
None
:
assert
fused_mlp
,
'Tensor Parallel is not implemented for FusedDenseSqreluDense'
assert
FusedDenseSqreluDense
is
not
None
mlp_cls
=
partial
(
FusedDenseSqreluDense
,
hidden_features
=
config
.
n_inner
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
**
factory_kwargs
)
...
...
flash_attn/modules/mlp.py
View file @
8ee62efc
...
...
@@ -11,9 +11,10 @@ except ImportError:
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
try
:
from
flash_attn.ops.fused_dense
import
FusedMLP
,
ParallelFusedMLP
from
flash_attn.ops.fused_dense
import
FusedMLP
,
ParallelFusedMLP
,
ColumnParallelLinear
,
RowParallelLinear
except
ImportError
:
FusedMLP
,
ParallelFusedMLP
=
None
,
None
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
class
Mlp
(
nn
.
Module
):
...
...
@@ -73,7 +74,7 @@ class GatedMlp(nn.Module):
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
2
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
...
...
@@ -84,3 +85,27 @@ class GatedMlp(nn.Module):
y
=
y
*
self
.
activation
(
gate
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
ParallelGatedMlp
(
GatedMlp
):
""" Parallel GatedMlp """
def
__init__
(
self
,
in_features
,
process_group
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
return_residual
=
False
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
(
in_features
,
hidden_features
=
hidden_features
,
out_features
=
out_features
,
activation
=
activation
,
bias1
=
bias1
,
bias2
=
bias2
,
multiple_of
=
multiple_of
,
return_residual
=
return_residual
,
device
=
device
,
dtype
=
dtype
)
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
int
(
8
*
in_features
/
3
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
2
*
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
tests/modules/test_mlp_parallel.py
0 → 100644
View file @
8ee62efc
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py
import
torch
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
apex.transformer
import
parallel_state
from
apex.transformer
import
tensor_parallel
from
flash_attn.modules.mlp
import
GatedMlp
,
ParallelGatedMlp
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
'activation'
,
[
F
.
silu
,
F
.
sigmoid
])
# @pytest.mark.parametrize('activation', [F.silu])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('dim', [1024])
def
test_mlp_parallel
(
dim
,
activation
,
sequence_parallel
,
world_size
,
dtype
):
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
seqlen
=
1024
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
dim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
if
sequence_parallel
:
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
GatedMlp
(
dim
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
partition_dim
=
model_pt
.
fc1
.
weight
.
shape
[
0
]
//
2
//
world_size
model
=
ParallelGatedMlp
(
dim
,
parallel_state
.
get_tensor_model_parallel_group
(),
activation
=
activation
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
fc1
.
weight
,
'(two o) i -> two o i'
,
two
=
2
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'two o i -> (two o) i'
)
)
model
.
fc1
.
bias
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
fc1
.
bias
,
'(two o) -> two o'
,
two
=
2
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'two o -> (two o)'
)
)
model
.
fc2
.
weight
.
copy_
(
model_pt
.
fc2
.
weight
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
)
if
rank
==
0
:
model
.
fc2
.
bias
.
copy_
(
model_pt
.
fc2
.
bias
)
out
=
model
(
x
)
out_pt
=
model_pt
(
x_pt
)
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
rearrange
(
rearrange
(
model_pt
.
fc1
.
weight
.
grad
,
'(two o) i -> two o i'
,
two
=
2
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'two o i -> (two o) i'
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
rearrange
(
rearrange
(
model_pt
.
fc1
.
bias
.
grad
,
'(two o) -> two o'
,
two
=
2
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'two o -> (two o)'
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt
.
fc2
.
weight
.
grad
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
rtol
=
rtol
,
atol
=
atol
)
if
rank
==
0
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt
.
fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment