Commit 93bed794 authored by Jared Casper's avatar Jared Casper
Browse files

A bit of cleanup.

parent 5ed304e4
......@@ -62,12 +62,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
apply_layernorm_1p=False):
super(MixedFusedLayerNorm, self).__init__()
self.apply_layernorm_1p = False
if apply_layernorm_1p:
self.weight_adjustment = 1
self.apply_layernorm_1p = True
else:
self.weight_adjustment = 0
self.apply_layernorm_1p = apply_layernorm_1p
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module(
......@@ -109,13 +104,12 @@ class MixedFusedLayerNorm(torch.nn.Module):
def forward(self, input):
weight = self.weight + 1 if self.apply_layernorm_1p else self.weight
if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply(
input, self.weight + self.weight_adjustment, \
self.bias, self.normalized_shape, self.eps)
return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
else:
output = FastLayerNormFN.apply(
input, self.weight + self.weight_adjustment, self.bias, self.eps)
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
......
......@@ -635,17 +635,13 @@ class ParallelTransformerLayer(MegatronModule):
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
apply_layernorm_1p = False
if args.apply_layernorm_1p:
apply_layernorm_1p = True
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p)
apply_layernorm_1p=args.apply_layernorm_1p)
# Self attention.
self.self_attention = ParallelAttention(
......@@ -664,7 +660,7 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p)
apply_layernorm_1p=args.apply_layernorm_1p)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
......@@ -678,7 +674,7 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p)
apply_layernorm_1p=args.apply_layernorm_1p)
# MLP
if args.num_experts is not None:
......@@ -1027,10 +1023,6 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
apply_layernorm_1p = False
if args.apply_layernorm_1p:
apply_layernorm_1p = True
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
......@@ -1038,7 +1030,7 @@ class ParallelTransformer(MegatronModule):
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel,
apply_layernorm_1p=apply_layernorm_1p)
apply_layernorm_1p=args.apply_layernorm_1p)
def _get_layer(self, layer_number):
return self.layers[layer_number]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment