Unverified Commit c963a72a authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Add fairscale.nn.misc.checkpoint_activations (#376)



* Add fairscale.utils.containers
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* Add fairscale.nn.misc.checkpoint_activations
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent e92e85ce
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Any, Dict, Optional, Tuple
import torch
from torch import Tensor
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Module:
"""
A friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version:
- wraps an nn.Module, so that all subsequent calls will use checkpointing
- handles keyword arguments in the forward
- handles non-Tensor outputs from the forward
- supports offloading activations to CPU
Usage::
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))
Args:
module (nn.Module): module to wrap
offload_to_cpu (Optional, bool): whether to offload activations to CPU
"""
module.forward = functools.partial(_checkpointed_forward, module.forward, offload_to_cpu) # type: ignore
return module
def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any:
# Autograd Functions in PyTorch work best with positional args, since
# the backward must return gradients (or None) for every input argument.
# We can flatten keyword arguments to make this easier.
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu}
output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args)
if isinstance(output, torch.Tensor):
return output
else:
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
if packed_non_tensor_outputs:
output = unpack_non_tensors(output, packed_non_tensor_outputs)
return output
def get_rng_state() -> Dict[str, Any]:
state = {"torch_rng_state": torch.get_rng_state()}
if torch.cuda.is_available():
state["cuda_rng_state"] = torch.cuda.get_rng_state()
return state
def set_rng_state(state: Dict[str, Any]) -> None:
torch.set_rng_state(state["torch_rng_state"])
if torch.cuda.is_available():
torch.cuda.set_rng_state(state["cuda_rng_state"])
class CheckpointFunction(torch.autograd.Function):
"""Similar to the torch version, but support non-Tensor outputs.
The caller is expected to provide a dict (*parent_ctx_dict*) that will hold
the non-Tensor outputs. These should be combined with the Tensor *outputs*
by calling :func:`unpack_non_tensors`.
"""
@staticmethod
def forward( # type: ignore
ctx: Any,
run_function: Any,
parent_ctx_dict: Dict[str, Any],
kwarg_keys: Tuple[str, ...],
*args: Any,
**kwargs: Any
) -> Any:
if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation
checkpoint.check_backward_validity(args)
ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys
ctx.fwd_rng_state = get_rng_state()
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
if parent_ctx_dict["offload"]:
ctx.fwd_device = tuple(x.device for x in tensor_inputs)
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
tensor_inputs = tuple(x.cpu() for x in tensor_inputs)
else:
ctx.fwd_device, ctx.grad_requirements = None, None
ctx.save_for_backward(*tensor_inputs)
ctx.packed_non_tensor_inputs = packed_non_tensor_inputs
with torch.no_grad():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs)
if isinstance(outputs, torch.Tensor):
return outputs
else:
# Autograd Functions don't like non-Tensor outputs. We can split the
# non-Tensor and Tensor outputs, returning the former by reference
# through *parent_ctx_dict* and returning the latter directly.
outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
return outputs
@staticmethod
def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
tensor_inputs: Tuple = ctx.saved_tensors
tensor_inputs = checkpoint.detach_variable(tensor_inputs)
if ctx.fwd_device is not None:
tensor_inputs = tuple(t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs))
for i, need_grad in enumerate(ctx.grad_requirements):
tensor_inputs[i].requires_grad = need_grad
inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs)
# Store the current states.
bwd_rng_state = get_rng_state()
# Set the states to what it used to be before the forward pass.
set_rng_state(ctx.fwd_rng_state)
with torch.enable_grad():
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs)
# Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state)
# Run backward() with only Tensors that require grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(tensor_outputs)):
if tensor_outputs[i].requires_grad:
outputs_with_grad.append(tensor_outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
return (None, None, None) + grads
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
"""Useful functions to deal with tensor types with other python container types."""
def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Recursively apply to all tensor in 4 kinds of container types."""
def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x):
return fn(x)
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
return x
return _apply(container)
def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]:
"""
Turn argument list into separate key list and value list (unpack_kwargs does the opposite)
Usage::
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
assert kwarg_keys == ("a", "b")
assert flat_args == (1, 2, 3, 4)
args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
assert args == (1, 2)
assert kwargs == {"a": 3, "b": 4}
"""
kwarg_keys: List[str] = []
flat_args: List[Any] = list(args)
for k, v in kwargs.items():
kwarg_keys.append(k)
flat_args.append(v)
return tuple(kwarg_keys), tuple(flat_args)
def unpack_kwargs(kwarg_keys: Tuple[str, ...], flat_args: Tuple[Any, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""See pack_kwargs."""
assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
if len(kwarg_keys) == 0:
return flat_args, {}
args = flat_args[: -len(kwarg_keys)]
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])}
return args, kwargs
def split_non_tensors(
mixed: Union[torch.Tensor, Tuple[Any, ...]]
) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]:
"""
Split a tuple into a list of tensors and the rest with information
for later reconstruction.
Usage::
x = torch.Tensor([1])
y = torch.Tensor([2])
tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
assert tensors == (x, y)
assert packed_non_tensors == {
"is_tensor": [True, True, False, False],
"objects": [None, 3],
}
recon = unpack_non_tensors(tensors, packed_non_tensors)
assert recon == (x, y, None, 3)
"""
if isinstance(mixed, torch.Tensor):
return (mixed,), None
tensors: List[torch.Tensor] = []
packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []}
for o in mixed:
if isinstance(o, torch.Tensor):
packed_non_tensors["is_tensor"].append(True)
tensors.append(o)
else:
packed_non_tensors["is_tensor"].append(False)
packed_non_tensors["objects"].append(o)
return tuple(tensors), packed_non_tensors
def unpack_non_tensors(
tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]]
) -> Tuple[Any, ...]:
"""See split_non_tensors."""
if packed_non_tensors is None:
return tensors
assert isinstance(packed_non_tensors, dict), type(packed_non_tensors)
mixed: List[Any] = []
is_tensor_list = packed_non_tensors["is_tensor"]
objects = packed_non_tensors["objects"]
assert len(tensors) + len(objects) == len(is_tensor_list), (
f"len(tensors) {len(tensors)} len(objects) {len(objects)} " f"len(is_tensor_list) {len(is_tensor_list)}"
)
obj_i = tnsr_i = 0
for is_tensor in is_tensor_list:
if is_tensor:
mixed.append(tensors[tnsr_i])
tnsr_i += 1
else:
mixed.append(objects[obj_i])
obj_i += 1
return tuple(mixed)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple from typing import Any, Iterable, Tuple
from .. import Tensor from .. import Tensor
from torch.nn.modules.module import Module from torch.nn.modules.module import Module
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ... def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ...
def checkpoint(function: Module, *args, **kwargs): ... def checkpoint(function: Module, *args, **kwargs): ...
def check_backward_validity(inputs: Iterable[Any]): ...
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Test fairscale.nn.misc.checkpoint_activations
"""
import unittest
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
class Model(nn.Module):
def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs):
super().__init__()
torch.manual_seed(0)
self.use_pytorch_checkpoint = use_pytorch_checkpoint
self.ffn = nn.Sequential(
nn.Linear(32, 128),
# add a Dropout layer to test RNG save/restore
nn.Dropout(p=0.5),
nn.Linear(128, 32),
)
if use_fairseq_checkpoint:
self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
self.out = nn.Linear(32, 1)
def forward(self, x):
if self.use_pytorch_checkpoint:
x = checkpoint(self.ffn, x)
else:
x = self.ffn(x)
return self.out(x)
class TestComparisonToPyTorch(unittest.TestCase):
def _test_checkpoint_wrapper(self, device, log_memory_usage=False):
def get_loss_and_gnorm(model):
torch.manual_seed(1)
input = torch.rand(2, 16, 32).requires_grad_(True).to(device)
model.zero_grad()
loss = model(input).sum()
loss.backward()
gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]))
return {"loss": loss, "gnorm": gnorm}
model = Model().to(device)
no_cpt = get_loss_and_gnorm(model)
model = Model(use_pytorch_checkpoint=True).to(device)
pyt_cpt = get_loss_and_gnorm(model)
torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"])
model = Model(use_fairseq_checkpoint=True).to(device)
fairseq_cpt = get_loss_and_gnorm(model)
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"])
model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device)
fairseq_cpt_offload = get_loss_and_gnorm(model)
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"])
def test_checkpoint_wrapper_cpu(self):
self._test_checkpoint_wrapper(device=torch.device("cpu"))
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
def test_checkpoint_wrapper_cuda(self):
self._test_checkpoint_wrapper(device=torch.device("cuda"))
if __name__ == "__main__":
unittest.main()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test utility classes from containers.py. """
import random
import pytest
import torch
from fairscale.utils.containers import (
apply_to_tensors,
pack_kwargs,
split_non_tensors,
unpack_kwargs,
unpack_non_tensors,
)
@pytest.mark.parametrize("devices", [["cpu"], ["cuda"], ["cpu", "cuda"]])
def test_apply_to_tensors(devices):
"""Test apply_to_tensors for both cpu & gpu"""
if "cuda" in devices and not torch.cuda.is_available() or torch.cuda.device_count() < 1:
pytest.skip("Skipped due to lack of GPU")
expected = 0
def get_a_tensor():
"""Return a random tensor on random device."""
dev = random.choice(devices)
shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10)))
t = torch.rand(shape).to(dev)
nonlocal expected
expected += t.numel()
return t
# create a mixed bag of data.
data = [1, "str"]
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2))))
total = 0
def fn(t, x=[[total]]):
nonlocal total
total += t.numel()
return t
apply_to_tensors(fn, data)
assert total == expected, f"{total} vs. {expected}"
def test_pack_unpack():
"""Test pack_kwargs and unpack_kwargs."""
kwarg_keys, flat_args = pack_kwargs(1, 2, 3, 4)
assert kwarg_keys == tuple()
assert flat_args == (1, 2, 3, 4)
kwarg_keys, flat_args = pack_kwargs(a=1, b={2: "2"}, c={3}, d=[4], e=(5,))
assert kwarg_keys == ("a", "b", "c", "d", "e")
assert flat_args == (1, {2: "2"}, {3}, [4], (5,))
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
assert kwarg_keys == ("a", "b")
assert flat_args == (1, 2, 3, 4)
args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
assert args == (1, 2)
assert kwargs == {"a": 3, "b": 4}
args, kwargs = unpack_kwargs([], flat_args)
assert kwargs == {}
assert args == (1, 2, 3, 4)
args, kwargs = unpack_kwargs(["a", "b", "c", "d"], flat_args)
assert kwargs == {"a": 1, "b": 2, "c": 3, "d": 4}
assert args == tuple()
with pytest.raises(AssertionError):
# too many keys should assert.
args, kwargs = unpack_kwargs(["a", "b", "c", "d", "e"], flat_args)
def test_split_unpack():
"""Test split_non_tensors and unpack_non_tensors."""
x = torch.Tensor([1])
y = torch.Tensor([2])
tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
assert tensors == (x, y)
assert packed_non_tensors == {
"is_tensor": [True, True, False, False],
"objects": [None, 3],
}
recon = unpack_non_tensors(tensors, packed_non_tensors)
assert recon == (x, y, None, 3)
tensors, packed_non_tensors = split_non_tensors((None, 3, x, y))
recon = unpack_non_tensors(tensors, packed_non_tensors)
assert recon == (None, 3, x, y)
tensors, packed_non_tensors = split_non_tensors((None, 3))
recon = unpack_non_tensors(tensors, packed_non_tensors)
assert recon == (None, 3)
tensors, packed_non_tensors = split_non_tensors((x, y))
recon = unpack_non_tensors(tensors, packed_non_tensors)
assert recon == (x, y)
recon = unpack_non_tensors(tensors, None)
assert recon == (x, y)
with pytest.raises(AssertionError):
# assert the second arg should be a dict.
recon = unpack_non_tensors(tensors, set())
with pytest.raises(AssertionError):
# assert the content of the second arg should be sane.
recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment