Commit 806422e5 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

partially cleaned

parent 05042081
......@@ -27,10 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
......@@ -700,7 +696,10 @@ class ParallelTransformer(MegatronModule):
hidden_states = self.input_tensor
# Viewless tensor
hidden_states = make_viewless_tensor(
# >>>
assert hidden_states is not None, "rank == %d." % torch.distributed.get_rank()
# <<<
hidden_states = mpu.make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
......
......@@ -65,6 +65,9 @@ from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .random import make_viewless_tensor
from .random import assert_viewless_tensor
from .random import safely_set_viewless_tensor_data
from .utils import divide
from .utils import split_tensor_along_last_dim
......@@ -98,34 +98,12 @@ def gather_split_1d_tensor(tensor):
group=get_tensor_model_parallel_group())
return gathered
# >>>
from lutil import pax # ****************
# def make_standalone_tensor(a):
# assert a._base is not None
# b = torch.empty((1,), dtype = a.dtype, device = a.device)
# b.data = a.data
# return b
# class MakeStandaloneTensor(torch.autograd.Function):
# class MakeViewlessTensor_(torch.autograd.Function):
class MakeViewlessTensor(torch.autograd.Function):
# @staticmethod
# def forward(ctx, inp):
# assert inp._base is not None
# out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
# out.data = inp.data
# # pax(0, {"inp": inp, "out": out})
# return out
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
# @staticmethod
# def forward(ctx, args):
# return [_kernel_make_viewless_tensor(*args)]
@staticmethod
def backward(ctx, grad_output):
# pax(0, {"grad_output": grad_output})
# return grad_output
return grad_output, None
def _kernel_make_viewless_tensor(inp, requires_grad):
......@@ -136,17 +114,8 @@ def _kernel_make_viewless_tensor(inp, requires_grad):
requires_grad = requires_grad,
)
out.data = inp.data
# >>>
# pax(0, {"inp": inp, "out": out})
# assert out.requires_grad
# <<<
return out
# def make_viewless_tensor(tensor):
# if tensor._base is None:
# return tensor
# else:
# return MakeViewlessTensor_.apply(tensor)
def make_viewless_tensor(inp, requires_grad, keep_graph):
# return tensor as-is, if not a 'view'
......@@ -155,36 +124,27 @@ def make_viewless_tensor(inp, requires_grad, keep_graph):
# create viewless tensor
if keep_graph:
# return MakeViewlessTensor.apply((inp, requires_grad))[0]
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
# return MakeViewlessTensor.apply((inp, requires_grad))[0]
# return MakeViewlessTensor.apply(inp, requires_grad)
# return MakeViewlessTensor.apply(inp)
# return MakeViewlessTensor.apply(inp, 7)
# return MakeViewlessTensor.apply(inp, 7)[0]
def assert_viewless_tensor(tensor, extra_msg = None):
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return
# assert isinstance(tensor, torch.Tensor), \
# "expected Tensor; found %s." % type(tensor).__name__
return tensor
if not isinstance(tensor, torch.Tensor):
return
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
# def set_viewless_tensor_data_attr(tensor, new_data_tensor):
def safely_set_tensor_data_attr(tensor, new_data_tensor):
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
# <<<
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......@@ -328,19 +288,10 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_checkpointed_activations:
# >>>
# raise Exception("distrib.")
# from lutil import data_leak_ctx
# with data_leak_ctx(args[0]):
# <<<
ctx.input_0_shape = args[0].data.shape
# >>>
# args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
# new_buffer=True)
safely_set_tensor_data_attr(
safely_set_viewless_tensor_data(
args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
# <<<
# Store everything.
ctx.save_for_backward(*args)
......@@ -357,12 +308,15 @@ class CheckpointFunction(torch.autograd.Function):
# >>>
# inputs[0].data = gather_split_1d_tensor(inputs[0].data)
# inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
safely_set_tensor_data_attr(
inputs[0],
gather_split_1d_tensor(inputs[0].data))
safely_set_tensor_data_attr(
# safely_set_tensor_data_attr(
# inputs[0],
# gather_split_1d_tensor(inputs[0].data))
# safely_set_tensor_data_attr(
# inputs[0],
# inputs[0].data.view(ctx.input_0_shape))
safely_set_viewless_tensor_data(
inputs[0],
inputs[0].data.view(ctx.input_0_shape))
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
# <<<
# Store the current states.
......
......@@ -20,9 +20,6 @@ import torch
from megatron import get_args
from megatron import mpu
# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape,
......@@ -145,16 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
return tensor_recv_prev, tensor_recv_next
......
......@@ -29,7 +29,7 @@ from megatron.model import Float16Module
from megatron.model import ModelType
# >>>
from megatron.mpu.random import assert_viewless_tensor
# from megatron.mpu.random import assert_viewless_tensor
# <<<
def get_forward_backward_func():
......@@ -115,17 +115,7 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor)
# >>>
# if input_tensor[0] is not None:
# from lutil import pax, tp
# pax({"input_tensor": tp(input_tensor)})
# <<<
output_tensor, loss_func = forward_step_func(data_iterator, model)
# >>>
# if input_tensor[0] is not None:
# from lutil import pax, tp
# pax({"input_tensor": tp(input_tensor)})
# <<<
if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
......@@ -530,7 +520,6 @@ def recv_forward(tensor_shapes, timers):
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape,
timers=timers))
assert_viewless_tensor(input_tensors[-1])
return input_tensors
......@@ -542,7 +531,6 @@ def recv_backward(tensor_shapes, timers):
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
timers=timers))
assert_viewless_tensor(output_tensor_grads[-1])
return output_tensor_grads
......@@ -575,7 +563,6 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, timers=timers)
output_tensor_grads.append(output_tensor_grad)
assert_viewless_tensor(output_tensor_grad)
return output_tensor_grads
......@@ -590,7 +577,6 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, timers=timers)
input_tensors.append(input_tensor)
assert_viewless_tensor(input_tensor)
return input_tensors
......@@ -636,33 +622,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
# >>>
# if input_tensor[0] is not None:
# from lutil import pax
# pax({"input_tensor": input_tensor})
# <<<
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
# >>>
# if True or input_tensor[0] is not None:
# from lutil import pax
# pax({"input_tensor": input_tensor})
# <<<
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
# >>>
# if input_tensor[0] is not None:
# from lutil import pax
# pax({"input_tensor": input_tensor})
# if output_tensor[0] is not None:
# from lutil import pax
# pax(0, {"output_tensor / 0": output_tensor[0]})
# <<<
assert_viewless_tensor(input_tensor)
assert_viewless_tensor(output_tensor)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensors.append(mpu.assert_viewless_tensor(input_tensor))
output_tensors.append(mpu.assert_viewless_tensor(output_tensor))
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
# Before running 1F1B, need to receive first forward tensor.
......@@ -690,10 +656,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers=timers)
# Add input_tensor and output_tensor to end of list.
assert_viewless_tensor(input_tensor)
assert_viewless_tensor(output_tensor)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensors.append(mpu.assert_viewless_tensor(input_tensor))
output_tensors.append(mpu.assert_viewless_tensor(output_tensor))
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
# Pop input_tensor and output_tensor from the start of the list for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment