Unverified Commit 324be332 authored by Robin Zhang's avatar Robin Zhang Committed by GitHub
Browse files

[PyTorch] Support cudagraph recomputation (#2518)



* replace autograd.grad with autograd.backward
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* get/set graphable rng state
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* fix lint
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 697b52cb
......@@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool:
)
def is_graph_safe_rng_state(state: Union[torch.Tensor, torch.Generator]) -> bool:
"""Returns whether the rng state is a graph safe version."""
return graph_safe_rng_available() and isinstance(state, torch.Generator)
def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda",
clone: bool = False,
......@@ -340,9 +345,16 @@ class _CheckpointFunction(torch.autograd.Function):
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
ctx.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(ctx.fwd_cuda_rng_state_tracker.values())))
if ctx.fwd_cuda_rng_state_tracker
else False
)
else:
ctx.graph_safe_rng_state = False
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
if context_fn is not None:
forward_ctx, recompute_ctx = context_fn()
......@@ -406,13 +418,13 @@ class _CheckpointFunction(torch.autograd.Function):
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
......@@ -427,7 +439,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
......@@ -470,12 +482,21 @@ class _CheckpointFrame:
def cache_rng_states(self, forward=True):
"""Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = (
torch.get_rng_state(),
_get_cuda_rng_state(graph_safe=False),
)
rng_states = (torch.get_rng_state(),)
if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(),)
tracker_states = self.get_rng_state_tracker().get_states()
self.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(tracker_states.values())))
if tracker_states
else False
)
rng_states += (
_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),
tracker_states,
)
else:
self.graph_safe_rng_state = False
rng_states += (_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),)
if forward:
self.fwd_rng_states = rng_states
......@@ -490,7 +511,7 @@ class _CheckpointFrame:
rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1], graph_safe=False)
_set_cuda_rng_state(rng_states[1], graph_safe=self.graph_safe_rng_state)
if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2])
......
......@@ -62,6 +62,21 @@ def graph_pool_handle():
return _graph_pool_handle()
@contextlib.contextmanager
def _none_grad_context_wrapper(inputs):
"""
Wrapper to set the gradients of the inputs to None,
in case the backward pass makes grad accumulations.
"""
original_input_grads = []
for input_tensor in inputs:
original_input_grads.append(input_tensor.grad)
input_tensor.grad = None
yield
for input_tensor, original_grad in zip(inputs, original_input_grads):
input_tensor.grad = original_grad
@contextlib.contextmanager
def _graph_context_wrapper(*args, **kwargs):
"""Wrapper around `torch.cuda.graph`.
......@@ -434,13 +449,15 @@ def _make_graphed_callables(
for hook in hooks:
hook.remove()
if is_training:
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
only_inputs=True,
allow_unused=allow_unused_input,
)
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs):
torch.autograd.backward(
tuple(o for o in outputs if o.requires_grad),
grad_tensors=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
)
grad_inputs = tuple(input.grad for input in inputs)
# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
......@@ -455,6 +472,14 @@ def _make_graphed_callables(
module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
if (
grad_inputs[grad_inputs_idx] is None
and grad_inputs_idx < num_required_grad_sample_args
):
assert allow_unused_input, (
"The input tensor requires grad, but the grad is None after"
" backward pass."
)
elif (
grad_inputs[grad_inputs_idx] is not None
and grad_inputs_idx >= num_required_grad_sample_args
):
......@@ -606,15 +631,17 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
grad_inputs = tuple(input.grad for input in inputs)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
......@@ -695,15 +722,17 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
grad_inputs = tuple(input.grad for input in inputs)
if need_bwd_dw_graph[bwd_idx]:
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[bwd_idx]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment