Unverified Commit 1a3315e3 authored by littsk's avatar littsk Committed by GitHub
Browse files

[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)



* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------
Co-authored-by: default avatarlittsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------
Co-authored-by: default avatarlittsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------
Co-authored-by: default avatarZhongkai Zhao <kanezz620@gmail.com>
parent d99b2c96
......@@ -338,7 +338,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
tp_group=self.tp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
......
......@@ -218,10 +218,10 @@ class TPInferEngine:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
model = model.model if self.shard_config.inference_gptq else model
policy = get_autopolicy(model, shard_config=self.shard_config)
policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy)
if self.shard_config.inference_gptq:
......
......@@ -235,6 +235,14 @@ class SubModuleReplacementDescription:
class Policy(ABC):
r"""
The base class for all the policies. For each different model, it should have a different policy class,
like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""
def __init__(self)
self.model = None
......@@ -245,6 +253,16 @@ class Policy(ABC):
"""
self.model = model
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()
@abstractmethod
def preprocess(self) -> nn.Module:
"""
......
......@@ -2,7 +2,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
......@@ -16,6 +16,9 @@ __all__ = [
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
"FusedLayerNorm",
"FusedRMSNorm",
"FusedLinear1D_Col",
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch.nn as nn
from colossalai.lazy import LazyInitContext
__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
from .utils import SeqParallelUtils
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
......@@ -35,7 +38,103 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
]
class FusedLayerNorm:
class BaseLayerNorm(ABC):
@abstractmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
"""
Convert a native PyTorch layer normalization module to a specific layer normalization module,
and optionally mark parameters for gradient aggregation.
Args:
module (nn.Module): The native PyTorch layer normalization module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The specific layer normalization module.
Raises:
AssertionError: If the provided module is not an instance of the supported layer normalization type.
"""
class RMSNorm(BaseLayerNorm):
r"""
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
"""
Convert a native RMSNorm module to colossalai layer norm module,
and optionally mark parameters for gradient aggregation.
Args:
module (nn.Module): The native RMSNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The RMSNorm module.
"""
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
return module
class LayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"LayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The LayerNorm module.
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
return module
class FusedLayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
"""
......@@ -43,15 +142,29 @@ class FusedLayerNorm:
def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
)
@staticmethod
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
# check if apex is installed
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
try:
pass
except ImportError:
......@@ -85,10 +198,18 @@ class FusedLayerNorm:
layernorm.weight = module.weight
layernorm.bias = module.bias
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
return layernorm
class FusedRMSNorm:
class FusedRMSNorm(BaseLayerNorm):
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""
......@@ -96,11 +217,22 @@ class FusedRMSNorm:
def __init__(self) -> None:
raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
)
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: FusedRMSNorm module.
"""
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
except ImportError:
......@@ -124,4 +256,10 @@ class FusedRMSNorm:
rmsnorm.weight = module.weight
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
return rmsnorm
from contextlib import contextmanager
from typing import List
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size
class SeqParallelUtils:
@staticmethod
def marked_as_sp_partial_derived_param(param):
"""
Mark a parameter as partially derived in sequence parallelism.
Args:
param: The parameter to mark as partially derived.
"""
setattr(param, "partial_derived", True)
@staticmethod
def is_sp_partial_derived_param(param):
"""
Check if a parameter is marked as partially derived in sequence parallelism.
Args:
param: The parameter to check.
Returns:
bool: True if the parameter is marked as partially derived, False otherwise.
"""
return getattr(param, "partial_derived", False)
@staticmethod
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
"""
Allreduce partial derived gradients across the specified process group.
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
Args:
tp_group (ProcessGroup): The process group for gradient synchronization.
model (nn.Module): The model from which gradients will be synchronized.
grads (List[torch.Tensor]): The list of gradients to be synchronized.
Raises:
AssertionError: If both `model` and `grads` are provided or neither is provided.
"""
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
# Get the size of the process group, which determines whether synchronization is needed.
tp_size = get_world_size(tp_group) if tp_group is not None else 1
if tp_size == 1:
# If the process group size is 1, no synchronization is required.
return
if model is not None:
# If `model` is provided, extract partial derived gradients from the model's parameters.
grads = []
for p in model.parameters():
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
grads.append(p.grad.data)
# Flatten and reduce the gradients using the specified process group.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
# Unflatten the synchronized gradients and update the model's gradients.
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
else:
# If `grads` are provided explicitly, synchronize those gradients directly.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
class Randomizer:
......
......@@ -4,6 +4,7 @@ from typing import Optional
import torch.nn as nn
from ..shard.shard_config import ShardConfig
from .base_policy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
......@@ -197,7 +198,7 @@ def _fullname(obj):
return module + "." + klass.__qualname__
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
r"""
Return the auto policy for the model
......@@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
if inference_only:
if shard_config.inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
......@@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location, inference_only)
policy = import_policy(policy_location, shard_config.inference_only)
return policy()
......@@ -11,6 +11,7 @@ from torch.nn import Module
from colossalai.pipeline.stage_manager import PipelineStageManager
from ..layer.normalization import BaseLayerNorm
from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig
......@@ -29,7 +30,7 @@ class SubModuleReplacementDescription:
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: ParallelModule
target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False
......@@ -77,7 +78,6 @@ class Policy(ABC):
def set_model(self, model: nn.Module) -> None:
r"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
Args:
model (:class:`nn.Module`): The model to be perform
"""
......@@ -86,11 +86,11 @@ class Policy(ABC):
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()
@property
......
......@@ -60,6 +60,12 @@ class BertPolicy(Policy):
)
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
......@@ -141,33 +147,34 @@ class BertPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle bert layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=BertLayer,
)
# handle embedding layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=BertEmbeddings,
)
# Handle bert layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
target_key=BertLayer,
)
# handle embedding layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=norm_cls,
)
],
policy=policy,
target_key=BertEmbeddings,
)
# use flash attention
if self.shard_config.enable_flash_attention:
......@@ -288,9 +295,6 @@ class BertPolicy(Policy):
# BertModel
class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertModel
......@@ -313,9 +317,6 @@ class BertModelPolicy(BertPolicy):
# BertForPreTraining
class BertForPreTrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
......@@ -355,9 +356,6 @@ class BertForPreTrainingPolicy(BertPolicy):
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
......@@ -396,9 +394,6 @@ class BertLMHeadModelPolicy(BertPolicy):
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
......@@ -437,9 +432,6 @@ class BertForMaskedLMPolicy(BertPolicy):
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification
......@@ -484,9 +476,6 @@ class BertForSequenceClassificationPolicy(BertPolicy):
# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification
......@@ -531,9 +520,6 @@ class BertForTokenClassificationPolicy(BertPolicy):
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
......@@ -564,9 +550,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice
......@@ -610,9 +593,6 @@ class BertForMultipleChoicePolicy(BertPolicy):
class BertForQuestionAnsweringPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
......
......@@ -43,6 +43,11 @@ class BlipPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism:
policy[Blip2EncoderLayer] = ModulePolicyDescription(
attribute_replacement={
......@@ -214,94 +219,93 @@ class BlipPolicy(Policy):
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle Blip2EncoderLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=Blip2EncoderLayer,
)
# Handle Blip2EncoderLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=norm_cls,
),
],
policy=policy,
target_key=Blip2EncoderLayer,
)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="post_layernorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2VisionModel,
)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="post_layernorm",
target_module=norm_cls,
)
],
policy=policy,
target_key=Blip2VisionModel,
)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layernorm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=Blip2QFormerModel,
)
# handle Blip2VisionModel layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layernorm",
target_module=norm_cls,
)
],
policy=policy,
target_key=Blip2QFormerModel,
)
# handle Blip2QFormerLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="crossattention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="output_query.LayerNorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=Blip2QFormerLayer,
)
# handle Blip2QFormerLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="crossattention.output.LayerNorm",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="output_query.LayerNorm",
target_module=norm_cls,
),
],
policy=policy,
target_key=Blip2QFormerLayer,
)
# handle OPTForCausalLM layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="model.decoder.final_layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=OPTForCausalLM,
)
# handle OPTForCausalLM layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="model.decoder.final_layer_norm",
target_module=norm_cls,
)
],
policy=policy,
target_key=OPTForCausalLM,
)
# handle OPTDecoderLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
# handle OPTDecoderLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=norm_cls,
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
# use flash attention
if self.shard_config.enable_flash_attention:
......
......@@ -42,6 +42,10 @@ class BloomPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
......@@ -97,38 +101,39 @@ class BloomPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# handle bloom model
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=BloomModel,
)
# handle bloom block
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=BloomBlock,
)
# handle bloom model
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=norm_cls,
),
],
policy=policy,
target_key=BloomModel,
)
# handle bloom block
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
target_key=BloomBlock,
)
if use_sequence_parallel:
self.append_or_create_method_replacement(
......@@ -225,9 +230,6 @@ class BloomPolicy(Policy):
class BloomModelPolicy(BloomPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bloom.modeling_bloom import BloomModel
......
......@@ -45,6 +45,16 @@ class ChatGLMPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
if self.model.config.rmsnorm:
norm_cls = col_nn.FusedRMSNorm
else:
norm_cls = col_nn.FusedLayerNorm
else:
if self.model.config.rmsnorm:
norm_cls = col_nn.RMSNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
......@@ -96,52 +106,34 @@ class ChatGLMPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
if not self.model.config.rmsnorm:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
SubModuleReplacementDescription(
suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm
),
],
policy=policy,
target_key=GLMBlock,
)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm
)
],
policy=policy,
target_key=ChatGLMModel,
)
else:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
SubModuleReplacementDescription(
suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm
),
],
policy=policy,
target_key=GLMBlock,
)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm
)
],
policy=policy,
target_key=ChatGLMModel,
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
target_key=GLMBlock,
)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="encoder.final_layernorm",
target_module=norm_cls,
)
],
policy=policy,
target_key=ChatGLMModel,
)
# use flash attention
if self.shard_config.enable_flash_attention:
......@@ -224,9 +216,6 @@ class ChatGLMPolicy(Policy):
class ChatGLMModelPolicy(ChatGLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
pass
......
......@@ -39,6 +39,11 @@ class GPT2Policy(Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
......@@ -102,33 +107,37 @@ class GPT2Policy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="ln_f",
target_module=norm_cls,
),
policy=policy,
target_key=GPT2Model,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_1",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
policy=policy,
target_key=GPT2Model,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="ln_2",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
),
],
policy=policy,
target_key=GPT2Block,
)
SubModuleReplacementDescription(
suffix="ln_2",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="ln_cross_attn",
target_module=norm_cls,
ignore_if_not_exist=True,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
target_key=GPT2Block,
)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
......@@ -192,9 +201,6 @@ class GPT2Policy(Policy):
# GPT2Model
class GPT2ModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
......@@ -216,9 +222,6 @@ class GPT2ModelPolicy(GPT2Policy):
# GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
......@@ -263,9 +266,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
# GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
......@@ -317,9 +317,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
# GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
......@@ -347,9 +344,6 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
# GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
......@@ -387,9 +381,6 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
# GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
......
......@@ -6,7 +6,7 @@ import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
......@@ -35,6 +35,11 @@ class LlamaPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
norm_cls = RMSNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
......@@ -93,31 +98,31 @@ class LlamaPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=LlamaDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=norm_cls,
),
policy=policy,
target_key=LlamaModel,
)
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
),
],
policy=policy,
target_key=LlamaDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
),
policy=policy,
target_key=LlamaModel,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
......@@ -174,9 +179,6 @@ class LlamaPolicy(Policy):
class LlamaModelPolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel
......
......@@ -5,7 +5,7 @@ from typing import Callable, Dict, List
import torch.nn as nn
from torch import Tensor, nn
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
......@@ -42,6 +42,12 @@ class OPTPolicy(Policy):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
norm_cls = LayerNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
......@@ -94,26 +100,25 @@ class OPTPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
policy=policy,
target_key=OPTDecoder,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
policy=policy,
target_key=OPTDecoder,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
),
SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
# use flash attention
if self.shard_config.enable_flash_attention:
......@@ -183,9 +188,6 @@ class OPTPolicy(Policy):
class OPTModelPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTModel
......@@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
class OPTForSequenceClassificationPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForSequenceClassification
......@@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
class OPTForQuestionAnsweringPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering
......
......@@ -24,6 +24,11 @@ class SamPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism:
policy[SamVisionLayer] = ModulePolicyDescription(
attribute_replacement={
......@@ -151,58 +156,57 @@ class SamPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle SamVisionLayer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=SamVisionLayer,
)
# Handle SamVisionLayer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=norm_cls,
),
],
policy=policy,
target_key=SamVisionLayer,
)
# Handle SamTwoWayAttentionBlock
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layer_norm3",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layer_norm4",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=SamTwoWayAttentionBlock,
)
# Handle SamTwoWayAttentionBlock
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm3",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm4",
target_module=norm_cls,
),
],
policy=policy,
target_key=SamTwoWayAttentionBlock,
)
# Handle SamTwoWayTransformer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm_final_attn",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=SamTwoWayTransformer,
)
# Handle SamTwoWayTransformer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm_final_attn",
target_module=norm_cls,
)
],
policy=policy,
target_key=SamTwoWayTransformer,
)
# use flash attention
if self.shard_config.enable_flash_attention:
......
......@@ -11,6 +11,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
RMSNorm,
VocabParallelEmbedding1D,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
......@@ -58,6 +59,11 @@ class T5BasePolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
norm_cls = RMSNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
......@@ -169,38 +175,37 @@ class T5BasePolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerSelfAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5LayerCrossAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=norm_cls,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=norm_cls,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
policy=policy,
target_key=T5LayerSelfAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
policy=policy,
target_key=T5LayerCrossAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=norm_cls),
policy=policy,
target_key=T5Stack,
)
# use flash attention
if self.shard_config.enable_flash_attention:
......@@ -363,9 +368,6 @@ class T5BasePolicy(Policy):
class T5ModelPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5Model
......@@ -402,9 +404,6 @@ class T5ModelPolicy(T5BasePolicy):
class T5ForConditionalGenerationPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5ForConditionalGeneration
......@@ -466,9 +465,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
class T5EncoderPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5EncoderModel
......
......@@ -159,9 +159,6 @@ class ViTPolicy(Policy):
# ViTModel
class ViTModelPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTModel
......@@ -227,9 +224,6 @@ class ViTForImageClassificationPolicy(ViTPolicy):
# ViTForMaskedImageModeling
class ViTForMaskedImageModelingPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
......
......@@ -52,6 +52,11 @@ class WhisperPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
......@@ -161,62 +166,61 @@ class WhisperPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle encoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=WhisperEncoderLayer,
)
# Handle encoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=norm_cls,
),
],
policy=policy,
target_key=WhisperEncoderLayer,
)
# Handle decoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
),
],
policy=policy,
target_key=WhisperDecoderLayer,
)
# Handle decoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=norm_cls,
),
],
policy=policy,
target_key=WhisperDecoderLayer,
)
# handle encoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperEncoder,
)
# handle encoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=norm_cls,
)
],
policy=policy,
target_key=WhisperEncoder,
)
# handle decoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=WhisperDecoder,
)
# handle decoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=norm_cls,
)
],
policy=policy,
target_key=WhisperDecoder,
)
# enable flash attention
if self.shard_config.enable_flash_attention:
......@@ -416,9 +420,6 @@ class WhisperPolicy(Policy):
# WhisperModel
class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import WhisperModel
......@@ -441,9 +442,6 @@ class WhisperModelPolicy(WhisperPolicy):
# WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import WhisperForConditionalGeneration
......@@ -502,9 +500,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def preprocess(self):
return self.model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment