Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f3b6aaa6
Commit
f3b6aaa6
authored
Jun 30, 2023
by
Frank Lee
Browse files
[shardformer] supported fused normalization (#4112)
parent
b1c29015
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
207 additions
and
31 deletions
+207
-31
colossalai/shardformer/layer/__init__.py
colossalai/shardformer/layer/__init__.py
+2
-2
colossalai/shardformer/layer/normalization.py
colossalai/shardformer/layer/normalization.py
+42
-2
colossalai/shardformer/policies/basepolicy.py
colossalai/shardformer/policies/basepolicy.py
+8
-0
colossalai/shardformer/policies/bert.py
colossalai/shardformer/policies/bert.py
+16
-5
colossalai/shardformer/policies/bloom.py
colossalai/shardformer/policies/bloom.py
+29
-2
colossalai/shardformer/policies/gpt2.py
colossalai/shardformer/policies/gpt2.py
+28
-1
colossalai/shardformer/policies/llama.py
colossalai/shardformer/policies/llama.py
+26
-2
colossalai/shardformer/policies/opt.py
colossalai/shardformer/policies/opt.py
+7
-1
colossalai/shardformer/policies/t5.py
colossalai/shardformer/policies/t5.py
+19
-3
colossalai/shardformer/policies/vit.py
colossalai/shardformer/policies/vit.py
+25
-2
colossalai/shardformer/shard/shard_config.py
colossalai/shardformer/shard/shard_config.py
+2
-8
tests/test_shardformer/test_model/_utils.py
tests/test_shardformer/test_model/_utils.py
+3
-3
No files found.
colossalai/shardformer/layer/__init__.py
View file @
f3b6aaa6
from
.dropout
import
DropoutForParallelInput
,
DropoutForReplicatedInput
from
.embedding
import
Embedding1D
,
VocabParallelEmbedding1D
from
.layernorm
import
FusedLayerNorm
from
.linear
import
Linear1D_Col
,
Linear1D_Row
from
.loss
import
cross_entropy_1d
from
.normalization
import
FusedLayerNorm
,
FusedRMSNorm
from
.qkv_fused_linear
import
GPT2FusedLinearConv1D_Col
,
GPT2FusedLinearConv1D_Row
__all__
=
[
"Embedding1D"
,
"VocabParallelEmbedding1D"
,
"Linear1D_Col"
,
"Linear1D_Row"
,
'GPT2FusedLinearConv1D_Col'
,
'GPT2FusedLinearConv1D_Row'
,
'DropoutForParallelInput'
,
'DropoutForReplicatedInput'
,
"cross_entropy_1d"
,
'FusedLayerNorm'
'FusedLayerNorm'
,
'FusedRMSNorm'
]
colossalai/shardformer/layer/
layernorm
.py
→
colossalai/shardformer/layer/
normalization
.py
View file @
f3b6aaa6
...
...
@@ -4,7 +4,7 @@
import
torch
import
torch.nn
as
nn
__all__
=
[
'FusedLayerNorm'
]
__all__
=
[
'FusedLayerNorm'
,
'FusedRMSNorm'
]
FAST_LAYERNORM_SUPPORTED_SIZE
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
...
...
@@ -61,4 +61,44 @@ class FusedLayerNorm():
# copy weight and bias
layernorm
.
weight
.
copy_
(
module
.
weight
)
layernorm
.
bias
.
copy_
(
module
.
bias
)
return
layernorm
\ No newline at end of file
return
layernorm
class
FusedRMSNorm
():
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""
def
__init__
(
self
)
->
None
:
raise
NotImplementedError
(
'FusedRMSNorm is not implemented as a physical class. '
'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.'
)
@
staticmethod
def
from_native_module
(
module
:
nn
.
Module
,
*
args
,
**
kwargs
)
->
nn
.
Module
:
try
:
from
apex.normalization
import
FusedRMSNorm
as
ApexFusedRMSNorm
except
ImportError
:
raise
ImportError
(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
)
# to check if it is huggingface LlamaRMSNorm
if
module
.
__class__
.
__name__
==
"LlamaRMSNorm"
:
normalized_shape
=
module
.
weight
.
shape
[
0
]
eps
=
module
.
variance_epsilon
elementwise_affine
=
True
else
:
# get the attributes of the module
normalized_shape
=
module
.
normalized_shape
eps
=
module
.
eps
elementwise_affine
=
module
.
elementwise_affine
rmsnorm
=
ApexFusedRMSNorm
(
normalized_shape
=
normalized_shape
,
eps
=
eps
,
elementwise_affine
=
elementwise_affine
)
with
torch
.
no_grad
():
# copy weight and bias
rmsnorm
.
weight
.
copy_
(
module
.
weight
)
return
rmsnorm
colossalai/shardformer/policies/basepolicy.py
View file @
f3b6aaa6
...
...
@@ -98,6 +98,14 @@ class Policy(ABC):
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self
.
shard_config
=
shard_config
self
.
config_sanity_check
()
@
abstractmethod
def
config_sanity_check
(
self
):
"""
Check if the shard config is valid for the model. Raise an exception if the config is invalid.
"""
pass
@
abstractmethod
def
preprocess
(
self
)
->
nn
.
Module
:
...
...
colossalai/shardformer/policies/bert.py
View file @
f3b6aaa6
...
...
@@ -16,6 +16,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class
BertPolicy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
# reshape the embedding layer
r
"""
...
...
@@ -99,7 +102,8 @@ class BertPolicy(Policy):
])
}
if
self
.
shard_config
.
fused_layernorm
:
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
base_policy
[
BertLayer
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"attention.output.LayerNorm"
,
...
...
@@ -150,12 +154,16 @@ class BertForPretrainingPolicy(BertPolicy):
kwargs
=
{
"gather_output"
:
True
}),
])
}
if
self
.
shard_config
.
fused_layernorm
:
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
addon_module
[
BertLMPredictionHead
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"transform.LayerNorm"
,
target_module
=
col_nn
.
FusedLayerNorm
,
))
# append extra policy
module_policy
.
update
(
addon_module
)
return
module_policy
...
...
@@ -187,7 +195,7 @@ class BertLMHeadModelPolicy(BertPolicy):
kwargs
=
{
"gather_output"
:
True
}),
])
}
if
self
.
shard_config
.
fused_layernorm
:
if
self
.
shard_config
.
enable_fused_normalization
:
addon_module
[
BertLMPredictionHead
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"transform.LayerNorm"
,
...
...
@@ -224,12 +232,15 @@ class BertForMaskedLMPolicy(BertPolicy):
kwargs
=
{
"gather_output"
:
True
}),
])
}
if
self
.
shard_config
.
fused_layernorm
:
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
addon_module
[
BertLMPredictionHead
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"transform.LayerNorm"
,
target_module
=
col_nn
.
FusedLayerNorm
,
))
module_policy
.
update
(
addon_module
)
return
module_policy
...
...
@@ -316,4 +327,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
])
}
module_policy
.
update
(
addon_module
)
return
module_policy
\ No newline at end of file
return
module_policy
colossalai/shardformer/policies/bloom.py
View file @
f3b6aaa6
...
...
@@ -65,6 +65,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,
class
BloomPolicy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
# reshape the embedding layer
r
"""
...
...
@@ -81,7 +84,7 @@ class BloomPolicy(Policy):
def
module_policy
(
self
):
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomModel
return
{
base_policy
=
{
BloomBlock
:
ModulePolicyDescription
(
attribute_replacement
=
{
...
...
@@ -99,7 +102,6 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription
(
suffix
=
"self_attention.query_key_value"
,
target_module
=
col_nn
.
Linear1D_Col
,
# kwargs={'n_fused': 3}
),
SubModuleReplacementDescription
(
suffix
=
"self_attention.dense"
,
...
...
@@ -132,6 +134,31 @@ class BloomPolicy(Policy):
])
}
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
base_policy
[
BloomModel
].
sub_module_replacement
.
extend
([
SubModuleReplacementDescription
(
suffix
=
"ln_f"
,
target_module
=
col_nn
.
FusedLayerNorm
,
),
SubModuleReplacementDescription
(
suffix
=
"word_embeddings_layernorm"
,
target_module
=
col_nn
.
FusedLayerNorm
,
)
])
base_policy
[
BloomBlock
].
sub_module_replacement
.
extend
([
SubModuleReplacementDescription
(
suffix
=
"input_layernorm"
,
target_module
=
col_nn
.
FusedLayerNorm
,
),
SubModuleReplacementDescription
(
suffix
=
"post_attention_layernorm"
,
target_module
=
col_nn
.
FusedLayerNorm
,
)
])
return
base_policy
def
new_model_class
(
self
):
# do nothing
return
self
.
model
...
...
colossalai/shardformer/policies/gpt2.py
View file @
f3b6aaa6
...
...
@@ -9,6 +9,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class
GPT2Policy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
# reshape the embedding layer
r
"""
...
...
@@ -22,7 +25,7 @@ class GPT2Policy(Policy):
return
self
.
model
def
module_policy
(
self
):
return
{
base_policy
=
{
GPT2Model
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
...
...
@@ -77,6 +80,30 @@ class GPT2Policy(Policy):
])
}
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
base_policy
[
GPT2Model
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"ln_f"
,
target_module
=
col_nn
.
FusedLayerNorm
,
))
base_policy
[
GPT2Block
].
sub_module_replacement
.
extend
([
SubModuleReplacementDescription
(
suffix
=
"ln_1"
,
target_module
=
col_nn
.
FusedLayerNorm
,
),
SubModuleReplacementDescription
(
suffix
=
"ln_2"
,
target_module
=
col_nn
.
FusedLayerNorm
,
),
SubModuleReplacementDescription
(
suffix
=
"ln_cross_attn"
,
target_module
=
col_nn
.
FusedLayerNorm
,
ignore_if_not_exist
=
True
)
])
return
base_policy
def
new_model_class
(
self
):
return
self
.
model
...
...
colossalai/shardformer/policies/llama.py
View file @
f3b6aaa6
...
...
@@ -4,13 +4,16 @@ import torch.nn as nn
from
transformers
import
LlamaForCausalLM
,
LlamaForSequenceClassification
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaModel
from
colossalai.shardformer.layer
import
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
from
colossalai.shardformer.layer
import
FusedRMSNorm
,
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
class
LlamaPolicy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
# Resize embedding
vocab_size
=
self
.
model
.
config
.
vocab_size
...
...
@@ -23,7 +26,7 @@ class LlamaPolicy(Policy):
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
return
{
base_policy
=
{
LlamaDecoderLayer
:
ModulePolicyDescription
(
attribute_replacement
=
{
...
...
@@ -75,6 +78,27 @@ class LlamaPolicy(Policy):
])
}
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
base_policy
[
LlamaDecoderLayer
].
sub_module_replacement
.
extend
([
SubModuleReplacementDescription
(
suffix
=
"input_layernorm"
,
target_module
=
FusedRMSNorm
,
),
SubModuleReplacementDescription
(
suffix
=
"post_attention_layernorm"
,
target_module
=
FusedRMSNorm
,
)
])
base_policy
[
LlamaModel
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"norm"
,
target_module
=
FusedRMSNorm
,
))
return
base_policy
def
new_model_class
(
self
):
return
None
...
...
colossalai/shardformer/policies/opt.py
View file @
f3b6aaa6
...
...
@@ -13,6 +13,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class
OPTPolicy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
# reshape the embedding layer
r
"""
...
...
@@ -74,7 +77,9 @@ class OPTPolicy(Policy):
),
]),
}
if
self
.
shard_config
.
fused_layernorm
:
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
base_policy
[
OPTDecoder
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"final_layer_norm"
,
target_module
=
FusedLayerNorm
,
...
...
@@ -87,6 +92,7 @@ class OPTPolicy(Policy):
target_module
=
FusedLayerNorm
,
ignore_if_not_exist
=
True
)
])
return
base_policy
def
new_model_class
(
self
):
...
...
colossalai/shardformer/policies/t5.py
View file @
f3b6aaa6
...
...
@@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import (
T5Stack
,
)
from
colossalai.shardformer.layer
import
DropoutForParallelInput
,
Embedding1D
,
Linear1D_Col
,
Linear1D_Row
from
colossalai.shardformer.layer
import
DropoutForParallelInput
,
Embedding1D
,
FusedRMSNorm
,
Linear1D_Col
,
Linear1D_Row
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
...
...
@@ -18,6 +18,9 @@ __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy
class
T5ModelPolicy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
# reshape the embedding layer
r
"""
...
...
@@ -31,7 +34,7 @@ class T5ModelPolicy(Policy):
return
self
.
model
def
module_policy
(
self
):
return
{
base_policy
=
{
T5Stack
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
...
...
@@ -139,6 +142,19 @@ class T5ModelPolicy(Policy):
])
}
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
base_policy
[
T5LayerFF
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"layer_norm"
,
target_module
=
FusedRMSNorm
))
base_policy
[
T5LayerSelfAttention
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"layer_norm"
,
target_module
=
FusedRMSNorm
))
base_policy
[
T5LayerCrossAttention
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"layer_norm"
,
target_module
=
FusedRMSNorm
))
base_policy
[
T5Stack
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"final_layer_norm"
,
target_module
=
FusedRMSNorm
))
return
base_policy
def
new_model_class
(
self
):
return
None
...
...
@@ -167,4 +183,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy):
class
T5EncoderPolicy
(
T5ModelPolicy
):
pass
\ No newline at end of file
pass
colossalai/shardformer/policies/vit.py
View file @
f3b6aaa6
...
...
@@ -3,13 +3,16 @@ from typing import Dict, Union
import
torch.nn
as
nn
from
transformers.models.vit.modeling_vit
import
ViTAttention
,
ViTEmbeddings
,
ViTLayer
,
ViTModel
from
colossalai.shardformer.layer
import
DropoutForReplicatedInput
,
Linear1D_Col
,
Linear1D_Row
from
colossalai.shardformer.layer
import
DropoutForReplicatedInput
,
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
class
ViTPolicy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
# Resize embedding
vocab_size
=
self
.
model
.
config
.
vocab_size
...
...
@@ -22,7 +25,7 @@ class ViTPolicy(Policy):
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
return
{
base_policy
=
{
ViTEmbeddings
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
...
...
@@ -80,6 +83,26 @@ class ViTPolicy(Policy):
]),
}
# optimization configuration
if
self
.
shard_config
.
enable_fused_normalization
:
base_policy
[
ViTAttention
].
sub_module_replacement
.
extend
([
SubModuleReplacementDescription
(
suffix
=
"layernorm_before"
,
target_module
=
FusedLayerNorm
,
),
SubModuleReplacementDescription
(
suffix
=
"layernorm_after"
,
target_module
=
FusedLayerNorm
,
)
])
base_policy
[
ViTModel
].
sub_module_replacement
.
append
(
SubModuleReplacementDescription
(
suffix
=
"layernorm"
,
target_module
=
FusedLayerNorm
,
))
return
base_policy
def
new_model_class
(
self
):
return
None
...
...
colossalai/shardformer/shard/shard_config.py
View file @
f3b6aaa6
...
...
@@ -12,16 +12,10 @@ class ShardConfig:
Args:
tensor_parallel_size (int): The size of tensor parallel
use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm`
data_parallel_size (int): The size of data parallel
pipeline_parallel_size (int): The size of pipeline parallel
tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d']
inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model
will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
"""
tensor_parallel_size
:
int
fused_layernorm
:
bool
=
False
enable_fused_normalization
:
bool
=
False
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
...
...
tests/test_shardformer/test_model/_utils.py
View file @
f3b6aaa6
...
...
@@ -8,11 +8,11 @@ def build_model(world_size, model_fn):
org_model
=
model_fn
().
cuda
()
# shard model
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
,
fused_layernorm
=
True
)
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
,
enable_fused_normalization
=
True
)
model_copy
=
copy
.
deepcopy
(
org_model
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
model_copy
)
sharded_model
=
shard_former
.
shard_model
(
model_copy
)
.
cuda
()
return
org_model
,
sharded_model
...
...
@@ -33,4 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
shard_output
=
sharded_model
(
**
data
)
shard_output
=
output_transform_fn
(
shard_output
)
shard_loss
=
loss_fn
(
shard_output
)
return
org_output
,
org_loss
,
shard_output
,
shard_loss
\ No newline at end of file
return
org_output
,
org_loss
,
shard_output
,
shard_loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment