Commit 4c598f9d authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

debugging.

parent 53f3efc4
...@@ -100,12 +100,35 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -100,12 +100,35 @@ class MixedFusedLayerNorm(torch.nn.Module):
init.zeros_(self.bias) init.zeros_(self.bias)
# def forward(self, input):
# if self.no_persist_layer_norm:
# return FusedLayerNormAffineFunction.apply(
# input, self.weight, self.bias, self.normalized_shape, self.eps)
# else:
# return FastLayerNormFN.apply(
# input, self.weight, self.bias, self.eps)
def forward(self, input): def forward(self, input):
if self.no_persist_layer_norm: if self.no_persist_layer_norm:
return FusedLayerNormAffineFunction.apply( result = FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps) input, self.weight, self.bias, self.normalized_shape, self.eps)
else: else:
return FastLayerNormFN.apply( result = FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps) input, self.weight, self.bias, self.eps)
result = make_viewless_tensor(inp = input, requires_grad = input.requires_grad, keep_grad = True)
# >>>
# if torch.distributed.get_rank() == 3:
# # from lutil import pax
# # pax({"result": result})
# from megatron import get_args
# args = get_args()
# raise Exception("r %d ... hid %d, persist %d, view %d." % (
# torch.distributed.get_rank(),
# args.hidden_size,
# not args.no_persist_layer_norm,
# result._base is not None,
# ))
# <<<
return result
...@@ -337,6 +337,12 @@ class TransformerLanguageModel(MegatronModule): ...@@ -337,6 +337,12 @@ class TransformerLanguageModel(MegatronModule):
else: else:
self.encoder = None self.encoder = None
# >>>
# if torch.distributed.get_rank() == 3:
# print(self.encoder)
# raise Exception("bye.")
# <<<
# Decoder (usually set to False, True if part of an encoder-decoder # Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage). # architecture and in decoder-only stage).
if self.add_decoder: if self.add_decoder:
......
...@@ -651,6 +651,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, ...@@ -651,6 +651,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
# >>>
if output_tensor[0]._base is not None:
# from lutil import pax
# pax({
# "output tensor / 0" : output_tensor[0],
# })
raise Exception(">>>>>> r %d, output / 0 == view." % torch.distributed.get_rank())
# <<<
deallocate_output_tensor(output_tensor[0]) deallocate_output_tensor(output_tensor[0])
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment