You need to sign in or sign up before continuing.
Commit bf1da27e authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressing comments

parent 5e079c87
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from .fused_layer_norm import MixedFusedLayerNorm1P as LayerNorm1P #from .fused_layer_norm import MixedFusedLayerNorm1P as LayerNorm1P
from .distributed import DistributedDataParallel from .distributed import DistributedDataParallel
from .bert_model import BertModel from .bert_model import BertModel
......
...@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter ...@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
from megatron import get_args
from megatron.core.utils import make_viewless_tensor from megatron.core.utils import make_viewless_tensor
try: try:
...@@ -89,6 +90,10 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -89,6 +90,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
setattr(self.weight, 'sequence_parallel', self.sequence_parallel) setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel) setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
args = get_args()
self.weight_adjustment = 0
if args.apply_layernorm_1p:
self.weight_adjustment = 1
def reset_parameters(self): def reset_parameters(self):
...@@ -100,10 +105,10 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -100,10 +105,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
if self.no_persist_layer_norm: if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply( return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps) input, self.weight + self.weight_adjustment, self.bias, self.normalized_shape, self.eps)
else: else:
output = FastLayerNormFN.apply( output = FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps) input, self.weight + self.weight_adjustment, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's # a populated '_base' field). This will result in schedule.py's
...@@ -117,26 +122,26 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -117,26 +122,26 @@ class MixedFusedLayerNorm(torch.nn.Module):
class MixedFusedLayerNorm1P(MixedFusedLayerNorm): #class MixedFusedLayerNorm1P(MixedFusedLayerNorm):
def reset_parameters(self): # def reset_parameters(self):
init.zeros_(self.weight) # init.zeros_(self.weight)
init.zeros_(self.bias) # init.zeros_(self.bias)
#
def forward(self, input): # def forward(self, input):
#
if self.no_persist_layer_norm: # if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply( # return FusedLayerNormAffineFunction.apply(
input, self.weight + 1, self.bias, self.normalized_shape, self.eps) # input, self.weight + 1, self.bias, self.normalized_shape, self.eps)
else: # else:
output = FastLayerNormFN.apply( # output = FastLayerNormFN.apply(
input, self.weight + 1, self.bias, self.eps) # input, self.weight + 1, self.bias, self.eps)
#
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has # # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's # # a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is # # deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this. # # created to prevent this.
output = make_viewless_tensor(inp = output, # output = make_viewless_tensor(inp = output,
requires_grad = input.requires_grad, # requires_grad = input.requires_grad,
keep_graph = True) # keep_graph = True)
#
return output # return output
...@@ -14,6 +14,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType ...@@ -14,6 +14,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
from megatron.model import LayerNorm
try: try:
from einops import rearrange from einops import rearrange
...@@ -634,10 +635,10 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -634,10 +635,10 @@ class ParallelTransformerLayer(MegatronModule):
self.bf16 = args.bf16 self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
if args.apply_layernorm_1p: #if args.apply_layernorm_1p:
from megatron.model import LayerNorm1P as LayerNorm # from megatron.model import LayerNorm1P as LayerNorm
else: #else:
from megatron.model import LayerNorm # from megatron.model import LayerNorm
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
...@@ -1024,10 +1025,10 @@ class ParallelTransformer(MegatronModule): ...@@ -1024,10 +1025,10 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
if args.apply_layernorm_1p: #if args.apply_layernorm_1p:
from megatron.model import LayerNorm1P as LayerNorm # from megatron.model import LayerNorm1P as LayerNorm
else: #else:
from megatron.model import LayerNorm # from megatron.model import LayerNorm
if self.post_process and self.post_layer_norm: if self.post_process and self.post_layer_norm:
# Final layer norm before output. # Final layer norm before output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment