Commit 4e891fe9 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'next-best-lm/merge-layernorm-1p-main' into 'main'

layernorm1p added

See merge request ADLR/megatron-lm!557
parents 7bd25e26 33a58153
......@@ -534,6 +534,9 @@ def _add_network_size_args(parser):
'This is added for computational efficieny reasons.')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.')
group.add_argument('--apply-layernorm-1p', action='store_true',
help='Adjust LayerNorm weights such that they are centered '
'around zero. This improves numerical stability.')
group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true',
help='If set, use original BERT residula connection '
......
......@@ -58,9 +58,12 @@ class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False):
sequence_parallel=False,
apply_layernorm_1p=False):
super(MixedFusedLayerNorm, self).__init__()
self.apply_layernorm_1p = apply_layernorm_1p
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module(
"fused_mix_prec_layer_norm_cuda")
......@@ -92,18 +95,21 @@ class MixedFusedLayerNorm(torch.nn.Module):
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
if self.apply_layernorm_1p:
init.zeros_(self.weight)
init.zeros_(self.bias)
else:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
weight = self.weight + 1 if self.apply_layernorm_1p else self.weight
if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
else:
output = FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps)
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
......
......@@ -8,7 +8,7 @@ from megatron import get_args
from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model import LayerNorm
from megatron.model.utils import (
openai_gelu,
get_linear_layer,
......
......@@ -10,8 +10,8 @@ from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb
......@@ -712,7 +712,8 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
# Self attention.
self.self_attention = ParallelAttention(
......@@ -730,7 +731,8 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
......@@ -743,7 +745,8 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
# MLP
if args.num_experts is not None:
......@@ -1108,7 +1111,8 @@ class ParallelTransformer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
def _get_layer(self, layer_number):
return self.layers[layer_number]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment