"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "bb0a668feeae647c247a509ed69e5f6c926a045c"
Unverified Commit 1a3315e3 authored by littsk's avatar littsk Committed by GitHub
Browse files

[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)



* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------
Co-authored-by: default avatarlittsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------
Co-authored-by: default avatarlittsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------
Co-authored-by: default avatarZhongkai Zhao <kanezz620@gmail.com>
parent d99b2c96
...@@ -338,7 +338,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -338,7 +338,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule( model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
tp_group=self.tp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0: if self.zero_stage == 0:
......
...@@ -218,10 +218,10 @@ class TPInferEngine: ...@@ -218,10 +218,10 @@ class TPInferEngine:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__ model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
model = model.model if self.shard_config.inference_gptq else model model = model.model if self.shard_config.inference_gptq else model
policy = get_autopolicy(model, shard_config=self.shard_config)
policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy) self.model, _ = shardformer.optimize(model, policy)
if self.shard_config.inference_gptq: if self.shard_config.inference_gptq:
......
...@@ -235,6 +235,14 @@ class SubModuleReplacementDescription: ...@@ -235,6 +235,14 @@ class SubModuleReplacementDescription:
class Policy(ABC): class Policy(ABC):
r"""
The base class for all the policies. For each different model, it should have a different policy class,
like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""
def __init__(self) def __init__(self)
self.model = None self.model = None
...@@ -245,6 +253,16 @@ class Policy(ABC): ...@@ -245,6 +253,16 @@ class Policy(ABC):
""" """
self.model = model self.model = model
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()
@abstractmethod @abstractmethod
def preprocess(self) -> nn.Module: def preprocess(self) -> nn.Module:
""" """
......
...@@ -2,7 +2,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput ...@@ -2,7 +2,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
...@@ -16,6 +16,9 @@ __all__ = [ ...@@ -16,6 +16,9 @@ __all__ = [
"DropoutForParallelInput", "DropoutForParallelInput",
"DropoutForReplicatedInput", "DropoutForReplicatedInput",
"cross_entropy_1d", "cross_entropy_1d",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
"FusedLayerNorm", "FusedLayerNorm",
"FusedRMSNorm", "FusedRMSNorm",
"FusedLinear1D_Col", "FusedLinear1D_Col",
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch.nn as nn import torch.nn as nn
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
__all__ = ["FusedLayerNorm", "FusedRMSNorm"] from .utils import SeqParallelUtils
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1024,
...@@ -35,7 +38,103 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [ ...@@ -35,7 +38,103 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
] ]
class FusedLayerNorm: class BaseLayerNorm(ABC):
@abstractmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
"""
Convert a native PyTorch layer normalization module to a specific layer normalization module,
and optionally mark parameters for gradient aggregation.
Args:
module (nn.Module): The native PyTorch layer normalization module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The specific layer normalization module.
Raises:
AssertionError: If the provided module is not an instance of the supported layer normalization type.
"""
class RMSNorm(BaseLayerNorm):
r"""
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
"""
Convert a native RMSNorm module to colossalai layer norm module,
and optionally mark parameters for gradient aggregation.
Args:
module (nn.Module): The native RMSNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The RMSNorm module.
"""
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
return module
class LayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"LayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The LayerNorm module.
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
return module
class FusedLayerNorm(BaseLayerNorm):
r""" r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
""" """
...@@ -43,15 +142,29 @@ class FusedLayerNorm: ...@@ -43,15 +142,29 @@ class FusedLayerNorm:
def __init__(self) -> None: def __init__(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. " "FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex." "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
) )
@staticmethod @staticmethod
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r""" r"""
Convert a native pytorch layer norm module to colossalai layer norm module Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
""" """
# check if apex is installed # check if apex is installed
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
try: try:
pass pass
except ImportError: except ImportError:
...@@ -85,10 +198,18 @@ class FusedLayerNorm: ...@@ -85,10 +198,18 @@ class FusedLayerNorm:
layernorm.weight = module.weight layernorm.weight = module.weight
layernorm.bias = module.bias layernorm.bias = module.bias
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
return layernorm return layernorm
class FusedRMSNorm: class FusedRMSNorm(BaseLayerNorm):
""" """
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
""" """
...@@ -96,11 +217,22 @@ class FusedRMSNorm: ...@@ -96,11 +217,22 @@ class FusedRMSNorm:
def __init__(self) -> None: def __init__(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. " "FusedRMSNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex." "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
) )
@staticmethod @staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: FusedRMSNorm module.
"""
try: try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
except ImportError: except ImportError:
...@@ -124,4 +256,10 @@ class FusedRMSNorm: ...@@ -124,4 +256,10 @@ class FusedRMSNorm:
rmsnorm.weight = module.weight rmsnorm.weight = module.weight
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
return rmsnorm return rmsnorm
from contextlib import contextmanager from contextlib import contextmanager
from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size
class SeqParallelUtils:
@staticmethod
def marked_as_sp_partial_derived_param(param):
"""
Mark a parameter as partially derived in sequence parallelism.
Args:
param: The parameter to mark as partially derived.
"""
setattr(param, "partial_derived", True)
@staticmethod
def is_sp_partial_derived_param(param):
"""
Check if a parameter is marked as partially derived in sequence parallelism.
Args:
param: The parameter to check.
Returns:
bool: True if the parameter is marked as partially derived, False otherwise.
"""
return getattr(param, "partial_derived", False)
@staticmethod
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
"""
Allreduce partial derived gradients across the specified process group.
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
Args:
tp_group (ProcessGroup): The process group for gradient synchronization.
model (nn.Module): The model from which gradients will be synchronized.
grads (List[torch.Tensor]): The list of gradients to be synchronized.
Raises:
AssertionError: If both `model` and `grads` are provided or neither is provided.
"""
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
# Get the size of the process group, which determines whether synchronization is needed.
tp_size = get_world_size(tp_group) if tp_group is not None else 1
if tp_size == 1:
# If the process group size is 1, no synchronization is required.
return
if model is not None:
# If `model` is provided, extract partial derived gradients from the model's parameters.
grads = []
for p in model.parameters():
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
grads.append(p.grad.data)
# Flatten and reduce the gradients using the specified process group.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
# Unflatten the synchronized gradients and update the model's gradients.
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
else:
# If `grads` are provided explicitly, synchronize those gradients directly.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
class Randomizer: class Randomizer:
......
...@@ -4,6 +4,7 @@ from typing import Optional ...@@ -4,6 +4,7 @@ from typing import Optional
import torch.nn as nn import torch.nn as nn
from ..shard.shard_config import ShardConfig
from .base_policy import Policy from .base_policy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] __all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
...@@ -197,7 +198,7 @@ def _fullname(obj): ...@@ -197,7 +198,7 @@ def _fullname(obj):
return module + "." + klass.__qualname__ return module + "." + klass.__qualname__
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy: def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
r""" r"""
Return the auto policy for the model Return the auto policy for the model
...@@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> ...@@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
:class:`Policy`: The auto policy for the model :class:`Policy`: The auto policy for the model
""" """
full_name = _fullname(model) full_name = _fullname(model)
if inference_only: if shard_config.inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None) policy_location = _INFER_POLICY_LIST.get(full_name, None)
else: else:
policy_location = _POLICY_LIST.get(full_name, None) policy_location = _POLICY_LIST.get(full_name, None)
...@@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> ...@@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
) )
else: else:
policy = import_policy(policy_location, inference_only) policy = import_policy(policy_location, shard_config.inference_only)
return policy() return policy()
...@@ -11,6 +11,7 @@ from torch.nn import Module ...@@ -11,6 +11,7 @@ from torch.nn import Module
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from ..layer.normalization import BaseLayerNorm
from ..layer.parallel_module import ParallelModule from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig from ..shard.shard_config import ShardConfig
...@@ -29,7 +30,7 @@ class SubModuleReplacementDescription: ...@@ -29,7 +30,7 @@ class SubModuleReplacementDescription:
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
""" """
suffix: str suffix: str
target_module: ParallelModule target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False ignore_if_not_exist: bool = False
...@@ -77,7 +78,6 @@ class Policy(ABC): ...@@ -77,7 +78,6 @@ class Policy(ABC):
def set_model(self, model: nn.Module) -> None: def set_model(self, model: nn.Module) -> None:
r""" r"""
Set model as an attribute of the Policy object so that we can access the model's attributes. Set model as an attribute of the Policy object so that we can access the model's attributes.
Args: Args:
model (:class:`nn.Module`): The model to be perform model (:class:`nn.Module`): The model to be perform
""" """
...@@ -86,11 +86,11 @@ class Policy(ABC): ...@@ -86,11 +86,11 @@ class Policy(ABC):
def set_shard_config(self, shard_config: ShardConfig) -> None: def set_shard_config(self, shard_config: ShardConfig) -> None:
r""" r"""
Set shard config as an attribute of the Policy object. Set shard config as an attribute of the Policy object.
Args: Args:
shard_config (:class:`ShardConfig`): The shard config to be perform shard_config (:class:`ShardConfig`): The shard config to be perform
""" """
self.shard_config = shard_config self.shard_config = shard_config
self.config_sanity_check() self.config_sanity_check()
@property @property
......
...@@ -60,6 +60,12 @@ class BertPolicy(Policy): ...@@ -60,6 +60,12 @@ class BertPolicy(Policy):
) )
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
...@@ -141,33 +147,34 @@ class BertPolicy(Policy): ...@@ -141,33 +147,34 @@ class BertPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # Handle bert layer
# Handle bert layer self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="attention.output.LayerNorm",
suffix="attention.output.LayerNorm", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, kwargs={"sp_partial_derived": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.LayerNorm", suffix="output.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), kwargs={"sp_partial_derived": use_sequence_parallel},
], ),
policy=policy, ],
target_key=BertLayer, policy=policy,
) target_key=BertLayer,
# handle embedding layer )
self.append_or_create_submodule_replacement( # handle embedding layer
description=[ self.append_or_create_submodule_replacement(
SubModuleReplacementDescription( description=[
suffix="LayerNorm", SubModuleReplacementDescription(
target_module=col_nn.FusedLayerNorm, suffix="LayerNorm",
) target_module=norm_cls,
], )
policy=policy, ],
target_key=BertEmbeddings, policy=policy,
) target_key=BertEmbeddings,
)
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
...@@ -288,9 +295,6 @@ class BertPolicy(Policy): ...@@ -288,9 +295,6 @@ class BertPolicy(Policy):
# BertModel # BertModel
class BertModelPolicy(BertPolicy): class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertModel from transformers.models.bert.modeling_bert import BertModel
...@@ -313,9 +317,6 @@ class BertModelPolicy(BertPolicy): ...@@ -313,9 +317,6 @@ class BertModelPolicy(BertPolicy):
# BertForPreTraining # BertForPreTraining
class BertForPreTrainingPolicy(BertPolicy): class BertForPreTrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
policy = self.add_lm_head_policy(policy) policy = self.add_lm_head_policy(policy)
...@@ -355,9 +356,6 @@ class BertForPreTrainingPolicy(BertPolicy): ...@@ -355,9 +356,6 @@ class BertForPreTrainingPolicy(BertPolicy):
# BertLMHeadModel # BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy): class BertLMHeadModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
policy = self.add_lm_head_policy(policy) policy = self.add_lm_head_policy(policy)
...@@ -396,9 +394,6 @@ class BertLMHeadModelPolicy(BertPolicy): ...@@ -396,9 +394,6 @@ class BertLMHeadModelPolicy(BertPolicy):
# BertForMaskedLM # BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy): class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
policy = self.add_lm_head_policy(policy) policy = self.add_lm_head_policy(policy)
...@@ -437,9 +432,6 @@ class BertForMaskedLMPolicy(BertPolicy): ...@@ -437,9 +432,6 @@ class BertForMaskedLMPolicy(BertPolicy):
# BertForSequenceClassification # BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy): class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification from transformers.models.bert.modeling_bert import BertForSequenceClassification
...@@ -484,9 +476,6 @@ class BertForSequenceClassificationPolicy(BertPolicy): ...@@ -484,9 +476,6 @@ class BertForSequenceClassificationPolicy(BertPolicy):
# BertForTokenClassification # BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy): class BertForTokenClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification from transformers.models.bert.modeling_bert import BertForTokenClassification
...@@ -531,9 +520,6 @@ class BertForTokenClassificationPolicy(BertPolicy): ...@@ -531,9 +520,6 @@ class BertForTokenClassificationPolicy(BertPolicy):
# BertForNextSentencePrediction # BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy): class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
...@@ -564,9 +550,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy): ...@@ -564,9 +550,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
# BertForMultipleChoice # BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy): class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice from transformers.models.bert.modeling_bert import BertForMultipleChoice
...@@ -610,9 +593,6 @@ class BertForMultipleChoicePolicy(BertPolicy): ...@@ -610,9 +593,6 @@ class BertForMultipleChoicePolicy(BertPolicy):
class BertForQuestionAnsweringPolicy(BertPolicy): class BertForQuestionAnsweringPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForQuestionAnswering from transformers.models.bert.modeling_bert import BertForQuestionAnswering
......
...@@ -43,6 +43,11 @@ class BlipPolicy(Policy): ...@@ -43,6 +43,11 @@ class BlipPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[Blip2EncoderLayer] = ModulePolicyDescription( policy[Blip2EncoderLayer] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
...@@ -214,94 +219,93 @@ class BlipPolicy(Policy): ...@@ -214,94 +219,93 @@ class BlipPolicy(Policy):
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # Handle Blip2EncoderLayer layer
# Handle Blip2EncoderLayer layer self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="layer_norm1",
suffix="layer_norm1", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="layer_norm2",
suffix="layer_norm2", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), ],
], policy=policy,
policy=policy, target_key=Blip2EncoderLayer,
target_key=Blip2EncoderLayer, )
)
# handle Blip2VisionModel layer # handle Blip2VisionModel layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="post_layernorm", suffix="post_layernorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=Blip2VisionModel, target_key=Blip2VisionModel,
) )
# handle Blip2VisionModel layer # handle Blip2VisionModel layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layernorm", suffix="layernorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=Blip2QFormerModel, target_key=Blip2QFormerModel,
) )
# handle Blip2QFormerLayer layer # handle Blip2QFormerLayer layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.LayerNorm", suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="crossattention.output.LayerNorm", suffix="crossattention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output_query.LayerNorm", suffix="output_query.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
], ],
policy=policy, policy=policy,
target_key=Blip2QFormerLayer, target_key=Blip2QFormerLayer,
) )
# handle OPTForCausalLM layer # handle OPTForCausalLM layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="model.decoder.final_layer_norm", suffix="model.decoder.final_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=OPTForCausalLM, target_key=OPTForCausalLM,
) )
# handle OPTDecoderLayer layer # handle OPTDecoderLayer layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn_layer_norm", suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="final_layer_norm", suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
], ],
policy=policy, policy=policy,
target_key=OPTDecoderLayer, target_key=OPTDecoderLayer,
) )
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
......
...@@ -42,6 +42,10 @@ class BloomPolicy(Policy): ...@@ -42,6 +42,10 @@ class BloomPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
...@@ -97,38 +101,39 @@ class BloomPolicy(Policy): ...@@ -97,38 +101,39 @@ class BloomPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # handle bloom model
# handle bloom model self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="ln_f",
suffix="ln_f", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="word_embeddings_layernorm",
suffix="word_embeddings_layernorm", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), ],
], policy=policy,
policy=policy, target_key=BloomModel,
target_key=BloomModel, )
)
# handle bloom block
# handle bloom block self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="input_layernorm",
suffix="input_layernorm", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, kwargs={"sp_partial_derived": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="post_attention_layernorm", suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), kwargs={"sp_partial_derived": use_sequence_parallel},
], ),
policy=policy, ],
target_key=BloomBlock, policy=policy,
) target_key=BloomBlock,
)
if use_sequence_parallel: if use_sequence_parallel:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
...@@ -225,9 +230,6 @@ class BloomPolicy(Policy): ...@@ -225,9 +230,6 @@ class BloomPolicy(Policy):
class BloomModelPolicy(BloomPolicy): class BloomModelPolicy(BloomPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
from transformers.models.bloom.modeling_bloom import BloomModel from transformers.models.bloom.modeling_bloom import BloomModel
......
...@@ -45,6 +45,16 @@ class ChatGLMPolicy(Policy): ...@@ -45,6 +45,16 @@ class ChatGLMPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
if self.model.config.rmsnorm:
norm_cls = col_nn.FusedRMSNorm
else:
norm_cls = col_nn.FusedLayerNorm
else:
if self.model.config.rmsnorm:
norm_cls = col_nn.RMSNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
...@@ -96,52 +106,34 @@ class ChatGLMPolicy(Policy): ...@@ -96,52 +106,34 @@ class ChatGLMPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
if not self.model.config.rmsnorm: description=[
self.append_or_create_submodule_replacement( SubModuleReplacementDescription(
description=[ suffix="input_layernorm",
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), target_module=norm_cls,
SubModuleReplacementDescription( kwargs={"sp_partial_derived": use_sequence_parallel},
suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm ),
), SubModuleReplacementDescription(
], suffix="post_attention_layernorm",
policy=policy, target_module=norm_cls,
target_key=GLMBlock, kwargs={"sp_partial_derived": use_sequence_parallel},
) ),
],
if self.model.config.post_layer_norm: policy=policy,
self.append_or_create_submodule_replacement( target_key=GLMBlock,
description=[ )
SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm if self.model.config.post_layer_norm:
) self.append_or_create_submodule_replacement(
], description=[
policy=policy, SubModuleReplacementDescription(
target_key=ChatGLMModel, suffix="encoder.final_layernorm",
) target_module=norm_cls,
else:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
SubModuleReplacementDescription(
suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm
),
],
policy=policy,
target_key=GLMBlock,
)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm
)
],
policy=policy,
target_key=ChatGLMModel,
) )
],
policy=policy,
target_key=ChatGLMModel,
)
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
...@@ -224,9 +216,6 @@ class ChatGLMPolicy(Policy): ...@@ -224,9 +216,6 @@ class ChatGLMPolicy(Policy):
class ChatGLMModelPolicy(ChatGLMPolicy): class ChatGLMModelPolicy(ChatGLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
pass pass
......
...@@ -39,6 +39,11 @@ class GPT2Policy(Policy): ...@@ -39,6 +39,11 @@ class GPT2Policy(Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
...@@ -102,33 +107,37 @@ class GPT2Policy(Policy): ...@@ -102,33 +107,37 @@ class GPT2Policy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(
description=SubModuleReplacementDescription( suffix="ln_f",
suffix="ln_f", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
policy=policy,
target_key=GPT2Model,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_1",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
), ),
policy=policy, SubModuleReplacementDescription(
target_key=GPT2Model, suffix="ln_2",
) target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
self.append_or_create_submodule_replacement( ),
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="ln_cross_attn",
suffix="ln_1", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True,
), kwargs={"sp_partial_derived": use_sequence_parallel},
SubModuleReplacementDescription( ),
suffix="ln_2", ],
target_module=col_nn.FusedLayerNorm, policy=policy,
), target_key=GPT2Block,
SubModuleReplacementDescription( )
suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
),
],
policy=policy,
target_key=GPT2Block,
)
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
...@@ -192,9 +201,6 @@ class GPT2Policy(Policy): ...@@ -192,9 +201,6 @@ class GPT2Policy(Policy):
# GPT2Model # GPT2Model
class GPT2ModelPolicy(GPT2Policy): class GPT2ModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
...@@ -216,9 +222,6 @@ class GPT2ModelPolicy(GPT2Policy): ...@@ -216,9 +222,6 @@ class GPT2ModelPolicy(GPT2Policy):
# GPT2LMHeadModel # GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy): class GPT2LMHeadModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
...@@ -263,9 +266,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy): ...@@ -263,9 +266,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
# GPT2DoubleHeadsModel # GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy): class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
...@@ -317,9 +317,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): ...@@ -317,9 +317,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
# GPT2ForQuestionAnswering # GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy): class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
...@@ -347,9 +344,6 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy): ...@@ -347,9 +344,6 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
# GPT2ForTokenClassification # GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy): class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
...@@ -387,9 +381,6 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy): ...@@ -387,9 +381,6 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
# GPT2ForSequenceClassification # GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy): class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
......
...@@ -6,7 +6,7 @@ import torch.nn as nn ...@@ -6,7 +6,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
...@@ -35,6 +35,11 @@ class LlamaPolicy(Policy): ...@@ -35,6 +35,11 @@ class LlamaPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
norm_cls = RMSNorm
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
...@@ -93,31 +98,31 @@ class LlamaPolicy(Policy): ...@@ -93,31 +98,31 @@ class LlamaPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="input_layernorm",
suffix="input_layernorm", target_module=norm_cls,
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=LlamaDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
), ),
policy=policy, SubModuleReplacementDescription(
target_key=LlamaModel, suffix="post_attention_layernorm",
) target_module=norm_cls,
),
],
policy=policy,
target_key=LlamaDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
),
policy=policy,
target_key=LlamaModel,
)
# use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
...@@ -174,9 +179,6 @@ class LlamaPolicy(Policy): ...@@ -174,9 +179,6 @@ class LlamaPolicy(Policy):
class LlamaModelPolicy(LlamaPolicy): class LlamaModelPolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel from transformers.models.llama.modeling_llama import LlamaModel
......
...@@ -5,7 +5,7 @@ from typing import Callable, Dict, List ...@@ -5,7 +5,7 @@ from typing import Callable, Dict, List
import torch.nn as nn import torch.nn as nn
from torch import Tensor, nn from torch import Tensor, nn
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_ from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.jit import get_jit_fused_dropout_add_func
...@@ -42,6 +42,12 @@ class OPTPolicy(Policy): ...@@ -42,6 +42,12 @@ class OPTPolicy(Policy):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
norm_cls = LayerNorm
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
...@@ -94,26 +100,25 @@ class OPTPolicy(Policy): ...@@ -94,26 +100,25 @@ class OPTPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(
description=SubModuleReplacementDescription( suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True ),
policy=policy,
target_key=OPTDecoder,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
), ),
policy=policy, SubModuleReplacementDescription(
target_key=OPTDecoder, suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
) ),
self.append_or_create_submodule_replacement( ],
description=[ policy=policy,
SubModuleReplacementDescription( target_key=OPTDecoderLayer,
suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True )
),
SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
...@@ -183,9 +188,6 @@ class OPTPolicy(Policy): ...@@ -183,9 +188,6 @@ class OPTPolicy(Policy):
class OPTModelPolicy(OPTPolicy): class OPTModelPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.opt.modeling_opt import OPTModel from transformers.models.opt.modeling_opt import OPTModel
...@@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy): ...@@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
class OPTForSequenceClassificationPolicy(OPTPolicy): class OPTForSequenceClassificationPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForSequenceClassification from transformers.models.opt.modeling_opt import OPTForSequenceClassification
...@@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy): ...@@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
class OPTForQuestionAnsweringPolicy(OPTPolicy): class OPTForQuestionAnsweringPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering from transformers.models.opt.modeling_opt import OPTForQuestionAnswering
......
...@@ -24,6 +24,11 @@ class SamPolicy(Policy): ...@@ -24,6 +24,11 @@ class SamPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[SamVisionLayer] = ModulePolicyDescription( policy[SamVisionLayer] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
...@@ -151,58 +156,57 @@ class SamPolicy(Policy): ...@@ -151,58 +156,57 @@ class SamPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # Handle SamVisionLayer
# Handle SamVisionLayer self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="layer_norm1",
suffix="layer_norm1", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="layer_norm2",
suffix="layer_norm2", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), ],
], policy=policy,
policy=policy, target_key=SamVisionLayer,
target_key=SamVisionLayer, )
)
# Handle SamTwoWayAttentionBlock # Handle SamTwoWayAttentionBlock
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm1", suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm2", suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm3", suffix="layer_norm3",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm4", suffix="layer_norm4",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
], ],
policy=policy, policy=policy,
target_key=SamTwoWayAttentionBlock, target_key=SamTwoWayAttentionBlock,
) )
# Handle SamTwoWayTransformer # Handle SamTwoWayTransformer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm_final_attn", suffix="layer_norm_final_attn",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=SamTwoWayTransformer, target_key=SamTwoWayTransformer,
) )
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
......
...@@ -11,6 +11,7 @@ from colossalai.shardformer.layer import ( ...@@ -11,6 +11,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm, FusedRMSNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
RMSNorm,
VocabParallelEmbedding1D, VocabParallelEmbedding1D,
) )
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
...@@ -58,6 +59,11 @@ class T5BasePolicy(Policy): ...@@ -58,6 +59,11 @@ class T5BasePolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
norm_cls = RMSNorm
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
...@@ -169,38 +175,37 @@ class T5BasePolicy(Policy): ...@@ -169,38 +175,37 @@ class T5BasePolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(
description=SubModuleReplacementDescription( suffix="layer_norm",
suffix="layer_norm", target_module=norm_cls,
target_module=FusedRMSNorm, ),
), policy=policy,
policy=policy, target_key=T5LayerFF,
target_key=T5LayerFF, )
) self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(
description=SubModuleReplacementDescription( suffix="layer_norm",
suffix="layer_norm", target_module=norm_cls,
target_module=FusedRMSNorm, ),
), policy=policy,
policy=policy, target_key=T5LayerFF,
target_key=T5LayerFF, )
) self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), policy=policy,
policy=policy, target_key=T5LayerSelfAttention,
target_key=T5LayerSelfAttention, )
) self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), policy=policy,
policy=policy, target_key=T5LayerCrossAttention,
target_key=T5LayerCrossAttention, )
) self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=norm_cls),
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm), policy=policy,
policy=policy, target_key=T5Stack,
target_key=T5Stack, )
)
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
...@@ -363,9 +368,6 @@ class T5BasePolicy(Policy): ...@@ -363,9 +368,6 @@ class T5BasePolicy(Policy):
class T5ModelPolicy(T5BasePolicy): class T5ModelPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import T5Model from transformers import T5Model
...@@ -402,9 +404,6 @@ class T5ModelPolicy(T5BasePolicy): ...@@ -402,9 +404,6 @@ class T5ModelPolicy(T5BasePolicy):
class T5ForConditionalGenerationPolicy(T5BasePolicy): class T5ForConditionalGenerationPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import T5ForConditionalGeneration from transformers import T5ForConditionalGeneration
...@@ -466,9 +465,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): ...@@ -466,9 +465,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
class T5EncoderPolicy(T5BasePolicy): class T5EncoderPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import T5EncoderModel from transformers import T5EncoderModel
......
...@@ -159,9 +159,6 @@ class ViTPolicy(Policy): ...@@ -159,9 +159,6 @@ class ViTPolicy(Policy):
# ViTModel # ViTModel
class ViTModelPolicy(ViTPolicy): class ViTModelPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.vit.modeling_vit import ViTModel from transformers.models.vit.modeling_vit import ViTModel
...@@ -227,9 +224,6 @@ class ViTForImageClassificationPolicy(ViTPolicy): ...@@ -227,9 +224,6 @@ class ViTForImageClassificationPolicy(ViTPolicy):
# ViTForMaskedImageModeling # ViTForMaskedImageModeling
class ViTForMaskedImageModelingPolicy(ViTPolicy): class ViTForMaskedImageModelingPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
......
...@@ -52,6 +52,11 @@ class WhisperPolicy(Policy): ...@@ -52,6 +52,11 @@ class WhisperPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn( warnings.warn(
...@@ -161,62 +166,61 @@ class WhisperPolicy(Policy): ...@@ -161,62 +166,61 @@ class WhisperPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # Handle encoder layer
# Handle encoder layer self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="self_attn_layer_norm",
suffix="self_attn_layer_norm", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="final_layer_norm",
suffix="final_layer_norm", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), ],
], policy=policy,
policy=policy, target_key=WhisperEncoderLayer,
target_key=WhisperEncoderLayer, )
)
# Handle decoder layer # Handle decoder layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn_layer_norm", suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="final_layer_norm", suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
], ],
policy=policy, policy=policy,
target_key=WhisperDecoderLayer, target_key=WhisperDecoderLayer,
) )
# handle encoder layer # handle encoder layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm", suffix="layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=WhisperEncoder, target_key=WhisperEncoder,
) )
# handle decoder layer # handle decoder layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm", suffix="layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=WhisperDecoder, target_key=WhisperDecoder,
) )
# enable flash attention # enable flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
...@@ -416,9 +420,6 @@ class WhisperPolicy(Policy): ...@@ -416,9 +420,6 @@ class WhisperPolicy(Policy):
# WhisperModel # WhisperModel
class WhisperModelPolicy(WhisperPolicy): class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import WhisperModel from transformers import WhisperModel
...@@ -441,9 +442,6 @@ class WhisperModelPolicy(WhisperPolicy): ...@@ -441,9 +442,6 @@ class WhisperModelPolicy(WhisperPolicy):
# WhisperForConditionalGeneration # WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy): class WhisperForConditionalGenerationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import WhisperForConditionalGeneration from transformers import WhisperForConditionalGeneration
...@@ -502,9 +500,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): ...@@ -502,9 +500,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
# WhisperForAudioClassification # WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy): class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def preprocess(self): def preprocess(self):
return self.model return self.model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment