Commit 4e891fe9 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'next-best-lm/merge-layernorm-1p-main' into 'main'

layernorm1p added

See merge request ADLR/megatron-lm!557
parents 7bd25e26 33a58153
...@@ -534,6 +534,9 @@ def _add_network_size_args(parser): ...@@ -534,6 +534,9 @@ def _add_network_size_args(parser):
'This is added for computational efficieny reasons.') 'This is added for computational efficieny reasons.')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5, group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.') help='Layer norm epsilon.')
group.add_argument('--apply-layernorm-1p', action='store_true',
help='Adjust LayerNorm weights such that they are centered '
'around zero. This improves numerical stability.')
group.add_argument('--apply-residual-connection-post-layernorm', group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true', action='store_true',
help='If set, use original BERT residula connection ' help='If set, use original BERT residula connection '
......
...@@ -58,9 +58,12 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -58,9 +58,12 @@ class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, def __init__(self, normalized_shape, eps=1e-5,
no_persist_layer_norm=True, no_persist_layer_norm=True,
sequence_parallel=False): sequence_parallel=False,
apply_layernorm_1p=False):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
self.apply_layernorm_1p = apply_layernorm_1p
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module( fused_mix_prec_layer_norm_cuda = importlib.import_module(
"fused_mix_prec_layer_norm_cuda") "fused_mix_prec_layer_norm_cuda")
...@@ -92,18 +95,21 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -92,18 +95,21 @@ class MixedFusedLayerNorm(torch.nn.Module):
def reset_parameters(self): def reset_parameters(self):
init.ones_(self.weight) if self.apply_layernorm_1p:
init.zeros_(self.bias) init.zeros_(self.weight)
init.zeros_(self.bias)
else:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input): def forward(self, input):
weight = self.weight + 1 if self.apply_layernorm_1p else self.weight
if self.no_persist_layer_norm: if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply( return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
input, self.weight, self.bias, self.normalized_shape, self.eps)
else: else:
output = FastLayerNormFN.apply( output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
input, self.weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's # a populated '_base' field). This will result in schedule.py's
......
...@@ -8,7 +8,7 @@ from megatron import get_args ...@@ -8,7 +8,7 @@ from megatron import get_args
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model.transformer import LayerNorm from megatron.model import LayerNorm
from megatron.model.utils import ( from megatron.model.utils import (
openai_gelu, openai_gelu,
get_linear_layer, get_linear_layer,
......
...@@ -10,8 +10,8 @@ from megatron import get_timers, get_args, core, get_num_microbatches ...@@ -10,8 +10,8 @@ from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule from .module import MegatronModule
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb
...@@ -712,7 +712,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -712,7 +712,8 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel) sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
...@@ -730,7 +731,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -730,7 +731,8 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel) sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
...@@ -743,7 +745,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -743,7 +745,8 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel) sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
# MLP # MLP
if args.num_experts is not None: if args.num_experts is not None:
...@@ -1108,7 +1111,8 @@ class ParallelTransformer(MegatronModule): ...@@ -1108,7 +1111,8 @@ class ParallelTransformer(MegatronModule):
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel) sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment