Commit 5e079c87 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

layernorm1p added

parent 035cae2e
...@@ -514,6 +514,8 @@ def _add_network_size_args(parser): ...@@ -514,6 +514,8 @@ def _add_network_size_args(parser):
'This is added for computational efficieny reasons.') 'This is added for computational efficieny reasons.')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5, group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.') help='Layer norm epsilon.')
group.add_argument('--apply-layernorm-1p', action='store_true',
help='Use layernorm 1p')
group.add_argument('--apply-residual-connection-post-layernorm', group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true', action='store_true',
help='If set, use original BERT residula connection ' help='If set, use original BERT residula connection '
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from .fused_layer_norm import MixedFusedLayerNorm1P as LayerNorm1P
from .distributed import DistributedDataParallel from .distributed import DistributedDataParallel
from .bert_model import BertModel from .bert_model import BertModel
......
...@@ -114,3 +114,29 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -114,3 +114,29 @@ class MixedFusedLayerNorm(torch.nn.Module):
keep_graph = True) keep_graph = True)
return output return output
class MixedFusedLayerNorm1P(MixedFusedLayerNorm):
def reset_parameters(self):
init.zeros_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply(
input, self.weight + 1, self.bias, self.normalized_shape, self.eps)
else:
output = FastLayerNormFN.apply(
input, self.weight + 1, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output = make_viewless_tensor(inp = output,
requires_grad = input.requires_grad,
keep_graph = True)
return output
...@@ -8,7 +8,7 @@ from megatron import get_args ...@@ -8,7 +8,7 @@ from megatron import get_args
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model.transformer import LayerNorm from megatron.model import LayerNorm
from megatron.model.utils import ( from megatron.model.utils import (
openai_gelu, openai_gelu,
get_linear_layer, get_linear_layer,
......
...@@ -11,7 +11,6 @@ from .module import MegatronModule ...@@ -11,7 +11,6 @@ from .module import MegatronModule
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
...@@ -635,6 +634,11 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -635,6 +634,11 @@ class ParallelTransformerLayer(MegatronModule):
self.bf16 = args.bf16 self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
if args.apply_layernorm_1p:
from megatron.model import LayerNorm1P as LayerNorm
else:
from megatron.model import LayerNorm
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
...@@ -1020,6 +1024,11 @@ class ParallelTransformer(MegatronModule): ...@@ -1020,6 +1024,11 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
if args.apply_layernorm_1p:
from megatron.model import LayerNorm1P as LayerNorm
else:
from megatron.model import LayerNorm
if self.post_process and self.post_layer_norm: if self.post_process and self.post_layer_norm:
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment