Unverified Commit 324be332 authored by Robin Zhang's avatar Robin Zhang Committed by GitHub
Browse files

[PyTorch] Support cudagraph recomputation (#2518)



* replace autograd.grad with autograd.backward
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* get/set graphable rng state
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* fix lint
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 697b52cb
...@@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool: ...@@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool:
) )
def is_graph_safe_rng_state(state: Union[torch.Tensor, torch.Generator]) -> bool:
"""Returns whether the rng state is a graph safe version."""
return graph_safe_rng_available() and isinstance(state, torch.Generator)
def _get_cuda_rng_state( def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda", device: Union[int, str, torch.device] = "cuda",
clone: bool = False, clone: bool = False,
...@@ -340,9 +345,16 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -340,9 +345,16 @@ class _CheckpointFunction(torch.autograd.Function):
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
ctx.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(ctx.fwd_cuda_rng_state_tracker.values())))
if ctx.fwd_cuda_rng_state_tracker
else False
)
else:
ctx.graph_safe_rng_state = False
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
if context_fn is not None: if context_fn is not None:
forward_ctx, recompute_ctx = context_fn() forward_ctx, recompute_ctx = context_fn()
...@@ -406,13 +418,13 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -406,13 +418,13 @@ class _CheckpointFunction(torch.autograd.Function):
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass. # Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state) torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False) _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
...@@ -427,7 +439,7 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -427,7 +439,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state) torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False) _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
...@@ -470,12 +482,21 @@ class _CheckpointFrame: ...@@ -470,12 +482,21 @@ class _CheckpointFrame:
def cache_rng_states(self, forward=True): def cache_rng_states(self, forward=True):
"""Cache fwd/bwd RNG states in the frame to restore later.""" """Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = ( rng_states = (torch.get_rng_state(),)
torch.get_rng_state(),
_get_cuda_rng_state(graph_safe=False),
)
if self.get_rng_state_tracker is not None: if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(),) tracker_states = self.get_rng_state_tracker().get_states()
self.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(tracker_states.values())))
if tracker_states
else False
)
rng_states += (
_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),
tracker_states,
)
else:
self.graph_safe_rng_state = False
rng_states += (_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),)
if forward: if forward:
self.fwd_rng_states = rng_states self.fwd_rng_states = rng_states
...@@ -490,7 +511,7 @@ class _CheckpointFrame: ...@@ -490,7 +511,7 @@ class _CheckpointFrame:
rng_states = self.bwd_rng_states rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0]) torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1], graph_safe=False) _set_cuda_rng_state(rng_states[1], graph_safe=self.graph_safe_rng_state)
if self.get_rng_state_tracker is not None: if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2]) self.get_rng_state_tracker().set_states(rng_states[2])
......
...@@ -62,6 +62,21 @@ def graph_pool_handle(): ...@@ -62,6 +62,21 @@ def graph_pool_handle():
return _graph_pool_handle() return _graph_pool_handle()
@contextlib.contextmanager
def _none_grad_context_wrapper(inputs):
"""
Wrapper to set the gradients of the inputs to None,
in case the backward pass makes grad accumulations.
"""
original_input_grads = []
for input_tensor in inputs:
original_input_grads.append(input_tensor.grad)
input_tensor.grad = None
yield
for input_tensor, original_grad in zip(inputs, original_input_grads):
input_tensor.grad = original_grad
@contextlib.contextmanager @contextlib.contextmanager
def _graph_context_wrapper(*args, **kwargs): def _graph_context_wrapper(*args, **kwargs):
"""Wrapper around `torch.cuda.graph`. """Wrapper around `torch.cuda.graph`.
...@@ -434,13 +449,15 @@ def _make_graphed_callables( ...@@ -434,13 +449,15 @@ def _make_graphed_callables(
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
if is_training: if is_training:
grad_inputs = torch.autograd.grad( inputs = tuple(i for i in static_input_surface if i.requires_grad)
outputs=tuple(o for o in outputs if o.requires_grad), with _none_grad_context_wrapper(inputs):
inputs=tuple(i for i in static_input_surface if i.requires_grad), torch.autograd.backward(
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), tuple(o for o in outputs if o.requires_grad),
only_inputs=True, grad_tensors=tuple(
allow_unused=allow_unused_input, torch.empty_like(o) for o in outputs if o.requires_grad
) ),
)
grad_inputs = tuple(input.grad for input in inputs)
# Filter module params that get None grad from grad_inputs and remove them # Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks # from static_input_surface. This is to ensure that the backward hooks
...@@ -455,6 +472,14 @@ def _make_graphed_callables( ...@@ -455,6 +472,14 @@ def _make_graphed_callables(
module_params_with_grad = [] module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
if ( if (
grad_inputs[grad_inputs_idx] is None
and grad_inputs_idx < num_required_grad_sample_args
):
assert allow_unused_input, (
"The input tensor requires grad, but the grad is None after"
" backward pass."
)
elif (
grad_inputs[grad_inputs_idx] is not None grad_inputs[grad_inputs_idx] is not None
and grad_inputs_idx >= num_required_grad_sample_args and grad_inputs_idx >= num_required_grad_sample_args
): ):
...@@ -606,15 +631,17 @@ def _make_graphed_callables( ...@@ -606,15 +631,17 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
if is_training: if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool): inputs = tuple(i for i in static_input_surface if i.requires_grad)
grad_inputs = torch.autograd.grad( with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
outputs=tuple(o for o in static_outputs if o.requires_grad), bwd_graph, pool=mempool
inputs=tuple(i for i in static_input_surface if i.requires_grad), ):
grad_outputs=tuple(o for o in static_grad_outputs if o is not None), torch.autograd.backward(
only_inputs=True, tuple(o for o in static_outputs if o.requires_grad),
allow_unused=allow_unused_input, grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
grad_inputs = tuple(input.grad for input in inputs)
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs # Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern. # that don't require grad. I couldn't think of a one-liner for this pattern.
...@@ -695,15 +722,17 @@ def _make_graphed_callables( ...@@ -695,15 +722,17 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
if is_training: if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool): inputs = tuple(i for i in static_input_surface if i.requires_grad)
grad_inputs = torch.autograd.grad( with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
outputs=tuple(o for o in static_outputs if o.requires_grad), bwd_graph, pool=mempool
inputs=tuple(i for i in static_input_surface if i.requires_grad), ):
grad_outputs=tuple(o for o in static_grad_outputs if o is not None), torch.autograd.backward(
only_inputs=True, tuple(o for o in static_outputs if o.requires_grad),
allow_unused=allow_unused_input, grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
grad_inputs = tuple(input.grad for input in inputs)
if need_bwd_dw_graph[bwd_idx]: if need_bwd_dw_graph[bwd_idx]:
with _graph_context_wrapper(bwd_dw_graph, pool=mempool): with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[bwd_idx]: for module in visited_te_modules[bwd_idx]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment