Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
93383bd5
Commit
93383bd5
authored
Jan 07, 2023
by
Tri Dao
Browse files
[TP] Implement TensorParallel without sequence parallel
parent
ce26d3d7
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
257 additions
and
133 deletions
+257
-133
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+27
-13
flash_attn/modules/block.py
flash_attn/modules/block.py
+15
-1
flash_attn/modules/embedding.py
flash_attn/modules/embedding.py
+5
-3
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+5
-5
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+57
-40
flash_attn/utils/distributed.py
flash_attn/utils/distributed.py
+38
-12
tests/models/test_gpt_parallel.py
tests/models/test_gpt_parallel.py
+14
-10
tests/modules/test_block_parallel.py
tests/modules/test_block_parallel.py
+30
-15
tests/modules/test_embedding_parallel.py
tests/modules/test_embedding_parallel.py
+9
-4
tests/modules/test_mha_parallel.py
tests/modules/test_mha_parallel.py
+18
-8
tests/ops/test_fused_dense_parallel.py
tests/ops/test_fused_dense_parallel.py
+39
-22
No files found.
flash_attn/models/gpt.py
View file @
93383bd5
...
@@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA
...
@@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
,
ParallelFusedDenseGeluDense
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
,
ParallelFusedDenseGeluDense
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_s
equence_parallel
_params
from
flash_attn.utils.distributed
import
sync_s
hared
_params
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.generation
import
GenerationMixin
...
@@ -62,7 +62,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -62,7 +62,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
mha_cls
=
MHA
if
process_group
is
None
else
ParallelMHA
mha_cls
=
MHA
if
process_group
is
None
else
ParallelMHA
serial_kwargs
=
({
'fused_bias_fc'
:
fused_bias_fc
,
'dwconv'
:
dwconv
}
serial_kwargs
=
({
'fused_bias_fc'
:
fused_bias_fc
,
'dwconv'
:
dwconv
}
if
process_group
is
None
else
{})
if
process_group
is
None
else
{})
parallel_kwargs
=
{
'process_group'
:
process_group
}
if
process_group
is
not
None
else
{}
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attn_pdrop
,
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
...
@@ -99,7 +101,9 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
...
@@ -99,7 +101,9 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if
FusedDenseGeluDense
is
None
:
if
FusedDenseGeluDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
mlp_cls
=
FusedDenseGeluDense
if
process_group
is
None
else
ParallelFusedDenseGeluDense
mlp_cls
=
FusedDenseGeluDense
if
process_group
is
None
else
ParallelFusedDenseGeluDense
parallel_kwargs
=
{
'process_group'
:
process_group
}
if
process_group
is
not
None
else
{}
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
mlp_cls
=
partial
(
mlp_cls
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
**
parallel_kwargs
,
**
factory_kwargs
)
**
parallel_kwargs
,
**
factory_kwargs
)
elif
fused_dense_sqrelu_dense
:
elif
fused_dense_sqrelu_dense
:
...
@@ -113,13 +117,15 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
...
@@ -113,13 +117,15 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
def
create_block
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
def
create_block
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
mixer_cls
=
create_mixer_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
mixer_cls
=
create_mixer_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
)
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
prenorm
=
True
,
resid_dropout
=
config
.
resid_pdrop
,
prenorm
=
True
,
resid_dropout
=
config
.
resid_pdrop
,
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
),
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
),
sequence_parallel
=
process_group
is
not
None
)
sequence_parallel
=
sequence_parallel
and
process_group
is
not
None
,
mark_shared_params
=
process_group
is
not
None
)
block
.
layer_idx
=
layer_idx
block
.
layer_idx
=
layer_idx
return
block
return
block
...
@@ -180,6 +186,7 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -180,6 +186,7 @@ class GPTModel(GPTPreTrainedModel):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'sqrelu'
]
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'sqrelu'
]
self
.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
self
.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
if
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
!=
0
:
if
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
!=
0
:
...
@@ -192,7 +199,8 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -192,7 +199,8 @@ class GPTModel(GPTPreTrainedModel):
else
:
else
:
self
.
embeddings
=
ParallelGPT2Embeddings
(
self
.
embeddings
=
ParallelGPT2Embeddings
(
config
.
hidden_size
,
config
.
vocab_size
,
config
.
max_position_embeddings
,
config
.
hidden_size
,
config
.
vocab_size
,
config
.
max_position_embeddings
,
process_group
=
process_group
,
**
factory_kwargs
process_group
=
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
**
factory_kwargs
)
)
self
.
emb_drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
emb_drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
...
@@ -209,10 +217,13 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -209,10 +217,13 @@ class GPTModel(GPTPreTrainedModel):
# is the final layer norm.
# is the final layer norm.
self
.
ln_0
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
,
self
.
ln_0
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
)
**
factory_kwargs
)
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if
process_group
is
not
None
:
if
process_group
is
not
None
:
for
p
in
self
.
ln_0
.
parameters
():
for
p
in
self
.
ln_0
.
parameters
():
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
p
.
_shared_params
=
True
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
if
self
.
sequence_parallel
:
p
.
_sequence_parallel
=
True
self
.
layers
=
nn
.
ModuleList
([
create_block
(
config
,
layer_idx
=
i
,
process_group
=
process_group
,
self
.
layers
=
nn
.
ModuleList
([
create_block
(
config
,
layer_idx
=
i
,
process_group
=
process_group
,
**
factory_kwargs
)
**
factory_kwargs
)
...
@@ -224,14 +235,14 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -224,14 +235,14 @@ class GPTModel(GPTPreTrainedModel):
def
tie_weights
(
self
):
def
tie_weights
(
self
):
if
self
.
process_group
is
not
None
:
if
self
.
process_group
is
not
None
:
sync_s
equence_parallel
_params
(
self
,
self
.
process_group
)
sync_s
hared
_params
(
self
,
self
.
process_group
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen.
# Only the attention layers need to know the seqlen.
embedding_kwargs
=
({
'combine_batch_seqlen_dim'
:
True
}
embedding_kwargs
=
({
'combine_batch_seqlen_dim'
:
True
}
if
self
.
process_group
is
not
None
else
{})
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{})
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
...
@@ -243,7 +254,8 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -243,7 +254,8 @@ class GPTModel(GPTPreTrainedModel):
self
.
emb_drop
.
p
if
self
.
training
else
0.0
,
self
.
ln_0
.
eps
,
prenorm
=
True
,
self
.
emb_drop
.
p
if
self
.
training
else
0.0
,
self
.
ln_0
.
eps
,
prenorm
=
True
,
residual_in_fp32
=
True
residual_in_fp32
=
True
)
)
mixer_kwargs
=
({
'seqlen'
:
input_ids
.
shape
[
1
]}
if
self
.
process_group
is
not
None
else
{})
mixer_kwargs
=
({
'seqlen'
:
input_ids
.
shape
[
1
]}
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{})
if
inference_params
is
not
None
:
if
inference_params
is
not
None
:
mixer_kwargs
[
'inference_params'
]
=
inference_params
mixer_kwargs
[
'inference_params'
]
=
inference_params
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
...
@@ -263,8 +275,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -263,8 +275,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
else
:
else
:
if
ColumnParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
:
raise
ImportError
(
'fused_dense_lib is not installed'
)
raise
ImportError
(
'fused_dense_lib is not installed'
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
n_embd
,
config
.
vocab_size
,
process_group
,
self
.
lm_head
=
ColumnParallelLinear
(
bias
=
False
,
**
factory_kwargs
)
config
.
n_embd
,
config
.
vocab_size
,
process_group
,
bias
=
False
,
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
),
**
factory_kwargs
)
# Initialize weights and apply final processing
# Initialize weights and apply final processing
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
))
initializer_range
=
config
.
initializer_range
))
...
@@ -273,7 +287,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -273,7 +287,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def
tie_weights
(
self
):
def
tie_weights
(
self
):
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
if
self
.
process_group
is
not
None
:
if
self
.
process_group
is
not
None
:
sync_s
equence_parallel
_params
(
self
,
self
.
process_group
)
sync_s
hared
_params
(
self
,
self
.
process_group
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
"""
"""
...
...
flash_attn/modules/block.py
View file @
93383bd5
...
@@ -23,7 +23,8 @@ class Block(nn.Module):
...
@@ -23,7 +23,8 @@ class Block(nn.Module):
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
prenorm
=
True
,
resid_dropout
=
0.
,
drop_path
=
0.
,
dropout_cls
=
nn
.
Dropout
,
prenorm
=
True
,
resid_dropout
=
0.
,
drop_path
=
0.
,
fused_dropout_add_ln
=
False
,
return_residual
=
False
,
sequence_parallel
=
False
):
fused_dropout_add_ln
=
False
,
return_residual
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
):
"""
"""
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
This is for performance reason: for post-norm architecture, returning the input allows us
...
@@ -51,6 +52,12 @@ class Block(nn.Module):
...
@@ -51,6 +52,12 @@ class Block(nn.Module):
assert
dropout_add_layer_norm
is
not
None
,
'dropout_add_ln is not installed'
assert
dropout_add_layer_norm
is
not
None
,
'dropout_add_ln is not installed'
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if
sequence_parallel
:
if
sequence_parallel
:
for
p
in
self
.
norm1
.
parameters
():
for
p
in
self
.
norm1
.
parameters
():
...
@@ -58,6 +65,13 @@ class Block(nn.Module):
...
@@ -58,6 +65,13 @@ class Block(nn.Module):
if
hasattr
(
self
,
'norm2'
):
if
hasattr
(
self
,
'norm2'
):
for
p
in
self
.
norm2
.
parameters
():
for
p
in
self
.
norm2
.
parameters
():
p
.
_sequence_parallel
=
True
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if
mark_shared_params
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_shared_params
=
True
if
hasattr
(
self
,
'norm2'
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_kwargs
=
None
):
mixer_kwargs
=
None
):
...
...
flash_attn/modules/embedding.py
View file @
93383bd5
...
@@ -6,7 +6,7 @@ from torch import Tensor
...
@@ -6,7 +6,7 @@ from torch import Tensor
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.utils.distributed
import
reduce_scatter
from
flash_attn.utils.distributed
import
reduce_scatter
,
all_reduce
class
GPT2Embeddings
(
nn
.
Module
):
class
GPT2Embeddings
(
nn
.
Module
):
...
@@ -130,13 +130,14 @@ class ColumnParallelEmbedding(nn.Embedding):
...
@@ -130,13 +130,14 @@ class ColumnParallelEmbedding(nn.Embedding):
class
ParallelGPT2Embeddings
(
nn
.
Module
):
class
ParallelGPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
):
padding_idx
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
):
"""
"""
If max_position_embeddings <= 0, there's no position embeddings
If max_position_embeddings <= 0, there's no position embeddings
"""
"""
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
word_embeddings
=
VocabParallelEmbedding
(
self
.
word_embeddings
=
VocabParallelEmbedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
**
factory_kwargs
**
factory_kwargs
...
@@ -167,4 +168,5 @@ class ParallelGPT2Embeddings(nn.Module):
...
@@ -167,4 +168,5 @@ class ParallelGPT2Embeddings(nn.Module):
embeddings
[...,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
embeddings
[...,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
if
combine_batch_seqlen_dim
:
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
'b s d -> (b s) d'
)
embeddings
=
rearrange
(
embeddings
,
'b s d -> (b s) d'
)
return
embeddings
if
world_size
<=
1
else
reduce_scatter
(
embeddings
,
self
.
process_group
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
embeddings
if
world_size
<=
1
else
reduce_fn
(
embeddings
,
self
.
process_group
)
flash_attn/modules/mha.py
View file @
93383bd5
...
@@ -497,11 +497,10 @@ class ParallelMHA(nn.Module):
...
@@ -497,11 +497,10 @@ class ParallelMHA(nn.Module):
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
bias
=
True
,
dropout
=
0.0
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
0
,
rotary_emb_scale_base
=
0
,
use_flash_attn
=
False
,
checkpointing
=
False
,
u
se
_flash_attn
=
False
,
checkpointing
=
Fals
e
,
device
=
None
,
dtype
=
None
)
->
None
:
se
quence_parallel
=
Tru
e
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
process_group
=
process_group
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
self
.
causal
=
causal
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
...
@@ -521,12 +520,13 @@ class ParallelMHA(nn.Module):
...
@@ -521,12 +520,13 @@ class ParallelMHA(nn.Module):
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
process_group
,
bias
=
bias
,
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
process_group
,
bias
=
bias
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
# output projection always have the bias (for now)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
**
factory_kwargs
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
seqlen
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
seqlen
=
None
,
**
kwargs
):
"""
"""
...
...
flash_attn/ops/fused_dense.py
View file @
93383bd5
...
@@ -15,26 +15,29 @@ from torch.cuda.amp import custom_bwd, custom_fwd
...
@@ -15,26 +15,29 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import
fused_dense_lib
as
fused_dense_cuda
import
fused_dense_lib
as
fused_dense_cuda
from
flash_attn.ops.gelu_activation
import
gelu_bwd
from
flash_attn.ops.gelu_activation
import
gelu_bwd
from
flash_attn.utils.distributed
import
all_gather_raw
,
reduce_scatter_raw
,
reduce_scatter
from
flash_attn.utils.distributed
import
all_gather_raw
,
reduce_scatter_raw
,
all_reduce_raw
from
flash_attn.utils.distributed
import
reduce_scatter
,
all_reduce
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
@
custom_fwd
def
forward
(
ctx
,
x
,
weight
,
bias
,
return_residual
=
False
,
process_group
=
None
):
def
forward
(
ctx
,
x
,
weight
,
bias
,
return_residual
=
False
,
process_group
=
None
,
sequence_parallel
=
True
):
"""
"""
If process_group is not None, we're doing Tensor Parallel
with sequence parallelism:
If process_group is not None
and sequence_parallel=True
, we're doing Tensor Parallel
we do an all_gather_raw of x before doing the matmul.
with sequence parallelism:
we do an all_gather_raw of x before doing the matmul.
"""
"""
ctx
.
compute_weight_gradient
=
weight
.
requires_grad
ctx
.
compute_weight_gradient
=
weight
.
requires_grad
ctx
.
return_residual
=
return_residual
ctx
.
return_residual
=
return_residual
ctx
.
process_group
=
process_group
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
# We want to kick off the all_gather early, before weight dtype conversion
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
else
:
...
@@ -44,7 +47,7 @@ class FusedDenseFunc(torch.autograd.Function):
...
@@ -44,7 +47,7 @@ class FusedDenseFunc(torch.autograd.Function):
weight
=
weight
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
weight
=
weight
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
bias
=
bias
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
if
bias
is
not
None
else
None
bias
=
bias
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
if
bias
is
not
None
else
None
weight
=
weight
.
contiguous
()
weight
=
weight
.
contiguous
()
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
...
@@ -66,9 +69,10 @@ class FusedDenseFunc(torch.autograd.Function):
...
@@ -66,9 +69,10 @@ class FusedDenseFunc(torch.autograd.Function):
grad_input
,
=
args
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
grad_input
=
grad_input
.
contiguous
()
process_group
=
ctx
.
process_group
process_group
=
ctx
.
process_group
sequence_parallel
=
ctx
.
sequence_parallel
if
ctx
.
compute_weight_gradient
:
if
ctx
.
compute_weight_gradient
:
x
,
weight
=
ctx
.
saved_tensors
x
,
weight
=
ctx
.
saved_tensors
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
else
:
total_x
=
x
total_x
=
x
...
@@ -86,13 +90,13 @@ class FusedDenseFunc(torch.autograd.Function):
...
@@ -86,13 +90,13 @@ class FusedDenseFunc(torch.autograd.Function):
grad_output
,
weight
)
grad_output
,
weight
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
if
process_group
is
not
None
:
grad_input
,
handle_grad_input
=
reduce_scatter_raw
(
grad_input
,
process_group
,
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
async_op
=
True
)
grad_input
,
handle_grad_input
=
reduce_fn
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
else
:
grad_input
=
None
grad_input
=
None
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
assert
ctx
.
compute_weight_gradient
assert
ctx
.
compute_weight_gradient
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
handle_x
.
wait
()
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
grad_weight
,
grad_bias
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_output
,
ctx
.
needs_input_grad
[
2
]
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_output
,
ctx
.
needs_input_grad
[
2
]
...
@@ -102,15 +106,17 @@ class FusedDenseFunc(torch.autograd.Function):
...
@@ -102,15 +106,17 @@ class FusedDenseFunc(torch.autograd.Function):
grad_bias
=
grad_output
if
ctx
.
needs_input_grad
[
2
]
else
None
grad_bias
=
grad_output
if
ctx
.
needs_input_grad
[
2
]
else
None
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
handle_grad_input
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
def
fused_dense_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
def
fused_dense_func
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
=
None
,
return_residual
:
bool
=
False
,
process_group
:
Optional
[
ProcessGroup
]
=
None
):
return_residual
:
bool
=
False
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
):
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
if
x
.
is_cuda
and
weight
.
is_cuda
and
(
bias
is
None
or
bias
.
is_cuda
)
and
dtype_eligible
:
if
x
.
is_cuda
and
weight
.
is_cuda
and
(
bias
is
None
or
bias
.
is_cuda
)
and
dtype_eligible
:
return
FusedDenseFunc
.
apply
(
x
,
weight
,
bias
,
return_residual
,
process_group
)
return
FusedDenseFunc
.
apply
(
x
,
weight
,
bias
,
return_residual
,
process_group
,
sequence_parallel
)
else
:
else
:
assert
process_group
is
None
assert
process_group
is
None
out
=
F
.
linear
(
x
,
weight
,
bias
)
out
=
F
.
linear
(
x
,
weight
,
bias
)
...
@@ -136,7 +142,7 @@ class FusedDense(nn.Linear):
...
@@ -136,7 +142,7 @@ class FusedDense(nn.Linear):
class
ColumnParallelLinear
(
nn
.
Linear
):
class
ColumnParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
bias
:
bool
=
True
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
out_features
%
world_size
!=
0
:
if
out_features
%
world_size
!=
0
:
raise
ValueError
(
f
'out_features (
{
out_features
}
) must be divisible by '
raise
ValueError
(
f
'out_features (
{
out_features
}
) must be divisible by '
...
@@ -144,19 +150,20 @@ class ColumnParallelLinear(nn.Linear):
...
@@ -144,19 +150,20 @@ class ColumnParallelLinear(nn.Linear):
super
().
__init__
(
in_features
,
out_features
//
world_size
,
bias
=
bias
,
super
().
__init__
(
in_features
,
out_features
//
world_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
device
=
device
,
dtype
=
dtype
)
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
We're doing Tensor Parallel with sequence parallelism: we do an all_gather of
# we do an all_gather of x before doing the matmul.
x before doing the matmul
.
# If not, then the input is already gathered
.
"""
return
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
,
process_group
=
self
.
process_group
,
return
fused_d
en
s
e_
func
(
x
,
self
.
weight
,
self
.
bias
,
process_group
=
self
.
process_group
)
sequ
en
c
e_
parallel
=
self
.
sequence_parallel
)
class
RowParallelLinear
(
nn
.
Linear
):
class
RowParallelLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
process_group
:
ProcessGroup
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
bias
:
bool
=
True
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
if
in_features
%
world_size
!=
0
:
if
in_features
%
world_size
!=
0
:
...
@@ -166,6 +173,7 @@ class RowParallelLinear(nn.Linear):
...
@@ -166,6 +173,7 @@ class RowParallelLinear(nn.Linear):
super
().
__init__
(
in_features
//
world_size
,
out_features
,
bias
=
bias
and
rank
==
0
,
super
().
__init__
(
in_features
//
world_size
,
out_features
,
bias
=
bias
and
rank
==
0
,
device
=
device
,
dtype
=
dtype
)
device
=
device
,
dtype
=
dtype
)
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
...
@@ -173,7 +181,8 @@ class RowParallelLinear(nn.Linear):
...
@@ -173,7 +181,8 @@ class RowParallelLinear(nn.Linear):
a reduce_scatter of the result.
a reduce_scatter of the result.
"""
"""
out
=
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
)
out
=
fused_dense_func
(
x
,
self
.
weight
,
self
.
bias
)
return
reduce_scatter
(
out
,
self
.
process_group
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
reduce_fn
(
out
,
self
.
process_group
)
class
FusedDenseGeluDenseFunc
(
torch
.
autograd
.
Function
):
class
FusedDenseGeluDenseFunc
(
torch
.
autograd
.
Function
):
...
@@ -181,10 +190,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -181,10 +190,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_fwd
@
custom_fwd
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_pre_act
=
True
,
return_residual
=
False
,
def
forward
(
ctx
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_pre_act
=
True
,
return_residual
=
False
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
process_group
=
None
):
checkpoint_lvl
=
0
,
heuristic
=
0
,
process_group
=
None
,
sequence_parallel
=
True
):
"""
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
we do an all_gather of x before doing the matmul.
with sequence parallelism: we do an all_gather of x before doing the matmul.
If sequence_parallel=False, then the input is already gathered.
checkpoint_lvl:
checkpoint_lvl:
0: no recomputation in the bwd
0: no recomputation in the bwd
...
@@ -197,13 +207,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -197,13 +207,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
ctx
.
return_residual
=
return_residual
ctx
.
return_residual
=
return_residual
ctx
.
process_group
=
process_group
ctx
.
process_group
=
process_group
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
checkpoint_lvl
=
checkpoint_lvl
ctx
.
heuristic
=
heuristic
ctx
.
heuristic
=
heuristic
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
to
(
dtype
=
torch
.
get_autocast_gpu_dtype
())
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
# We want to kick off the all_gather early, before weight dtype conversion
# We want to kick off the all_gather early, before weight dtype conversion
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
else
:
else
:
...
@@ -218,7 +229,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -218,7 +229,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
bias1
=
bias1
.
contiguous
()
if
bias1
is
not
None
else
None
bias1
=
bias1
.
contiguous
()
if
bias1
is
not
None
else
None
weight2
=
weight2
.
contiguous
()
weight2
=
weight2
.
contiguous
()
bias2
=
bias2
.
contiguous
()
if
bias2
is
not
None
else
None
bias2
=
bias2
.
contiguous
()
if
bias2
is
not
None
else
None
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
handle_x
.
wait
()
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_shape
,
n
=
total_x
.
shape
[:
-
1
],
total_x
.
shape
[
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
...
@@ -257,13 +268,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -257,13 +268,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_input
,
=
args
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
grad_input
=
grad_input
.
contiguous
()
process_group
=
ctx
.
process_group
process_group
=
ctx
.
process_group
sequence_parallel
=
ctx
.
sequence_parallel
x
,
weight1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
x
,
weight1
,
weight2
,
*
rest
=
ctx
.
saved_tensors
if
process_group
is
None
:
if
process_group
is
None
or
not
sequence_parallel
:
total_x
=
x
total_x
=
x
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_shape
=
grad_output
.
shape
[:
-
1
]
batch_dim
=
batch_shape
.
numel
()
batch_dim
=
batch_shape
.
numel
()
if
checkpoint_lvl
in
[
0
,
1
]:
if
checkpoint_lvl
in
[
0
,
1
]:
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
total_x
,
handle_x
=
all_gather_raw
(
x
,
process_group
,
async_op
=
True
)
if
checkpoint_lvl
==
0
:
if
checkpoint_lvl
==
0
:
gelu_in
,
output1
=
rest
gelu_in
,
output1
=
rest
...
@@ -272,7 +284,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -272,7 +284,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
output1
=
F
.
gelu
(
gelu_in
,
approximate
=
'tanh'
)
elif
checkpoint_lvl
==
2
:
elif
checkpoint_lvl
==
2
:
bias1
,
=
rest
bias1
,
=
rest
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
heuristic
==
-
1
:
gelu_in
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
gelu_in
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
...
@@ -314,13 +326,13 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -314,13 +326,13 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_gelu
,
weight1
)
grad_gelu
,
weight1
)
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
grad_input
=
grad_input
.
reshape
(
*
batch_shape
,
grad_input
.
shape
[
-
1
])
if
process_group
is
not
None
:
if
process_group
is
not
None
:
grad_input
,
handle_grad_input
=
reduce_scatter_raw
(
grad_input
,
process_group
,
reduce_fn
=
reduce_scatter_raw
if
sequence_parallel
else
all_reduce_raw
async_op
=
True
)
grad_input
,
handle_grad_input
=
reduce_fn
(
grad_input
,
process_group
,
async_op
=
True
)
else
:
else
:
grad_input
=
None
grad_input
=
None
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
heuristic
==
-
1
:
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
handle_x
.
wait
()
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
grad_weight1
,
grad_bias1
=
fused_dense_cuda
.
linear_bias_wgrad
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_gelu
,
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
grad_gelu
,
...
@@ -331,7 +343,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -331,7 +343,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_bias1
=
grad_gelu
if
ctx
.
needs_input_grad
[
2
]
else
None
grad_bias1
=
grad_gelu
if
ctx
.
needs_input_grad
[
2
]
else
None
else
:
else
:
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
if
process_group
is
not
None
:
if
process_group
is
not
None
and
sequence_parallel
:
handle_x
.
wait
()
handle_x
.
wait
()
grad_weight1
=
F
.
linear
(
grad_gelu
.
t
(),
grad_weight1
=
F
.
linear
(
grad_gelu
.
t
(),
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]).
t
())
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]).
t
())
...
@@ -340,7 +352,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
...
@@ -340,7 +352,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
if
process_group
is
not
None
and
ctx
.
needs_input_grad
[
0
]:
handle_grad_input
.
wait
()
handle_grad_input
.
wait
()
return
(
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
return
(
grad_input
,
grad_weight1
,
grad_bias1
,
grad_weight2
,
grad_bias2
,
None
,
None
,
None
,
None
,
None
)
None
,
None
,
None
,
None
,
None
,
None
)
def
fused_dense_gelu_dense_func
(
def
fused_dense_gelu_dense_func
(
...
@@ -348,15 +360,16 @@ def fused_dense_gelu_dense_func(
...
@@ -348,15 +360,16 @@ def fused_dense_gelu_dense_func(
bias2
:
Optional
[
Tensor
]
=
None
,
bias2
:
Optional
[
Tensor
]
=
None
,
save_pre_act
:
bool
=
True
,
return_residual
:
bool
=
False
,
save_pre_act
:
bool
=
True
,
return_residual
:
bool
=
False
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
,
checkpoint_lvl
:
int
=
0
,
heuristic
:
int
=
0
,
process_group
:
Optional
[
ProcessGroup
]
=
None
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
):
):
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
if
(
x
.
is_cuda
and
weight1
.
is_cuda
and
weight2
.
is_cuda
and
(
bias1
is
None
or
bias1
.
is_cuda
)
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
dtype_eligible
):
and
(
bias2
is
None
or
bias2
.
is_cuda
)
and
dtype_eligible
):
return
FusedDenseGeluDenseFunc
.
apply
(
return
FusedDenseGeluDenseFunc
.
apply
(
x
,
weight1
,
bias1
,
weight2
,
bias2
,
x
,
weight1
,
bias1
,
weight2
,
bias2
,
save_pre_act
,
return_residual
,
save_pre_act
,
return_residual
,
checkpoint_lvl
,
heuristic
,
process_group
checkpoint_lvl
,
heuristic
,
process_group
,
sequence_parallel
)
)
else
:
else
:
assert
process_group
is
None
assert
process_group
is
None
...
@@ -418,7 +431,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
...
@@ -418,7 +431,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
def
__init__
(
self
,
in_features
,
hidden_features
,
out_features
=
None
,
process_group
:
ProcessGroup
=
None
,
bias1
=
True
,
bias2
=
True
,
process_group
:
ProcessGroup
=
None
,
bias1
=
True
,
bias2
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
sequence_parallel
=
True
,
checkpoint_lvl
=
0
,
heuristic
=
0
,
device
=
None
,
dtype
=
None
):
"""
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
we do an all_gather of x before doing the matmul, gelu, then matmul.
...
@@ -441,6 +454,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
...
@@ -441,6 +454,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
if
out_features
is
None
:
if
out_features
is
None
:
out_features
=
in_features
out_features
=
in_features
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
heuristic
=
heuristic
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
...
@@ -452,6 +466,9 @@ class ParallelFusedDenseGeluDense(nn.Module):
...
@@ -452,6 +466,9 @@ class ParallelFusedDenseGeluDense(nn.Module):
out
=
fused_dense_gelu_dense_func
(
out
=
fused_dense_gelu_dense_func
(
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
x
,
self
.
fc1
.
weight
,
self
.
fc2
.
weight
,
self
.
fc1
.
bias
,
self
.
fc2
.
bias
,
save_pre_act
=
self
.
training
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
save_pre_act
=
self
.
training
,
checkpoint_lvl
=
self
.
checkpoint_lvl
,
heuristic
=
self
.
heuristic
,
process_group
=
self
.
process_group
heuristic
=
self
.
heuristic
,
process_group
=
self
.
process_group
,
sequence_parallel
=
self
.
sequence_parallel
)
)
return
reduce_scatter
(
out
,
self
.
process_group
)
reduce_fn
=
reduce_scatter
if
self
.
sequence_parallel
else
all_reduce
return
reduce_fn
(
out
,
self
.
process_group
)
flash_attn/utils/distributed.py
View file @
93383bd5
...
@@ -14,7 +14,7 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
...
@@ -14,7 +14,7 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
torch
.
distributed
.
reduce_scatter_tensor
=
torch
.
distributed
.
_reduce_scatter_base
# Raw operation, oes
does
support autograd, but does support async
# Raw operation,
d
oes
not
support autograd, but does support async
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
def
all_gather_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
output
=
torch
.
empty
(
world_size
*
input_
.
shape
[
0
],
*
input_
.
shape
[
1
:],
...
@@ -24,7 +24,7 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
...
@@ -24,7 +24,7 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
return
output
,
handle
return
output
,
handle
# Raw operation, oes
does
support autograd, but does support async
# Raw operation,
d
oes
not
support autograd, but does support async
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
def
reduce_scatter_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
assert
input_
.
shape
[
0
]
%
world_size
==
0
assert
input_
.
shape
[
0
]
%
world_size
==
0
...
@@ -36,6 +36,13 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
...
@@ -36,6 +36,13 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
return
output
,
handle
return
output
,
handle
# Raw operation, does not support autograd, but does support async
def
all_reduce_raw
(
input_
:
Tensor
,
process_group
:
ProcessGroup
,
async_op
:
bool
=
False
):
input_
=
input_
.
contiguous
()
handle
=
torch
.
distributed
.
all_reduce
(
input_
,
group
=
process_group
,
async_op
=
async_op
)
return
input_
,
handle
class
AllGatherFunc
(
torch
.
autograd
.
Function
):
class
AllGatherFunc
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatenate."""
"""Gather the input from sequence parallel region and concatenate."""
...
@@ -74,12 +81,30 @@ class ReduceScatterFunc(torch.autograd.Function):
...
@@ -74,12 +81,30 @@ class ReduceScatterFunc(torch.autograd.Function):
reduce_scatter
=
ReduceScatterFunc
.
apply
reduce_scatter
=
ReduceScatterFunc
.
apply
def
sync_sequence_parallel_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
class
AllReduceFunc
(
torch
.
autograd
.
Function
):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
"""Gather the input from sequence parallel region and concatenate."""
@
staticmethod
def
forward
(
ctx
,
input_
:
Tensor
,
process_group
:
ProcessGroup
)
->
Tensor
:
ctx
.
process_group
=
process_group
output
,
_
=
all_reduce_raw
(
input_
,
process_group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
Tensor
):
return
grad_output
,
None
# Supports autograd, but does not support async
all_reduce
=
AllReduceFunc
.
apply
def
sync_shared_params
(
model
:
torch
.
nn
.
Module
,
process_group
:
ProcessGroup
):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pa
r
ams_s
eqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
pa
m
ams_s
hared
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
'_s
equence
_para
llel
'
,
False
)}
if
getattr
(
p
,
'_s
hared
_para
ms
'
,
False
)}
for
_
,
p
in
sorted
(
pa
r
ams_s
eqparallel
.
items
()):
for
_
,
p
in
sorted
(
pa
m
ams_s
hared
.
items
()):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# Broadcast needs src to be global rank, not group rank
# Broadcast needs src to be global rank, not group rank
torch
.
distributed
.
broadcast
(
torch
.
distributed
.
broadcast
(
...
@@ -94,8 +119,9 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc
...
@@ -94,8 +119,9 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
params_seqparallel
=
{
name
:
p
for
name
,
p
in
model
.
named_parameters
()
if
getattr
(
p
,
'_sequence_parallel'
,
False
)}
if
getattr
(
p
,
'_sequence_parallel'
,
False
)}
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
grads
=
[
p
.
grad
for
_
,
p
in
sorted
(
params_seqparallel
.
items
())]
with
torch
.
no_grad
():
if
grads
:
coalesced
=
torch
.
_utils
.
_flatten_dense_tensors
(
grads
)
with
torch
.
no_grad
():
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
process_group
)
coalesced
=
torch
.
_utils
.
_flatten_dense_tensors
(
grads
)
for
buf
,
synced
in
zip
(
grads
,
torch
.
_utils
.
_unflatten_dense_tensors
(
coalesced
,
grads
)):
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
process_group
)
buf
.
copy_
(
synced
)
for
buf
,
synced
in
zip
(
grads
,
torch
.
_utils
.
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
tests/models/test_gpt_parallel.py
View file @
93383bd5
...
@@ -23,10 +23,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...
@@ -23,10 +23,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
'has_pos_emb'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_pos_emb'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_pos_emb', [True])
# @pytest.mark.parametrize('has_pos_emb', [True])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
def
test_gpt_parallel
(
dim
,
has_pos_emb
,
world_size
,
dtype
):
def
test_gpt_parallel
(
dim
,
has_pos_emb
,
sequence_parallel
,
world_size
,
dtype
):
head_dim
=
64
head_dim
=
64
assert
dim
%
head_dim
==
0
assert
dim
%
head_dim
==
0
num_heads
=
dim
//
head_dim
num_heads
=
dim
//
head_dim
...
@@ -59,7 +61,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
...
@@ -59,7 +61,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
scale_attn_by_inverse_layer_idx
=
True
,
use_flash_attn
=
True
,
scale_attn_by_inverse_layer_idx
=
True
,
use_flash_attn
=
True
,
fused_dense_gelu_dense
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
,
fused_dense_gelu_dense
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
,
rotary_emb_fraction
=
0.0
if
has_pos_emb
else
0.5
,
rotary_emb_fraction
=
0.0
if
has_pos_emb
else
0.5
,
pad_vocab_size_multiple
=
8
*
world_size
)
pad_vocab_size_multiple
=
8
*
world_size
,
sequence_parallel
=
sequence_parallel
)
model_pt
=
GPTLMHeadModel
(
config
,
device
=
device
)
model_pt
=
GPTLMHeadModel
(
config
,
device
=
device
)
def
init_layer_norm
(
module
):
def
init_layer_norm
(
module
):
...
@@ -75,16 +78,15 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
...
@@ -75,16 +78,15 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
sharded_nparams_all
,
torch
.
tensor
([
sharded_nparams
],
device
=
device
),
group
=
process_group
sharded_nparams_all
,
torch
.
tensor
([
sharded_nparams
],
device
=
device
),
group
=
process_group
)
)
s
equence_parallel
_nparams
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
s
hared
_nparams
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
getattr
(
p
,
'_s
equence
_para
llel
'
,
False
))
if
getattr
(
p
,
'_s
hared
_para
ms
'
,
False
))
s
equence_parallel
_nparams_all
=
torch
.
empty
(
world_size
,
dtype
=
torch
.
long
,
device
=
device
)
s
hared
_nparams_all
=
torch
.
empty
(
world_size
,
dtype
=
torch
.
long
,
device
=
device
)
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
sequence_parallel_nparams_all
,
torch
.
tensor
([
sequence_parallel_nparams
],
device
=
device
),
shared_nparams_all
,
torch
.
tensor
([
shared_nparams
],
device
=
device
),
group
=
process_group
group
=
process_group
)
)
assert
torch
.
all
(
s
equence_parallel
_nparams_all
==
s
equence_parallel
_nparams
)
assert
torch
.
all
(
s
hared
_nparams_all
==
s
hared
_nparams
)
assert
total_nparams
==
((
sharded_nparams_all
-
s
equence_parallel
_nparams_all
).
sum
().
item
()
assert
total_nparams
==
((
sharded_nparams_all
-
s
hared
_nparams_all
).
sum
().
item
()
+
s
equence_parallel
_nparams
)
+
s
hared
_nparams
)
# vocab_size has been rounded up here
# vocab_size has been rounded up here
partition_vocab_size
=
config
.
vocab_size
//
world_size
partition_vocab_size
=
config
.
vocab_size
//
world_size
...
@@ -96,6 +98,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
...
@@ -96,6 +98,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
dtype
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
dtype
):
out
=
model
(
input_ids
[:,
:
-
1
]).
logits
out
=
model
(
input_ids
[:,
:
-
1
]).
logits
if
not
sequence_parallel
:
out
=
rearrange
(
out
,
'b s d -> (b s) d'
)
out_pt
=
rearrange
(
model_pt
(
input_ids
[:,
:
-
1
]).
logits
,
'b s d -> (b s) d'
)
out_pt
=
rearrange
(
model_pt
(
input_ids
[:,
:
-
1
]).
logits
,
'b s d -> (b s) d'
)
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
assert
torch
.
allclose
(
...
...
tests/modules/test_block_parallel.py
View file @
93383bd5
...
@@ -23,11 +23,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...
@@ -23,11 +23,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.
b
float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
def
test_block_parallel
(
dim
,
world_size
,
dtype
):
def
test_block_parallel
(
dim
,
sequence_parallel
,
world_size
,
dtype
):
head_dim
=
64
head_dim
=
64
assert
dim
%
head_dim
==
0
assert
dim
%
head_dim
==
0
num_heads
=
dim
//
head_dim
num_heads
=
dim
//
head_dim
...
@@ -41,7 +43,7 @@ def test_block_parallel(dim, world_size, dtype):
...
@@ -41,7 +43,7 @@ def test_block_parallel(dim, world_size, dtype):
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
2
seqlen
=
1024
seqlen
=
1024
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
dim
,
device
=
device
,
dtype
=
dtype
,
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
dim
,
device
=
device
,
dtype
=
dtype
,
...
@@ -51,8 +53,12 @@ def test_block_parallel(dim, world_size, dtype):
...
@@ -51,8 +53,12 @@ def test_block_parallel(dim, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG.
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
g
=
torch
.
randn_like
(
x_pt
)
/
32
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
if
sequence_parallel
:
residual
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
residual_pt
).
detach
().
clone
().
requires_grad_
()
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
residual
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
residual_pt
).
detach
().
clone
().
requires_grad_
()
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
residual
=
residual_pt
.
detach
().
clone
().
requires_grad_
()
mixer_cls_pt
=
partial
(
MHA
,
num_heads
=
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
mixer_cls_pt
=
partial
(
MHA
,
num_heads
=
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
...
@@ -69,12 +75,12 @@ def test_block_parallel(dim, world_size, dtype):
...
@@ -69,12 +75,12 @@ def test_block_parallel(dim, world_size, dtype):
mixer_cls
=
partial
(
ParallelMHA
,
num_heads
=
num_heads
,
mixer_cls
=
partial
(
ParallelMHA
,
num_heads
=
num_heads
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
mlp_cls
=
partial
(
ParallelFusedDenseGeluDense
,
hidden_features
=
4
*
dim
,
mlp_cls
=
partial
(
ParallelFusedDenseGeluDense
,
hidden_features
=
4
*
dim
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
device
=
device
,
dtype
=
dtype
)
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
model
=
Block
(
dim
,
mixer_cls
,
mlp_cls
,
norm_cls
,
fused_dropout_add_ln
=
True
,
model
=
Block
(
dim
,
mixer_cls
,
mlp_cls
,
norm_cls
,
fused_dropout_add_ln
=
True
,
sequence_parallel
=
True
)
sequence_parallel
=
sequence_parallel
,
mark_shared_params
=
True
)
partition_dim
=
dim
//
world_size
partition_dim
=
dim
//
world_size
partition_hidden_dim
=
4
*
dim
//
world_size
partition_hidden_dim
=
4
*
dim
//
world_size
...
@@ -115,25 +121,34 @@ def test_block_parallel(dim, world_size, dtype):
...
@@ -115,25 +121,34 @@ def test_block_parallel(dim, world_size, dtype):
out_pt
,
out_residual_pt
=
[
rearrange
(
x
,
'b s d -> (b s) d'
)
for
x
in
[
out_pt
,
out_residual_pt
]]
out_pt
,
out_residual_pt
=
[
rearrange
(
x
,
'b s d -> (b s) d'
)
for
x
in
[
out_pt
,
out_residual_pt
]]
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
rtol
=
rtol
,
atol
=
atol
)
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
out_residual
,
out_residual_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
out_residual
,
out_residual_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_residual_pt
,
rtol
=
rtol
,
atol
=
atol
rtol
=
rtol
,
atol
=
atol
)
)
out_pt
.
backward
(
g
)
(
out_pt
+
2
*
out_residual_pt
).
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
])
(
out
+
2
*
out_residual
).
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
allreduce_sequence_parallel_grad
(
model
,
parallel_state
.
get_tensor_model_parallel_group
())
allreduce_sequence_parallel_grad
(
model
,
parallel_state
.
get_tensor_model_parallel_group
())
parallel_state
.
destroy_model_parallel
()
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
x
.
grad
,
rtol
=
rtol
,
atol
=
atol
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
/
100
# magnitude of x.grad is quite small
)
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
residual
.
grad
,
residual_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
residual
.
grad
,
residual_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
residual_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
rtol
=
rtol
,
atol
=
atol
)
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
...
...
tests/modules/test_embedding_parallel.py
View file @
93383bd5
...
@@ -19,10 +19,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...
@@ -19,10 +19,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
'has_pos_emb'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_pos_emb'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_pos_emb', [True])
# @pytest.mark.parametrize('has_pos_emb', [True])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
def
test_embedding_parallel
(
dim
,
world_size
,
has_pos_emb
,
dtype
):
def
test_embedding_parallel
(
dim
,
has_pos_emb
,
sequence_parallel
,
world_size
,
dtype
):
vocab_size
=
50264
vocab_size
=
50264
seqlen
=
2048
seqlen
=
2048
assert
vocab_size
%
world_size
==
0
assert
vocab_size
%
world_size
==
0
...
@@ -46,7 +48,7 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
...
@@ -46,7 +48,7 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
device
=
device
,
dtype
=
dtype
)
device
=
device
,
dtype
=
dtype
)
model
=
ParallelGPT2Embeddings
(
dim
,
vocab_size
,
seqlen
if
has_pos_emb
else
0
,
model
=
ParallelGPT2Embeddings
(
dim
,
vocab_size
,
seqlen
if
has_pos_emb
else
0
,
parallel_state
.
get_tensor_model_parallel_group
(),
parallel_state
.
get_tensor_model_parallel_group
(),
device
=
device
,
dtype
=
dtype
)
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
partition_vocab_size
=
vocab_size
//
world_size
partition_vocab_size
=
vocab_size
//
world_size
partition_dim
=
dim
//
world_size
partition_dim
=
dim
//
world_size
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -62,13 +64,16 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
...
@@ -62,13 +64,16 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
out_pt
=
rearrange
(
model_pt
(
input_ids
),
'b s d -> (b s) d'
)
out_pt
=
rearrange
(
model_pt
(
input_ids
),
'b s d -> (b s) d'
)
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
rtol
=
rtol
,
atol
=
atol
)
)
g
=
torch
.
randn_like
(
out_pt
)
g
=
torch
.
randn_like
(
out_pt
)
out_pt
.
backward
(
g
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
])
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
parallel_state
.
destroy_model_parallel
()
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
assert
torch
.
allclose
(
...
...
tests/modules/test_mha_parallel.py
View file @
93383bd5
...
@@ -21,11 +21,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...
@@ -21,11 +21,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
'head_dim'
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
'head_dim'
,
[
64
,
128
])
# @pytest.mark.parametrize('head_dim', [64])
# @pytest.mark.parametrize('head_dim', [64])
@
pytest
.
mark
.
parametrize
(
'embed_dim'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'embed_dim'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('embed_dim', [1024])
# @pytest.mark.parametrize('embed_dim', [1024])
def
test_mha_parallel
(
embed_dim
,
head_dim
,
world_size
,
dtype
):
def
test_mha_parallel
(
embed_dim
,
head_dim
,
sequence_parallel
,
world_size
,
dtype
):
assert
embed_dim
%
head_dim
==
0
assert
embed_dim
%
head_dim
==
0
num_heads
=
embed_dim
//
head_dim
num_heads
=
embed_dim
//
head_dim
assert
num_heads
%
world_size
==
0
assert
num_heads
%
world_size
==
0
...
@@ -38,7 +40,7 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
...
@@ -38,7 +40,7 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
2
seqlen
=
1024
seqlen
=
1024
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
embed_dim
,
device
=
device
,
dtype
=
dtype
,
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
embed_dim
,
device
=
device
,
dtype
=
dtype
,
...
@@ -47,14 +49,17 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
...
@@ -47,14 +49,17 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG.
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
g
=
torch
.
randn_like
(
x_pt
)
/
32
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
if
sequence_parallel
:
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
MHA
(
embed_dim
,
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
model_pt
=
MHA
(
embed_dim
,
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
partition_dim
=
embed_dim
//
world_size
partition_dim
=
embed_dim
//
world_size
model
=
ParallelMHA
(
embed_dim
,
num_heads
,
parallel_state
.
get_tensor_model_parallel_group
(),
model
=
ParallelMHA
(
embed_dim
,
num_heads
,
parallel_state
.
get_tensor_model_parallel_group
(),
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
Wqkv
.
weight
.
copy_
(
model
.
Wqkv
.
weight
.
copy_
(
...
@@ -75,17 +80,22 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
...
@@ -75,17 +80,22 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
out_pt
=
rearrange
(
model_pt
(
rearrange
(
x_pt
,
'(b s) d -> b s d'
,
s
=
seqlen
)),
'b s d -> (b s) d'
)
out_pt
=
rearrange
(
model_pt
(
rearrange
(
x_pt
,
'(b s) d -> b s d'
,
s
=
seqlen
)),
'b s d -> (b s) d'
)
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
rtol
=
rtol
,
atol
=
atol
)
)
out_pt
.
backward
(
g
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
])
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
parallel_state
.
destroy_model_parallel
()
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
x
.
grad
,
rtol
=
rtol
,
atol
=
atol
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
/
100
# magnitude of x.grad is quite small
)
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
assert
torch
.
allclose
(
...
...
tests/ops/test_fused_dense_parallel.py
View file @
93383bd5
...
@@ -19,14 +19,15 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...
@@ -19,14 +19,15 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [8])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
'has_bias'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_bias'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias', [True])
# @pytest.mark.parametrize('has_bias', [False])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
])
# @pytest.mark.parametrize('out_features', [1024])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
sequence_parallel
,
# @pytest.mark.parametrize('in_features', [4096])
world_size
,
dtype
):
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
if
not
torch
.
distributed
.
is_initialized
():
...
@@ -37,18 +38,21 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
...
@@ -37,18 +38,21 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
2
seqlen
=
512
seqlen
=
512
assert
batch_size
*
seqlen
%
world_size
==
0
assert
batch_size
*
seqlen
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
requires_grad
=
True
)
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
if
sequence_parallel
:
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
partition_out_features
=
out_features
//
world_size
partition_out_features
=
out_features
//
world_size
model
=
ColumnParallelLinear
(
in_features
,
out_features
,
model
=
ColumnParallelLinear
(
in_features
,
out_features
,
parallel_state
.
get_tensor_model_parallel_group
(),
bias
=
has_bias
,
parallel_state
.
get_tensor_model_parallel_group
(),
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model
.
weight
.
copy_
(
model_pt
.
weight
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
model_pt
.
weight
[
rank
*
partition_out_features
:(
rank
+
1
)
*
partition_out_features
]
...
@@ -73,7 +77,9 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
...
@@ -73,7 +77,9 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
rtol
=
rtol
,
atol
=
atol
)
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
...
@@ -94,13 +100,14 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
...
@@ -94,13 +100,14 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'sequence_parallel'
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_bias2'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias2', [True])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
4096
])
# @pytest.mark.parametrize('out_features', [1024])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias2
,
sequence_parallel
,
# @pytest.mark.parametrize('in_features', [1024])
world_size
,
dtype
):
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
has_bias2
,
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
if
not
torch
.
distributed
.
is_initialized
():
...
@@ -111,7 +118,7 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
...
@@ -111,7 +118,7 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
2
seqlen
=
512
seqlen
=
512
assert
batch_size
*
seqlen
%
world_size
==
0
assert
batch_size
*
seqlen
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
...
@@ -120,7 +127,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
...
@@ -120,7 +127,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
# as rank 0 will have an extra bias that changes the RNG.
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
g
=
torch
.
randn_like
(
x_pt
)
/
32
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
if
sequence_parallel
:
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
...
@@ -129,7 +139,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
...
@@ -129,7 +139,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
partition_in_features
=
in_features
//
world_size
partition_in_features
=
in_features
//
world_size
model
=
ParallelFusedDenseGeluDense
(
in_features
,
out_features
,
in_features
,
model
=
ParallelFusedDenseGeluDense
(
in_features
,
out_features
,
in_features
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
bias2
=
has_bias2
and
rank
==
0
,
device
=
device
,
dtype
=
dtype
)
bias2
=
has_bias2
and
rank
==
0
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model
.
fc1
.
weight
.
copy_
(
...
@@ -148,16 +160,21 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
...
@@ -148,16 +160,21 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
rtol
=
rtol
,
atol
=
atol
)
)
out_pt
.
backward
(
g
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
])
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
parallel_state
.
destroy_model_parallel
()
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
rtol
=
rtol
,
atol
=
atol
)
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment