Commit 270d6412 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

loss matches; memory savings for multi-node (tested n3, n16)

parent b6d4369b
......@@ -699,15 +699,12 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states = self.input_tensor
# hidden_states = make_standalone_tensor(hidden_states)
# hidden_states = MakeStandaloneTensor.apply(hidden_states)
# hidden_states = MakeViewlessTensor.apply(hidden_states)
hidden_states = make_viewless_tensor(hidden_states)
# hidden_states = hidden_states.clone()
# >>>
# from lutil import pax
# pax(0, {"hidden_states": hidden_states})
# <<<
# Viewless tensor
hidden_states = make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
)
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
......
......@@ -99,7 +99,7 @@ def gather_split_1d_tensor(tensor):
return gathered
# >>>
# from lutil import pax
from lutil import pax # ****************
# def make_standalone_tensor(a):
# assert a._base is not None
......@@ -107,26 +107,66 @@ def gather_split_1d_tensor(tensor):
# b.data = a.data
# return b
# class MakeStandaloneTensor(torch.autograd.Function):
class MakeViewlessTensor_(torch.autograd.Function):
# class MakeViewlessTensor_(torch.autograd.Function):
class MakeViewlessTensor(torch.autograd.Function):
# @staticmethod
# def forward(ctx, inp):
# assert inp._base is not None
# out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
# out.data = inp.data
# # pax(0, {"inp": inp, "out": out})
# return out
@staticmethod
def forward(ctx, inp):
assert inp._base is not None
out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
out.data = inp.data
# pax(0, {"inp": inp, "out": out})
return out
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
# @staticmethod
# def forward(ctx, args):
# return [_kernel_make_viewless_tensor(*args)]
@staticmethod
def backward(ctx, grad_output):
# pax(0, {"grad_output": grad_output})
return grad_output
# return grad_output
return grad_output, None
def _kernel_make_viewless_tensor(inp, requires_grad):
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
# >>>
# pax(0, {"inp": inp, "out": out})
# assert out.requires_grad
# <<<
return out
def make_viewless_tensor(tensor):
if tensor._base is None:
return tensor
# def make_viewless_tensor(tensor):
# if tensor._base is None:
# return tensor
# else:
# return MakeViewlessTensor_.apply(tensor)
def make_viewless_tensor(inp, requires_grad, keep_graph):
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
# return MakeViewlessTensor.apply((inp, requires_grad))[0]
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return MakeViewlessTensor_.apply(tensor)
return _kernel_make_viewless_tensor(inp, requires_grad)
# return MakeViewlessTensor.apply((inp, requires_grad))[0]
# return MakeViewlessTensor.apply(inp, requires_grad)
# return MakeViewlessTensor.apply(inp)
# return MakeViewlessTensor.apply(inp, 7)
# return MakeViewlessTensor.apply(inp, 7)[0]
def assert_viewless_tensor(tensor):
def assert_viewless_tensor(tensor, extra_msg = None):
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return
......@@ -137,13 +177,12 @@ def assert_viewless_tensor(tensor):
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). FYI, tensor._base has shape "
"%s, and new_data_tensor has shape %s."
) % (tensor._base.shape, new_data_tensor.shape)
"likely accumulate over iterations). %s"
) % extra_msg
# def set_viewless_tensor_data_attr(tensor, new_data_tensor):
def safely_set_tensor_data_attr(tensor, new_data_tensor):
assert_viewless_tensor(tensor)
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor
# <<<
......
......@@ -145,12 +145,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = make_viewless_tensor(tensor_recv_prev)
tensor_recv_prev = make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = make_viewless_tensor(tensor_recv_next)
tensor_recv_next = make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
return tensor_recv_prev, tensor_recv_next
......
......@@ -631,13 +631,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
# >>>
# if input_tensor[0] is not None:
# from lutil import pax
# pax({
# "input_tensor" : input_tensor,
# })
# <<<
assert_viewless_tensor(input_tensor)
assert_viewless_tensor(output_tensor)
input_tensors.append(input_tensor)
......@@ -669,15 +662,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers=timers)
# Add input_tensor and output_tensor to end of list.
# >>>
# assert input_tensor[0]._base is None, \
# "rank %s; uh oh." % torch.distributed.get_rank()
# if input_tensor[0] is not None:
# from lutil import pax
# pax(4, {
# "input_tensor[0]" : input_tensor[0],
# })
# <<<
assert_viewless_tensor(input_tensor)
assert_viewless_tensor(output_tensor)
input_tensors.append(input_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment