Unverified Commit ec8b1cb0 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Activation checkpointing for non-tensor arguments and return values (#741)

* Activation checkpoint support for non tensor input/output

* Format fixes

* Address PR comments; Add ordering edge case tests
parent 7bf1b837
...@@ -23,6 +23,7 @@ from torch.cuda import _lazy_call, device as device_ctx_manager ...@@ -23,6 +23,7 @@ from torch.cuda import _lazy_call, device as device_ctx_manager
from deepspeed.runtime.config import DeepSpeedConfig from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.runtime.utils import move_to_device
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
#DeepSpeed Checkpointing Enabled or Disabled #DeepSpeed Checkpointing Enabled or Disabled
...@@ -311,6 +312,53 @@ def get_full_inputs(tensors, device=None): ...@@ -311,6 +312,53 @@ def get_full_inputs(tensors, device=None):
return tuple(inputs) return tuple(inputs)
def extract_tensors(all_objects):
"""
Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation.
The order of tensors and non-tensors is preserved in their respective output groups.
Parameters:
all_objects (list/tuple): Objects containing tensors and non-tensors to be split.
Returns:
tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor.
"""
tensor_objects = [v for v in all_objects if torch.is_tensor(v)]
non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)]
tensor_flags = [torch.is_tensor(v) for v in all_objects]
if type(all_objects) is tuple:
return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags)
return tensor_objects, non_tensor_objects, tensor_flags
def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
"""
Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple).
Parameters:
tensor_objects (list/tuple): Tensors to merge.
non_tensor_objects (list/tuple): Non-tensors to merge.
tensor_flags (list/tuple): Indicates whether each position in output is a tensor.
Returns:
tuple: Merge of tensors and non-tensors
"""
merged_objects = []
tensor_idx = 0
non_tensor_idx = 0
for is_tensor in tensor_flags:
if is_tensor:
merged_objects.append(tensor_objects[tensor_idx])
tensor_idx += 1
else:
merged_objects.append(non_tensor_objects[non_tensor_idx])
non_tensor_idx += 1
return tuple(merged_objects)
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with """This function is adapted from torch.utils.checkpoint with
two main changes: two main changes:
...@@ -322,9 +370,15 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -322,9 +370,15 @@ class CheckpointFunction(torch.autograd.Function):
5) Profile forward and backward functions 5) Profile forward and backward functions
""" """
@staticmethod @staticmethod
def forward(ctx, run_function, *args): def forward(ctx, run_function, all_outputs, *args):
global mpu, timers, SYNCHRONIZE, PROFILE_TIME global mpu, timers, SYNCHRONIZE, PROFILE_TIME
def save_args_for_backward(*all_args):
tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
ctx.save_for_backward(*tensor_args)
ctx.non_tensor_args = non_tensor_args
ctx.tensor_flags = tensor_flags
if SYNCHRONIZE: if SYNCHRONIZE:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -417,12 +471,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -417,12 +471,7 @@ class CheckpointFunction(torch.autograd.Function):
inputs.append(args[-1]) inputs.append(args[-1])
#just in case something funky is happening such as reuse of inputs #just in case something funky is happening such as reuse of inputs
inputs_cuda = [] inputs_cuda = move_to_device(args, cuda_device)
for item in args:
if torch.is_tensor(item):
inputs_cuda.append(item.to(cuda_device))
else:
inputs_cuda.append(item)
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
...@@ -485,9 +534,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -485,9 +534,10 @@ class CheckpointFunction(torch.autograd.Function):
#if dist.get_rank() == 0: #if dist.get_rank() == 0:
# logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ") # logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ")
ctx.save_for_backward(*new_args) save_args_for_backward(*new_args)
else: else:
ctx.save_for_backward(*args) save_args_for_backward(*args)
if PROFILE_TIME: if PROFILE_TIME:
timers('forward').stop() timers('forward').stop()
timers.log(['forward']) timers.log(['forward'])
...@@ -498,9 +548,18 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -498,9 +548,18 @@ class CheckpointFunction(torch.autograd.Function):
if torch.is_tensor(outputs): if torch.is_tensor(outputs):
non_grad_outputs = [outputs] if not outputs.is_floating_point() else [] non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
else: else:
non_grad_outputs = [o for o in outputs if not o.is_floating_point()] non_grad_outputs = [
o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()
]
ctx.mark_non_differentiable(*non_grad_outputs) ctx.mark_non_differentiable(*non_grad_outputs)
if torch.is_tensor(outputs):
all_outputs += [outputs]
return outputs return outputs
else:
all_outputs += outputs
outputs, _, _ = extract_tensors(all_objects=outputs)
return tuple(outputs)
@staticmethod @staticmethod
def backward(ctx, *grads): def backward(ctx, *grads):
...@@ -544,6 +603,11 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -544,6 +603,11 @@ class CheckpointFunction(torch.autograd.Function):
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
# Add non tensor input args
detached_inputs = merge_tensors(tensor_objects=detached_inputs,
non_tensor_objects=ctx.non_tensor_args,
tensor_flags=ctx.tensor_flags)
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state = torch.cuda.get_rng_state()
...@@ -569,6 +633,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -569,6 +633,9 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = (outputs, ) outputs = (outputs, )
# Filter out non tensor outputs
outputs, _, _ = extract_tensors(all_objects=outputs)
# Construct arguments to autograd.backward(). # Construct arguments to autograd.backward().
# This is usually just outputs and grads, but forward() can return tensors that # This is usually just outputs and grads, but forward() can return tensors that
# are not differentiable. # are not differentiable.
...@@ -586,7 +653,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -586,7 +653,7 @@ class CheckpointFunction(torch.autograd.Function):
timers.log(['backward']) timers.log(['backward'])
if SYNCHRONIZE: if SYNCHRONIZE:
torch.cuda.synchronize() torch.cuda.synchronize()
ret_list = [None] # first None for ctx ret_list = [None, None] # first None for ctx
for inp in detached_inputs: for inp in detached_inputs:
if torch.is_tensor(inp): if torch.is_tensor(inp):
ret_list.append(inp.grad) ret_list.append(inp.grad)
...@@ -599,7 +666,13 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -599,7 +666,13 @@ class CheckpointFunction(torch.autograd.Function):
def checkpoint(function, *args): def checkpoint(function, *args):
"""Checkpoint a model or part of the model. """Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """ This has been directly copied from torch.utils.checkpoint. """
return CheckpointFunction.apply(function, *args)
all_outputs = []
CheckpointFunction.apply(function, all_outputs, *args)
if len(all_outputs) == 1:
return all_outputs[0]
else:
return tuple(all_outputs)
def partition_activations_in_checkpoint(partition_activation): def partition_activations_in_checkpoint(partition_activation):
......
...@@ -38,6 +38,28 @@ def set_random_seed(seed): ...@@ -38,6 +38,28 @@ def set_random_seed(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
def move_to_device(item, device):
"""
Move tensor onto device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
Parameters:
item: tensor to move or (possibly nested) container of tensors to move.
device: target device
Returns:
None
"""
if torch.is_tensor(item):
return item.to(device)
elif isinstance(item, list):
return [move_to_device(v, device) for v in item]
elif isinstance(item, tuple):
return tuple([move_to_device(v, device) for v in item])
elif isinstance(item, dict):
return {k: move_to_device(v, device) for k, v in item.items()}
else:
return item
class CheckOverflow(object): class CheckOverflow(object):
'''Checks for overflow in gradient across parallel process''' '''Checks for overflow in gradient across parallel process'''
def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False): def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False):
......
...@@ -21,7 +21,8 @@ def _compute(module, *inputs, do_checkpoint=False): ...@@ -21,7 +21,8 @@ def _compute(module, *inputs, do_checkpoint=False):
if torch.is_tensor(outputs): if torch.is_tensor(outputs):
outputs = (outputs, ) outputs = (outputs, )
sum(o.sum() for o in outputs if o.requires_grad).backward() sum(o.sum() for o in outputs if torch.is_tensor(o) and o.requires_grad).backward()
grads = [p.grad for p in module.parameters()] grads = [p.grad for p in module.parameters()]
input_grads = [inp.grad for inp in inputs if torch.is_tensor(inp)] input_grads = [inp.grad for inp in inputs if torch.is_tensor(inp)]
...@@ -44,6 +45,19 @@ def _prep_inputs(*inputs): ...@@ -44,6 +45,19 @@ def _prep_inputs(*inputs):
return tuple(_inputs) return tuple(_inputs)
def _match_outputs(ref, tgt):
assert type(ref) == type(tgt)
if type(ref) in [list, tuple]:
for x, y in zip(ref, tgt):
_match_outputs(x, y)
elif not torch.is_tensor(ref):
assert ref == tgt
elif ref.is_floating_point():
assert torch.allclose(ref, tgt)
else:
assert torch.equal(ref, tgt)
# This is distributed because checkpoint() assumes that torch.distributed is initialized. # This is distributed because checkpoint() assumes that torch.distributed is initialized.
# torch.distributed is used with activation partitioning, but not for these simple cases. # torch.distributed is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1) @distributed_test(world_size=1)
...@@ -64,13 +78,32 @@ def _test_activation_checkpoint(module, *inputs): ...@@ -64,13 +78,32 @@ def _test_activation_checkpoint(module, *inputs):
for group in base.keys(): for group in base.keys():
for b, t in zip(base[group], test[group]): for b, t in zip(base[group], test[group]):
# Catch grad `None`s, etc. _match_outputs(b, t)
if not torch.is_tensor(b):
assert b == t
elif b.is_floating_point(): # This is distributed because checkpoint() assumes that torch.distributed is initialized.
assert torch.allclose(b, t) # torch.distributed is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1)
def _test_activation_checkpoint_ordering(module, expected_ordering, *inputs):
# Move to device
module.cuda()
# Get rid of dropouts until we fork the RNG between tests.
module.eval()
module_ = deepcopy(module)
inputs_ = _prep_inputs(*inputs)
test = _compute(module_, *inputs_, do_checkpoint=True)
outputs = test['outputs']
test_ordering = []
for item in outputs:
if type(item) in [list, tuple]:
test_ordering += [torch.is_tensor(t) for t in item]
else: else:
assert torch.equal(b, t) test_ordering += [torch.is_tensor(item)]
assert expected_ordering == test_ordering
# #
...@@ -179,3 +212,78 @@ def test_ckpt_arg_none(): ...@@ -179,3 +212,78 @@ def test_ckpt_arg_none():
inputs = (torch.rand(HIDDEN_DIM), None) inputs = (torch.rand(HIDDEN_DIM), None)
inputs[0].requires_grad = True inputs[0].requires_grad = True
_test_activation_checkpoint(module, *inputs) _test_activation_checkpoint(module, *inputs)
class LinearNonTensorInput(torch.nn.Linear):
def forward(self, x, non_tensor_input):
return super().forward(x)
@pytest.mark.parametrize(
'non_tensor_input',
[None,
2,
True,
(None,
2.5),
(None,
True,
torch.randn(HIDDEN_DIM))])
def test_ckpt_non_tensor_input(non_tensor_input):
module = LinearNonTensorInput(HIDDEN_DIM, HIDDEN_DIM)
inputs = torch.rand(HIDDEN_DIM)
inputs.requires_grad = True
_test_activation_checkpoint(module, inputs, non_tensor_input)
class LinearNonTensorOutput(torch.nn.Linear):
def __init__(self, non_tensor_output):
super().__init__(HIDDEN_DIM, HIDDEN_DIM)
self.non_tensor_output = non_tensor_output
def forward(self, x):
out = super().forward(x)
return out, self.non_tensor_output
@pytest.mark.parametrize(
'non_tensor_output',
[None,
2,
True,
(None,
2.5),
(None,
True,
torch.randn(HIDDEN_DIM))])
def test_ckpt_non_tensor_output(non_tensor_output):
module = LinearNonTensorOutput(non_tensor_output)
inputs = torch.rand(HIDDEN_DIM)
inputs.requires_grad = True
_test_activation_checkpoint(module, inputs)
@pytest.mark.parametrize('non_tensor_output',
[
None,
(torch.randn(HIDDEN_DIM),
2.5),
(None,
torch.randn(HIDDEN_DIM),
True),
(None,
True,
torch.randn(HIDDEN_DIM))
])
def test_ckpt_non_tensor_output_ordering(non_tensor_output):
module = LinearNonTensorOutput(non_tensor_output)
inputs = torch.rand(HIDDEN_DIM)
inputs.requires_grad = True
# First return is a tensor
ordering = [True]
if type(non_tensor_output) in [list, tuple]:
ordering += [torch.is_tensor(t) for t in non_tensor_output]
else:
ordering += [torch.is_tensor(non_tensor_output)]
_test_activation_checkpoint_ordering(module, ordering, inputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment