Unverified Commit 96944a81 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid garbage collection when capturing a CUDA Graph (#2092)



Avoid garbage collection when capturing a CUDA Graph
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent bc99a88d
......@@ -4,6 +4,8 @@
"""Functions for CUDA Graphs support in FP8"""
from collections.abc import Iterable
import contextlib
import gc
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch
......@@ -58,6 +60,25 @@ def graph_pool_handle():
return _graph_pool_handle()
@contextlib.contextmanager
def _graph_context_wrapper(*args, **kwargs):
"""Wrapper around `torch.cuda.graph`.
This wrapper is a temporary workaround for a PyTorch bug:
automatic garbage collection can destroy a graph while another
graph is being captured, resulting in a CUDA error. See
https://github.com/pytorch/pytorch/pull/161037.
"""
gc_is_enabled = gc.isenabled()
if gc_is_enabled:
gc.disable()
with torch.cuda.graph(*args, **kwargs):
yield
if gc_is_enabled:
gc.enable()
def _make_graphed_callables(
callables: SingleOrTuple[Callable],
sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
......@@ -445,7 +466,7 @@ def _make_graphed_callables(
args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool):
with _graph_context_wrapper(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs)
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
......@@ -483,7 +504,7 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
......@@ -548,7 +569,7 @@ def _make_graphed_callables(
per_callable_output_unflatten_spec = []
graph_id = 0
for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool):
with _graph_context_wrapper(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs)
graph_callables[graph_id] = func
graph_id += 1
......@@ -570,7 +591,7 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment