Commit 93bed794 authored by Jared Casper's avatar Jared Casper
Browse files

A bit of cleanup.

parent 5ed304e4
...@@ -62,12 +62,7 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -62,12 +62,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
apply_layernorm_1p=False): apply_layernorm_1p=False):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
self.apply_layernorm_1p = False self.apply_layernorm_1p = apply_layernorm_1p
if apply_layernorm_1p:
self.weight_adjustment = 1
self.apply_layernorm_1p = True
else:
self.weight_adjustment = 0
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module( fused_mix_prec_layer_norm_cuda = importlib.import_module(
...@@ -109,13 +104,12 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -109,13 +104,12 @@ class MixedFusedLayerNorm(torch.nn.Module):
def forward(self, input): def forward(self, input):
weight = self.weight + 1 if self.apply_layernorm_1p else self.weight
if self.no_persist_layer_norm: if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply( return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
input, self.weight + self.weight_adjustment, \
self.bias, self.normalized_shape, self.eps)
else: else:
output = FastLayerNormFN.apply( output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
input, self.weight + self.weight_adjustment, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's # a populated '_base' field). This will result in schedule.py's
......
...@@ -635,17 +635,13 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -635,17 +635,13 @@ class ParallelTransformerLayer(MegatronModule):
self.bf16 = args.bf16 self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
apply_layernorm_1p = False
if args.apply_layernorm_1p:
apply_layernorm_1p = True
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel, sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p) apply_layernorm_1p=args.apply_layernorm_1p)
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
...@@ -664,7 +660,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -664,7 +660,7 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel, sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p) apply_layernorm_1p=args.apply_layernorm_1p)
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
...@@ -678,7 +674,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -678,7 +674,7 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel, sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p) apply_layernorm_1p=args.apply_layernorm_1p)
# MLP # MLP
if args.num_experts is not None: if args.num_experts is not None:
...@@ -1027,10 +1023,6 @@ class ParallelTransformer(MegatronModule): ...@@ -1027,10 +1023,6 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
apply_layernorm_1p = False
if args.apply_layernorm_1p:
apply_layernorm_1p = True
if self.post_process and self.post_layer_norm: if self.post_process and self.post_layer_norm:
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
...@@ -1038,7 +1030,7 @@ class ParallelTransformer(MegatronModule): ...@@ -1038,7 +1030,7 @@ class ParallelTransformer(MegatronModule):
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel, sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p) apply_layernorm_1p=args.apply_layernorm_1p)
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment