"docs/vscode:/vscode.git/clone" did not exist on "fad855cdb4910ddc44730de0a2bb39a3e9f6c4a6"
Commit 0760822b authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressing the comments

parent bf1da27e
......@@ -10,7 +10,6 @@ from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
from megatron import get_args
from megatron.core.utils import make_viewless_tensor
try:
......@@ -59,9 +58,17 @@ class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False):
sequence_parallel=False,
apply_layernorm_1p=False):
super(MixedFusedLayerNorm, self).__init__()
self.apply_layernorm_1p = False
if apply_layernorm_1p:
self.weight_adjustment = 1
self.apply_layernorm_1p = True
else:
self.weight_adjustment = 0
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module(
"fused_mix_prec_layer_norm_cuda")
......@@ -89,23 +96,23 @@ class MixedFusedLayerNorm(torch.nn.Module):
# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
args = get_args()
self.weight_adjustment = 0
if args.apply_layernorm_1p:
self.weight_adjustment = 1
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
if self.apply_layernorm_1p:
init.zeros_(self.weight)
init.zeros_(self.bias)
else:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply(
input, self.weight + self.weight_adjustment, self.bias, self.normalized_shape, self.eps)
input, self.weight + self.weight_adjustment, \
self.bias, self.normalized_shape, self.eps)
else:
output = FastLayerNormFN.apply(
input, self.weight + self.weight_adjustment, self.bias, self.eps)
......
......@@ -10,11 +10,11 @@ from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.model import LayerNorm
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
from megatron.model import LayerNorm
try:
from einops import rearrange
......@@ -635,8 +635,10 @@ class ParallelTransformerLayer(MegatronModule):
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
#if args.apply_layernorm_1p:
# from megatron.model import LayerNorm1P as LayerNorm
apply_layernorm_1p = False
if args.apply_layernorm_1p:
apply_layernorm_1p = True
#from megatron.model import LayerNorm1P as LayerNorm
#else:
# from megatron.model import LayerNorm
......@@ -645,7 +647,8 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p)
# Self attention.
self.self_attention = ParallelAttention(
......@@ -663,7 +666,8 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
......@@ -676,7 +680,8 @@ class ParallelTransformerLayer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p)
# MLP
if args.num_experts is not None:
......@@ -1025,8 +1030,10 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
#if args.apply_layernorm_1p:
# from megatron.model import LayerNorm1P as LayerNorm
apply_layernorm_1p = False
if args.apply_layernorm_1p:
apply_layernorm_1p = True
#from megatron.model import LayerNorm1P as LayerNorm
#else:
# from megatron.model import LayerNorm
......@@ -1036,7 +1043,8 @@ class ParallelTransformer(MegatronModule):
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p)
def _get_layer(self, layer_number):
return self.layers[layer_number]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment