Commit 806422e5 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

partially cleaned

parent 05042081
...@@ -27,10 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax ...@@ -27,10 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
n: number of attention heads n: number of attention heads
...@@ -700,7 +696,10 @@ class ParallelTransformer(MegatronModule): ...@@ -700,7 +696,10 @@ class ParallelTransformer(MegatronModule):
hidden_states = self.input_tensor hidden_states = self.input_tensor
# Viewless tensor # Viewless tensor
hidden_states = make_viewless_tensor( # >>>
assert hidden_states is not None, "rank == %d." % torch.distributed.get_rank()
# <<<
hidden_states = mpu.make_viewless_tensor(
hidden_states, hidden_states,
requires_grad = True, requires_grad = True,
keep_graph = True, keep_graph = True,
......
...@@ -65,6 +65,9 @@ from .random import get_cuda_rng_tracker ...@@ -65,6 +65,9 @@ from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .random import gather_split_1d_tensor from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks from .random import split_tensor_into_1d_equal_chunks
from .random import make_viewless_tensor
from .random import assert_viewless_tensor
from .random import safely_set_viewless_tensor_data
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -98,34 +98,12 @@ def gather_split_1d_tensor(tensor): ...@@ -98,34 +98,12 @@ def gather_split_1d_tensor(tensor):
group=get_tensor_model_parallel_group()) group=get_tensor_model_parallel_group())
return gathered return gathered
# >>>
from lutil import pax # ****************
# def make_standalone_tensor(a):
# assert a._base is not None
# b = torch.empty((1,), dtype = a.dtype, device = a.device)
# b.data = a.data
# return b
# class MakeStandaloneTensor(torch.autograd.Function):
# class MakeViewlessTensor_(torch.autograd.Function):
class MakeViewlessTensor(torch.autograd.Function): class MakeViewlessTensor(torch.autograd.Function):
# @staticmethod
# def forward(ctx, inp):
# assert inp._base is not None
# out = torch.empty((1,), dtype = inp.dtype, device = inp.device)
# out.data = inp.data
# # pax(0, {"inp": inp, "out": out})
# return out
@staticmethod @staticmethod
def forward(ctx, inp, requires_grad): def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad) return _kernel_make_viewless_tensor(inp, requires_grad)
# @staticmethod
# def forward(ctx, args):
# return [_kernel_make_viewless_tensor(*args)]
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# pax(0, {"grad_output": grad_output})
# return grad_output
return grad_output, None return grad_output, None
def _kernel_make_viewless_tensor(inp, requires_grad): def _kernel_make_viewless_tensor(inp, requires_grad):
...@@ -136,17 +114,8 @@ def _kernel_make_viewless_tensor(inp, requires_grad): ...@@ -136,17 +114,8 @@ def _kernel_make_viewless_tensor(inp, requires_grad):
requires_grad = requires_grad, requires_grad = requires_grad,
) )
out.data = inp.data out.data = inp.data
# >>>
# pax(0, {"inp": inp, "out": out})
# assert out.requires_grad
# <<<
return out return out
# def make_viewless_tensor(tensor):
# if tensor._base is None:
# return tensor
# else:
# return MakeViewlessTensor_.apply(tensor)
def make_viewless_tensor(inp, requires_grad, keep_graph): def make_viewless_tensor(inp, requires_grad, keep_graph):
# return tensor as-is, if not a 'view' # return tensor as-is, if not a 'view'
...@@ -155,36 +124,27 @@ def make_viewless_tensor(inp, requires_grad, keep_graph): ...@@ -155,36 +124,27 @@ def make_viewless_tensor(inp, requires_grad, keep_graph):
# create viewless tensor # create viewless tensor
if keep_graph: if keep_graph:
# return MakeViewlessTensor.apply((inp, requires_grad))[0]
return MakeViewlessTensor.apply(inp, requires_grad) return MakeViewlessTensor.apply(inp, requires_grad)
else: else:
return _kernel_make_viewless_tensor(inp, requires_grad) return _kernel_make_viewless_tensor(inp, requires_grad)
# return MakeViewlessTensor.apply((inp, requires_grad))[0]
# return MakeViewlessTensor.apply(inp, requires_grad)
# return MakeViewlessTensor.apply(inp)
# return MakeViewlessTensor.apply(inp, 7)
# return MakeViewlessTensor.apply(inp, 7)[0]
def assert_viewless_tensor(tensor, extra_msg = None): def assert_viewless_tensor(tensor, extra_msg = None):
if isinstance(tensor, list): if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ] [ assert_viewless_tensor(t) for t in tensor ]
return return tensor
# assert isinstance(tensor, torch.Tensor), \
# "expected Tensor; found %s." % type(tensor).__name__
if not isinstance(tensor, torch.Tensor): if not isinstance(tensor, torch.Tensor):
return return tensor
assert tensor._base is None, ( assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing " "Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and " "tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s" "likely accumulate over iterations). %s"
) % extra_msg ) % extra_msg
return tensor
# def set_viewless_tensor_data_attr(tensor, new_data_tensor): def safely_set_viewless_tensor_data(tensor, new_data_tensor):
def safely_set_tensor_data_attr(tensor, new_data_tensor):
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor tensor.data = new_data_tensor
# <<<
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states. """Tracker for the cuda RNG states.
...@@ -328,19 +288,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -328,19 +288,10 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep # Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank. # the chunk corresponding to the current rank.
if distribute_checkpointed_activations: if distribute_checkpointed_activations:
# >>>
# raise Exception("distrib.")
# from lutil import data_leak_ctx
# with data_leak_ctx(args[0]):
# <<<
ctx.input_0_shape = args[0].data.shape ctx.input_0_shape = args[0].data.shape
# >>> safely_set_viewless_tensor_data(
# args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
# new_buffer=True)
safely_set_tensor_data_attr(
args[0], args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)) split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
# <<<
# Store everything. # Store everything.
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
...@@ -357,12 +308,15 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -357,12 +308,15 @@ class CheckpointFunction(torch.autograd.Function):
# >>> # >>>
# inputs[0].data = gather_split_1d_tensor(inputs[0].data) # inputs[0].data = gather_split_1d_tensor(inputs[0].data)
# inputs[0].data = inputs[0].data.view(ctx.input_0_shape) # inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
safely_set_tensor_data_attr( # safely_set_tensor_data_attr(
inputs[0], # inputs[0],
gather_split_1d_tensor(inputs[0].data)) # gather_split_1d_tensor(inputs[0].data))
safely_set_tensor_data_attr( # safely_set_tensor_data_attr(
# inputs[0],
# inputs[0].data.view(ctx.input_0_shape))
safely_set_viewless_tensor_data(
inputs[0], inputs[0],
inputs[0].data.view(ctx.input_0_shape)) gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
# <<< # <<<
# Store the current states. # Store the current states.
......
...@@ -20,9 +20,6 @@ import torch ...@@ -20,9 +20,6 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape, tensor_shape,
...@@ -145,16 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -145,16 +142,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = make_viewless_tensor(tensor_recv_prev, tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad = True, requires_grad = True,
keep_graph = False) keep_graph = False)
if recv_next: if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = make_viewless_tensor(tensor_recv_next, tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad = True, requires_grad = True,
keep_graph = False) keep_graph = False)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
...@@ -29,7 +29,7 @@ from megatron.model import Float16Module ...@@ -29,7 +29,7 @@ from megatron.model import Float16Module
from megatron.model import ModelType from megatron.model import ModelType
# >>> # >>>
from megatron.mpu.random import assert_viewless_tensor # from megatron.mpu.random import assert_viewless_tensor
# <<< # <<<
def get_forward_backward_func(): def get_forward_backward_func():
...@@ -115,17 +115,7 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -115,17 +115,7 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrap_output_tensor = True unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
# >>>
# if input_tensor[0] is not None:
# from lutil import pax, tp
# pax({"input_tensor": tp(input_tensor)})
# <<<
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
# >>>
# if input_tensor[0] is not None:
# from lutil import pax, tp
# pax({"input_tensor": tp(input_tensor)})
# <<<
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor) output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
...@@ -530,7 +520,6 @@ def recv_forward(tensor_shapes, timers): ...@@ -530,7 +520,6 @@ def recv_forward(tensor_shapes, timers):
else: else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape, input_tensors.append(p2p_communication.recv_forward(tensor_shape,
timers=timers)) timers=timers))
assert_viewless_tensor(input_tensors[-1])
return input_tensors return input_tensors
...@@ -542,7 +531,6 @@ def recv_backward(tensor_shapes, timers): ...@@ -542,7 +531,6 @@ def recv_backward(tensor_shapes, timers):
else: else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
timers=timers)) timers=timers))
assert_viewless_tensor(output_tensor_grads[-1])
return output_tensor_grads return output_tensor_grads
...@@ -575,7 +563,6 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers): ...@@ -575,7 +563,6 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
output_tensor_grad = p2p_communication.send_forward_recv_backward( output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, timers=timers) output_tensor, tensor_shape, timers=timers)
output_tensor_grads.append(output_tensor_grad) output_tensor_grads.append(output_tensor_grad)
assert_viewless_tensor(output_tensor_grad)
return output_tensor_grads return output_tensor_grads
...@@ -590,7 +577,6 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): ...@@ -590,7 +577,6 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
input_tensor = p2p_communication.send_backward_recv_forward( input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, timers=timers) input_tensor_grad, tensor_shape, timers=timers)
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
assert_viewless_tensor(input_tensor)
return input_tensors return input_tensors
...@@ -636,33 +622,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -636,33 +622,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
# >>>
# if input_tensor[0] is not None:
# from lutil import pax
# pax({"input_tensor": input_tensor})
# <<<
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
# >>>
# if True or input_tensor[0] is not None:
# from lutil import pax
# pax({"input_tensor": input_tensor})
# <<<
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
# >>> input_tensors.append(mpu.assert_viewless_tensor(input_tensor))
# if input_tensor[0] is not None: output_tensors.append(mpu.assert_viewless_tensor(output_tensor))
# from lutil import pax
# pax({"input_tensor": input_tensor})
# if output_tensor[0] is not None:
# from lutil import pax
# pax(0, {"output_tensor / 0": output_tensor[0]})
# <<<
assert_viewless_tensor(input_tensor)
assert_viewless_tensor(output_tensor)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
...@@ -690,10 +656,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -690,10 +656,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers=timers) timers=timers)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
assert_viewless_tensor(input_tensor) input_tensors.append(mpu.assert_viewless_tensor(input_tensor))
assert_viewless_tensor(output_tensor) output_tensors.append(mpu.assert_viewless_tensor(output_tensor))
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) free_output_tensor(output_tensor, args.deallocate_pipeline_outputs)
# Pop input_tensor and output_tensor from the start of the list for # Pop input_tensor and output_tensor from the start of the list for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment