Commit f3b6aaa6 authored by Frank Lee's avatar Frank Lee
Browse files

[shardformer] supported fused normalization (#4112)

parent b1c29015
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import FusedLayerNorm
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [ __all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
'FusedLayerNorm' 'FusedLayerNorm', 'FusedRMSNorm'
] ]
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
__all__ = ['FusedLayerNorm'] __all__ = ['FusedLayerNorm', 'FusedRMSNorm']
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576,
...@@ -61,4 +61,44 @@ class FusedLayerNorm(): ...@@ -61,4 +61,44 @@ class FusedLayerNorm():
# copy weight and bias # copy weight and bias
layernorm.weight.copy_(module.weight) layernorm.weight.copy_(module.weight)
layernorm.bias.copy_(module.bias) layernorm.bias.copy_(module.bias)
return layernorm return layernorm
\ No newline at end of file
class FusedRMSNorm():
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
'FusedRMSNorm is not implemented as a physical class. '
'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.'
)
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
except ImportError:
raise ImportError(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
)
# to check if it is huggingface LlamaRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm":
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine
rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
with torch.no_grad():
# copy weight and bias
rmsnorm.weight.copy_(module.weight)
return rmsnorm
...@@ -98,6 +98,14 @@ class Policy(ABC): ...@@ -98,6 +98,14 @@ class Policy(ABC):
shard_config (:class:`ShardConfig`): The shard config to be perform shard_config (:class:`ShardConfig`): The shard config to be perform
""" """
self.shard_config = shard_config self.shard_config = shard_config
self.config_sanity_check()
@abstractmethod
def config_sanity_check(self):
"""
Check if the shard config is valid for the model. Raise an exception if the config is invalid.
"""
pass
@abstractmethod @abstractmethod
def preprocess(self) -> nn.Module: def preprocess(self) -> nn.Module:
......
...@@ -16,6 +16,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes ...@@ -16,6 +16,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class BertPolicy(Policy): class BertPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer # reshape the embedding layer
r""" r"""
...@@ -99,7 +102,8 @@ class BertPolicy(Policy): ...@@ -99,7 +102,8 @@ class BertPolicy(Policy):
]) ])
} }
if self.shard_config.fused_layernorm: # optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BertLayer].sub_module_replacement.append( base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.LayerNorm", suffix="attention.output.LayerNorm",
...@@ -150,12 +154,16 @@ class BertForPretrainingPolicy(BertPolicy): ...@@ -150,12 +154,16 @@ class BertForPretrainingPolicy(BertPolicy):
kwargs={"gather_output": True}), kwargs={"gather_output": True}),
]) ])
} }
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append( addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="transform.LayerNorm", suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
)) ))
# append extra policy
module_policy.update(addon_module) module_policy.update(addon_module)
return module_policy return module_policy
...@@ -187,7 +195,7 @@ class BertLMHeadModelPolicy(BertPolicy): ...@@ -187,7 +195,7 @@ class BertLMHeadModelPolicy(BertPolicy):
kwargs={"gather_output": True}), kwargs={"gather_output": True}),
]) ])
} }
if self.shard_config.fused_layernorm: if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append( addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="transform.LayerNorm", suffix="transform.LayerNorm",
...@@ -224,12 +232,15 @@ class BertForMaskedLMPolicy(BertPolicy): ...@@ -224,12 +232,15 @@ class BertForMaskedLMPolicy(BertPolicy):
kwargs={"gather_output": True}), kwargs={"gather_output": True}),
]) ])
} }
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append( addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="transform.LayerNorm", suffix="transform.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
)) ))
module_policy.update(addon_module) module_policy.update(addon_module)
return module_policy return module_policy
...@@ -316,4 +327,4 @@ class BertForMultipleChoicePolicy(BertPolicy): ...@@ -316,4 +327,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
]) ])
} }
module_policy.update(addon_module) module_policy.update(addon_module)
return module_policy return module_policy
\ No newline at end of file
...@@ -65,6 +65,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, ...@@ -65,6 +65,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,
class BloomPolicy(Policy): class BloomPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer # reshape the embedding layer
r""" r"""
...@@ -81,7 +84,7 @@ class BloomPolicy(Policy): ...@@ -81,7 +84,7 @@ class BloomPolicy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
return { base_policy = {
BloomBlock: BloomBlock:
ModulePolicyDescription( ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
...@@ -99,7 +102,6 @@ class BloomPolicy(Policy): ...@@ -99,7 +102,6 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.query_key_value", suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
# kwargs={'n_fused': 3}
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.dense", suffix="self_attention.dense",
...@@ -132,6 +134,31 @@ class BloomPolicy(Policy): ...@@ -132,6 +134,31 @@ class BloomPolicy(Policy):
]) ])
} }
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BloomModel].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
base_policy[BloomBlock].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
return base_policy
def new_model_class(self): def new_model_class(self):
# do nothing # do nothing
return self.model return self.model
......
...@@ -9,6 +9,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes ...@@ -9,6 +9,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class GPT2Policy(Policy): class GPT2Policy(Policy):
def config_sanity_check(self):
pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer # reshape the embedding layer
r""" r"""
...@@ -22,7 +25,7 @@ class GPT2Policy(Policy): ...@@ -22,7 +25,7 @@ class GPT2Policy(Policy):
return self.model return self.model
def module_policy(self): def module_policy(self):
return { base_policy = {
GPT2Model: GPT2Model:
ModulePolicyDescription(attribute_replacement={}, ModulePolicyDescription(attribute_replacement={},
param_replacement=[], param_replacement=[],
...@@ -77,6 +80,30 @@ class GPT2Policy(Policy): ...@@ -77,6 +80,30 @@ class GPT2Policy(Policy):
]) ])
} }
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[GPT2Model].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
))
base_policy[GPT2Block].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="ln_2",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(suffix="ln_cross_attn",
target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True)
])
return base_policy
def new_model_class(self): def new_model_class(self):
return self.model return self.model
......
...@@ -4,13 +4,16 @@ import torch.nn as nn ...@@ -4,13 +4,16 @@ import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class LlamaPolicy(Policy): class LlamaPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self): def preprocess(self):
# Resize embedding # Resize embedding
vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
...@@ -23,7 +26,7 @@ class LlamaPolicy(Policy): ...@@ -23,7 +26,7 @@ class LlamaPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return { base_policy = {
LlamaDecoderLayer: LlamaDecoderLayer:
ModulePolicyDescription( ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
...@@ -75,6 +78,27 @@ class LlamaPolicy(Policy): ...@@ -75,6 +78,27 @@ class LlamaPolicy(Policy):
]) ])
} }
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[LlamaDecoderLayer].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
)
])
base_policy[LlamaModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
))
return base_policy
def new_model_class(self): def new_model_class(self):
return None return None
......
...@@ -13,6 +13,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes ...@@ -13,6 +13,9 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
class OPTPolicy(Policy): class OPTPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer # reshape the embedding layer
r""" r"""
...@@ -74,7 +77,9 @@ class OPTPolicy(Policy): ...@@ -74,7 +77,9 @@ class OPTPolicy(Policy):
), ),
]), ]),
} }
if self.shard_config.fused_layernorm:
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[OPTDecoder].sub_module_replacement.append( base_policy[OPTDecoder].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="final_layer_norm", SubModuleReplacementDescription(suffix="final_layer_norm",
target_module=FusedLayerNorm, target_module=FusedLayerNorm,
...@@ -87,6 +92,7 @@ class OPTPolicy(Policy): ...@@ -87,6 +92,7 @@ class OPTPolicy(Policy):
target_module=FusedLayerNorm, target_module=FusedLayerNorm,
ignore_if_not_exist=True) ignore_if_not_exist=True)
]) ])
return base_policy return base_policy
def new_model_class(self): def new_model_class(self):
......
...@@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import ( ...@@ -9,7 +9,7 @@ from transformers.models.t5.modeling_t5 import (
T5Stack, T5Stack,
) )
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
...@@ -18,6 +18,9 @@ __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy ...@@ -18,6 +18,9 @@ __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy
class T5ModelPolicy(Policy): class T5ModelPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer # reshape the embedding layer
r""" r"""
...@@ -31,7 +34,7 @@ class T5ModelPolicy(Policy): ...@@ -31,7 +34,7 @@ class T5ModelPolicy(Policy):
return self.model return self.model
def module_policy(self): def module_policy(self):
return { base_policy = {
T5Stack: T5Stack:
ModulePolicyDescription(attribute_replacement={}, ModulePolicyDescription(attribute_replacement={},
param_replacement=[], param_replacement=[],
...@@ -139,6 +142,19 @@ class T5ModelPolicy(Policy): ...@@ -139,6 +142,19 @@ class T5ModelPolicy(Policy):
]) ])
} }
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[T5LayerFF].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5LayerSelfAttention].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5LayerCrossAttention].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm))
base_policy[T5Stack].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm))
return base_policy
def new_model_class(self): def new_model_class(self):
return None return None
...@@ -167,4 +183,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy): ...@@ -167,4 +183,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy):
class T5EncoderPolicy(T5ModelPolicy): class T5EncoderPolicy(T5ModelPolicy):
pass pass
\ No newline at end of file
...@@ -3,13 +3,16 @@ from typing import Dict, Union ...@@ -3,13 +3,16 @@ from typing import Dict, Union
import torch.nn as nn import torch.nn as nn
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ViTPolicy(Policy): class ViTPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self): def preprocess(self):
# Resize embedding # Resize embedding
vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
...@@ -22,7 +25,7 @@ class ViTPolicy(Policy): ...@@ -22,7 +25,7 @@ class ViTPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return { base_policy = {
ViTEmbeddings: ViTEmbeddings:
ModulePolicyDescription(attribute_replacement={}, ModulePolicyDescription(attribute_replacement={},
param_replacement=[], param_replacement=[],
...@@ -80,6 +83,26 @@ class ViTPolicy(Policy): ...@@ -80,6 +83,26 @@ class ViTPolicy(Policy):
]), ]),
} }
# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[ViTAttention].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="layernorm_before",
target_module=FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layernorm_after",
target_module=FusedLayerNorm,
)
])
base_policy[ViTModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="layernorm",
target_module=FusedLayerNorm,
))
return base_policy
def new_model_class(self): def new_model_class(self):
return None return None
......
...@@ -12,16 +12,10 @@ class ShardConfig: ...@@ -12,16 +12,10 @@ class ShardConfig:
Args: Args:
tensor_parallel_size (int): The size of tensor parallel tensor_parallel_size (int): The size of tensor parallel
use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm` enable_fused_normalization (bool): Whether to use fused layernorm, default is False
data_parallel_size (int): The size of data parallel
pipeline_parallel_size (int): The size of pipeline parallel
tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d']
inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model
will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer
""" """
tensor_parallel_size: int tensor_parallel_size: int
fused_layernorm: bool = False enable_fused_normalization: bool = False
# TODO: add support for tensor parallel # TODO: add support for tensor parallel
# pipeline_parallel_size: int # pipeline_parallel_size: int
......
...@@ -8,11 +8,11 @@ def build_model(world_size, model_fn): ...@@ -8,11 +8,11 @@ def build_model(world_size, model_fn):
org_model = model_fn().cuda() org_model = model_fn().cuda()
# shard model # shard model
shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True) shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True)
model_copy = copy.deepcopy(org_model) model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed() shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy) sharded_model = shard_former.shard_model(model_copy).cuda()
return org_model, sharded_model return org_model, sharded_model
...@@ -33,4 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, ...@@ -33,4 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
shard_output = sharded_model(**data) shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output) shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output) shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss return org_output, org_loss, shard_output, shard_loss
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment