Commit a1fe4805 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

added comments

parent 806422e5
...@@ -696,9 +696,6 @@ class ParallelTransformer(MegatronModule): ...@@ -696,9 +696,6 @@ class ParallelTransformer(MegatronModule):
hidden_states = self.input_tensor hidden_states = self.input_tensor
# Viewless tensor # Viewless tensor
# >>>
assert hidden_states is not None, "rank == %d." % torch.distributed.get_rank()
# <<<
hidden_states = mpu.make_viewless_tensor( hidden_states = mpu.make_viewless_tensor(
hidden_states, hidden_states,
requires_grad = True, requires_grad = True,
......
...@@ -98,15 +98,15 @@ def gather_split_1d_tensor(tensor): ...@@ -98,15 +98,15 @@ def gather_split_1d_tensor(tensor):
group=get_tensor_model_parallel_group()) group=get_tensor_model_parallel_group())
return gathered return gathered
class MakeViewlessTensor(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def _kernel_make_viewless_tensor(inp, requires_grad): def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty( out = torch.empty(
(1,), (1,),
dtype = inp.dtype, dtype = inp.dtype,
...@@ -116,7 +116,31 @@ def _kernel_make_viewless_tensor(inp, requires_grad): ...@@ -116,7 +116,31 @@ def _kernel_make_viewless_tensor(inp, requires_grad):
out.data = inp.data out.data = inp.data
return out return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph): def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view' # return tensor as-is, if not a 'view'
if inp._base is None: if inp._base is None:
...@@ -129,6 +153,8 @@ def make_viewless_tensor(inp, requires_grad, keep_graph): ...@@ -129,6 +153,8 @@ def make_viewless_tensor(inp, requires_grad, keep_graph):
return _kernel_make_viewless_tensor(inp, requires_grad) return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None): def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list): if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ] [ assert_viewless_tensor(t) for t in tensor ]
return tensor return tensor
...@@ -142,6 +168,11 @@ def assert_viewless_tensor(tensor, extra_msg = None): ...@@ -142,6 +168,11 @@ def assert_viewless_tensor(tensor, extra_msg = None):
return tensor return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor): def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor tensor.data = new_data_tensor
...@@ -305,19 +336,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -305,19 +336,9 @@ class CheckpointFunction(torch.autograd.Function):
"please use .backward() if possible") "please use .backward() if possible")
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if ctx.distribute_checkpointed_activations: if ctx.distribute_checkpointed_activations:
# >>>
# inputs[0].data = gather_split_1d_tensor(inputs[0].data)
# inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
# safely_set_tensor_data_attr(
# inputs[0],
# gather_split_1d_tensor(inputs[0].data))
# safely_set_tensor_data_attr(
# inputs[0],
# inputs[0].data.view(ctx.input_0_shape))
safely_set_viewless_tensor_data( safely_set_viewless_tensor_data(
inputs[0], inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)) gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
# <<<
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
......
...@@ -28,9 +28,6 @@ from megatron.model import DistributedDataParallel as LocalDDP ...@@ -28,9 +28,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType from megatron.model import ModelType
# >>>
# from megatron.mpu.random import assert_viewless_tensor
# <<<
def get_forward_backward_func(): def get_forward_backward_func():
args = get_args() args = get_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment