Commit 5422d23a authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

debugging make_standalone_tensor(), safely_set_tensor_data_attr()

parent 9a8b89ac
......@@ -696,6 +696,21 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states = self.input_tensor
# >>>
def make_standalone_tensor(a):
assert a._base is not None
b = torch.empty((1,), dtype = a.dtype, device = a.device)
b.data = a.data
return b
# <<<
# hidden_states = make_standalone_tensor(hidden_states)
hidden_states = hidden_states.clone()
# >>>
# from lutil import pax
# pax({"hidden_states": hidden_states})
# <<<
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
......
......@@ -98,6 +98,13 @@ def gather_split_1d_tensor(tensor):
group=get_tensor_model_parallel_group())
return gathered
def safely_set_tensor_data_attr(tensor, new_data_tensor):
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data. Otherwise, "
"a memory leak will occur (and likely accumulate over iterations). "
"FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
) % (tensor._base.shape, new_data_tensor.shape)
tensor.data = new_data_tensor
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......@@ -241,9 +248,16 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_checkpointed_activations:
# >>>
# from lutil import data_leak_ctx
# with data_leak_ctx(args[0]):
# <<<
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
new_buffer=True)
# args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
# new_buffer=True)
safely_set_tensor_data_attr(
args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
# Store everything.
ctx.save_for_backward(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment