Commit c5be45ee authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

Merge branch 'main' into lmcafee/distrib-opt-nodupe

parents b42c3052 15f6bb1b
......@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
from megatron.mpu import make_viewless_tensor
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
......@@ -106,6 +108,15 @@ class MixedFusedLayerNorm(torch.nn.Module):
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
return FastLayerNormFN.apply(
output = FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output = make_viewless_tensor(inp = output,
requires_grad = input.requires_grad,
keep_graph = True)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment