Commit ed0c8714 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

more iterating on 'viewless tensor' methods

parent 5422d23a
......@@ -27,6 +27,9 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<
""" We use the following notation throughout this file:
h: hidden size
......@@ -696,19 +699,14 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states = self.input_tensor
# >>>
def make_standalone_tensor(a):
assert a._base is not None
b = torch.empty((1,), dtype = a.dtype, device = a.device)
b.data = a.data
return b
# <<<
# hidden_states = make_standalone_tensor(hidden_states)
hidden_states = hidden_states.clone()
# hidden_states = MakeStandaloneTensor.apply(hidden_states)
# hidden_states = MakeViewlessTensor.apply(hidden_states)
hidden_states = make_viewless_tensor(hidden_states)
# hidden_states = hidden_states.clone()
# >>>
# from lutil import pax
# pax({"hidden_states": hidden_states})
# pax(0, {"hidden_states": hidden_states})
# <<<
if encoder_output is not None:
......
......@@ -98,13 +98,54 @@ def gather_split_1d_tensor(tensor):
group=get_tensor_model_parallel_group())
return gathered
def safely_set_tensor_data_attr(tensor, new_data_tensor):
# >>>
# from lutil import pax
# def make_standalone_tensor(a):
# assert a._base is not None
# b = torch.empty((1,), dtype = a.dtype, device = a.device)
# b.data = a.data
# return b
# class MakeStandaloneTensor(torch.autograd.Function):
class MakeViewlessTensor_(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
assert inp._base is not None
out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
out.data = inp.data
# pax(0, {"inp": inp, "out": out})
return out
@staticmethod
def backward(ctx, grad_output):
# pax(0, {"grad_output": grad_output})
return grad_output
def make_viewless_tensor(tensor):
if tensor._base is None:
return tensor
else:
return MakeViewlessTensor_.apply(tensor)
def assert_viewless_tensor(tensor):
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return
# assert isinstance(tensor, torch.Tensor), \
# "expected Tensor; found %s." % type(tensor).__name__
if not isinstance(tensor, torch.Tensor):
return
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data. Otherwise, "
"a memory leak will occur (and likely accumulate over iterations). "
"FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). FYI, tensor._base has shape "
"%s, and new_data_tensor has shape %s."
) % (tensor._base.shape, new_data_tensor.shape)
# def set_viewless_tensor_data_attr(tensor, new_data_tensor):
def safely_set_tensor_data_attr(tensor, new_data_tensor):
assert_viewless_tensor(tensor)
tensor.data = new_data_tensor
# <<<
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......@@ -253,11 +294,13 @@ class CheckpointFunction(torch.autograd.Function):
# with data_leak_ctx(args[0]):
# <<<
ctx.input_0_shape = args[0].data.shape
# >>>
# args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
# new_buffer=True)
safely_set_tensor_data_attr(
args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
# <<<
# Store everything.
ctx.save_for_backward(*args)
......@@ -271,8 +314,16 @@ class CheckpointFunction(torch.autograd.Function):
"please use .backward() if possible")
inputs = ctx.saved_tensors
if ctx.distribute_checkpointed_activations:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
# >>>
# inputs[0].data = gather_split_1d_tensor(inputs[0].data)
# inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
safely_set_tensor_data_attr(
inputs[0],
gather_split_1d_tensor(inputs[0].data))
safely_set_tensor_data_attr(
inputs[0],
inputs[0].data.view(ctx.input_0_shape))
# <<<
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
......
......@@ -20,6 +20,9 @@ import torch
from megatron import get_args
from megatron import mpu
# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape,
......@@ -142,10 +145,12 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = make_viewless_tensor(tensor_recv_prev)
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = make_viewless_tensor(tensor_recv_next)
return tensor_recv_prev, tensor_recv_next
......
......@@ -28,6 +28,10 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
# >>>
from megatron.mpu.random import assert_viewless_tensor
# <<<
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
......@@ -306,6 +310,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
model[model_chunk_id],
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
assert_viewless_tensor(output_tensor)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
......@@ -339,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape, timers=timers))
assert_viewless_tensor(input_tensors[0][-1])
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
......@@ -370,6 +376,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape=tensor_shape,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
assert_viewless_tensor(output_tensor_grad)
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(
......@@ -378,6 +385,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
timers=timers)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
assert_viewless_tensor(input_tensor)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
......@@ -447,15 +455,18 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
assert_viewless_tensor(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(
output_tensor_grad)
assert_viewless_tensor(output_tensor_grad)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(tensor_shape, timers=timers))
assert_viewless_tensor(output_tensor_grads[num_model_chunks-1][-1])
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
......@@ -470,6 +481,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor_grad, recv_next=recv_next,
tensor_shape=tensor_shape,
timers=timers))
assert_viewless_tensor(output_tensor_grads[next_backward_model_chunk_id][-1])
return losses_reduced
......@@ -508,6 +520,7 @@ def recv_forward(tensor_shapes, timers):
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape,
timers=timers))
assert_viewless_tensor(input_tensors[-1])
return input_tensors
......@@ -519,6 +532,7 @@ def recv_backward(tensor_shapes, timers):
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
timers=timers))
assert_viewless_tensor(output_tensor_grads[-1])
return output_tensor_grads
......@@ -551,6 +565,7 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, timers=timers)
output_tensor_grads.append(output_tensor_grad)
assert_viewless_tensor(output_tensor_grad)
return output_tensor_grads
......@@ -565,6 +580,7 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, timers=timers)
input_tensors.append(input_tensor)
assert_viewless_tensor(input_tensor)
return input_tensors
......@@ -615,6 +631,15 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
# >>>
if input_tensor[0] is not None:
from lutil import pax
pax({
"input_tensor" : input_tensor,
})
# <<<
assert_viewless_tensor(input_tensor)
assert_viewless_tensor(output_tensor)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
......@@ -644,6 +669,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers=timers)
# Add input_tensor and output_tensor to end of list.
# >>>
# assert input_tensor[0]._base is None, \
# "rank %s; uh oh." % torch.distributed.get_rank()
# if input_tensor[0] is not None:
# from lutil import pax
# pax(4, {
# "input_tensor[0]" : input_tensor[0],
# })
# <<<
assert_viewless_tensor(input_tensor)
assert_viewless_tensor(output_tensor)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment