Commit 65e6bc32 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed.

parent 4c598f9d
...@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter ...@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
from megatron.mpu import make_viewless_tensor
try: try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True HAVE_PERSIST_LAYER_NORM = True
...@@ -100,35 +102,21 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -100,35 +102,21 @@ class MixedFusedLayerNorm(torch.nn.Module):
init.zeros_(self.bias) init.zeros_(self.bias)
# def forward(self, input):
# if self.no_persist_layer_norm:
# return FusedLayerNormAffineFunction.apply(
# input, self.weight, self.bias, self.normalized_shape, self.eps)
# else:
# return FastLayerNormFN.apply(
# input, self.weight, self.bias, self.eps)
def forward(self, input): def forward(self, input):
if self.no_persist_layer_norm: if self.no_persist_layer_norm:
result = FusedLayerNormAffineFunction.apply( return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps) input, self.weight, self.bias, self.normalized_shape, self.eps)
else: else:
result = FastLayerNormFN.apply( output = FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps) input, self.weight, self.bias, self.eps)
result = make_viewless_tensor(inp = input, requires_grad = input.requires_grad, keep_grad = True)
# >>> # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# if torch.distributed.get_rank() == 3: # a populated '_base' field). This will result in schedule.py's
# # from lutil import pax # deallocate_output_tensor() throwing an error, so a viewless tensor is
# # pax({"result": result}) # created to prevent this.
# from megatron import get_args output = make_viewless_tensor(inp = output,
# args = get_args() requires_grad = input.requires_grad,
# raise Exception("r %d ... hid %d, persist %d, view %d." % ( keep_graph = True)
# torch.distributed.get_rank(),
# args.hidden_size, return output
# not args.no_persist_layer_norm,
# result._base is not None,
# ))
# <<<
return result
...@@ -337,12 +337,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -337,12 +337,6 @@ class TransformerLanguageModel(MegatronModule):
else: else:
self.encoder = None self.encoder = None
# >>>
# if torch.distributed.get_rank() == 3:
# print(self.encoder)
# raise Exception("bye.")
# <<<
# Decoder (usually set to False, True if part of an encoder-decoder # Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage). # architecture and in decoder-only stage).
if self.add_decoder: if self.add_decoder:
......
...@@ -651,14 +651,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -651,14 +651,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
# >>>
if output_tensor[0]._base is not None:
# from lutil import pax
# pax({
# "output tensor / 0" : output_tensor[0],
# })
raise Exception(">>>>>> r %d, output / 0 == view." % torch.distributed.get_rank())
# <<<
deallocate_output_tensor(output_tensor[0]) deallocate_output_tensor(output_tensor[0])
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment