Commit d33a44e8 authored by Frank Lee's avatar Frank Lee
Browse files

[shardformer] refactored layernorm (#4086)

parent c4b1b659
from .dropout import Dropout1D from .dropout import Dropout1D
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, VocabParallelEmbedding1D
from .layernorm import LayerNorm1D from .layernorm import FusedLayerNorm
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row
from .linear_conv import LinearConv1D_Col, LinearConv1D_Row from .linear_conv import LinearConv1D_Col, LinearConv1D_Row
from .loss import cross_entropy_1d from .loss import cross_entropy_1d
__all__ = [ __all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row",
"Dropout1D", "cross_entropy_1d", 'LayerNorm1D' "Dropout1D", "cross_entropy_1d", 'FusedLayerNorm'
] ]
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import List, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.kernel import LayerNorm
from colossalai.nn import init as init
from .parallel_module import ParallelModule __all__ = ['FusedLayerNorm']
__all__ = ['LayerNorm1D'] FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576,
25600, 30720, 32768, 40960, 49152, 65536
]
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
class FusedLayerNorm():
class LayerNorm1D(ParallelModule):
r""" r"""
Layer Normalization for colossalai This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
""" """
_fast_ln_supported_sizes = [ def __init__(self) -> None:
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, raise NotImplementedError(
24576, 25600, 30720, 32768, 40960, 49152, 65536 'FusedLayerNorm is not implemented as a physical class. '
] 'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.'
)
def __init__(self,
normalized_shape: int,
eps: int = 1e-05,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype)
self.norm = norm
@staticmethod @staticmethod
def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
**kwargs) -> ParallelModule:
r""" r"""
Convert a native pytorch layer norm module to colossalai layer norm module Convert a native pytorch layer norm module to colossalai layer norm module
""" """
# check if apex is installed
try:
import apex
except ImportError:
raise ImportError(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
# get the attributes of the module
normalized_shape = module.normalized_shape normalized_shape = module.normalized_shape
eps = module.eps eps = module.eps
bias = module.bias is not None elementwise_affine = module.elementwise_affine
dtype = module.weight.dtype dtype = module.weight.dtype
device = module.weight.device device = module.weight.device
# ensure only one process group is passed # pick the suitable layernorm implementation
if isinstance(process_group, (list, tuple)): use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.' if use_fast_ln:
process_group = process_group[0] try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm
except ImportError:
# fall back to the normal fused layernorm is not built
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
else:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
# create layer norm layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps,
layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm elementwise_affine=elementwise_affine).to(dtype).to(device)
with torch.no_grad(): with torch.no_grad():
# copy weight and bias # copy weight and bias
layer_norm.weight.copy_(module.weight) layernorm.weight.copy_(module.weight)
if bias: layernorm.bias.copy_(module.bias)
layer_norm.bias.copy_(module.bias) return layernorm
return layer_norm
...@@ -103,17 +103,17 @@ class BertPolicy(Policy): ...@@ -103,17 +103,17 @@ class BertPolicy(Policy):
base_policy[BertLayer].sub_module_replacement.append( base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.LayerNorm", suffix="attention.output.LayerNorm",
target_module=col_nn.LayerNorm1D, target_module=col_nn.FusedLayerNorm,
)) ))
base_policy[BertLayer].sub_module_replacement.append( base_policy[BertLayer].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.LayerNorm", suffix="output.LayerNorm",
target_module=col_nn.LayerNorm1D, target_module=col_nn.FusedLayerNorm,
)) ))
base_policy[BertEmbeddings].sub_module_replacement.append( base_policy[BertEmbeddings].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="LayerNorm", suffix="LayerNorm",
target_module=col_nn.LayerNorm1D, target_module=col_nn.FusedLayerNorm,
),) ),)
return base_policy return base_policy
...@@ -154,7 +154,7 @@ class BertForPretrainingPolicy(BertPolicy): ...@@ -154,7 +154,7 @@ class BertForPretrainingPolicy(BertPolicy):
addon_module[BertLMPredictionHead].sub_module_replacement.append( addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="transform.LayerNorm", suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D, target_module=col_nn.FusedLayerNorm,
)) ))
module_policy.update(addon_module) module_policy.update(addon_module)
return module_policy return module_policy
...@@ -191,7 +191,7 @@ class BertLMHeadModelPolicy(BertPolicy): ...@@ -191,7 +191,7 @@ class BertLMHeadModelPolicy(BertPolicy):
addon_module[BertLMPredictionHead].sub_module_replacement.append( addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="transform.LayerNorm", suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D, target_module=col_nn.FusedLayerNorm,
)) ))
module_policy.update(addon_module) module_policy.update(addon_module)
return module_policy return module_policy
...@@ -228,7 +228,7 @@ class BertForMaskedLMPolicy(BertPolicy): ...@@ -228,7 +228,7 @@ class BertForMaskedLMPolicy(BertPolicy):
addon_module[BertLMPredictionHead].sub_module_replacement.append( addon_module[BertLMPredictionHead].sub_module_replacement.append(
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="transform.LayerNorm", suffix="transform.LayerNorm",
target_module=col_nn.LayerNorm1D, target_module=col_nn.FusedLayerNorm,
)) ))
module_policy.update(addon_module) module_policy.update(addon_module)
return module_policy return module_policy
......
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.shardformer.layer import LayerNorm1D from colossalai.shardformer.layer import FusedLayerNorm
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_layernorm_1d(): def check_layernorm():
norm = nn.LayerNorm(128, 0.00001).cuda() norm = nn.LayerNorm(128, 0.00001).cuda()
norm1d = LayerNorm1D.from_native_module(norm, process_group=None) norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
assert norm1d.weight.shape == torch.Size([128]) assert norm1d.weight.shape == torch.Size([128])
...@@ -33,11 +32,11 @@ def check_layernorm_1d(): ...@@ -33,11 +32,11 @@ def check_layernorm_1d():
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_layernorm_1d() check_layernorm()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_layernorm_1d(): def test_layernorm():
spawn(run_dist, nprocs=2) spawn(run_dist, nprocs=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment