"vscode:/vscode.git/clone" did not exist on "8b430d7dea5695324636fc458c1cce52213bd499"
Unverified Commit 262c184e authored by Robin Zhang's avatar Robin Zhang Committed by GitHub
Browse files

[PyTorch] Add reset cudagraph interface (#2367)



* reset cudagraph
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* use closure instead of mutable default values
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* add test
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* fix test
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0ded1134
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
from typing import Iterable, List, Union
from typing import Callable, Dict, Iterable, List, Tuple, Union
import pytest
import torch
......@@ -160,6 +160,20 @@ def get_outputs(
return values
def reset_graphs(
graphed_callables: Union[Callable, Tuple[Callable, ...], Dict[Tuple[int, int], Callable]],
) -> None:
"""Reset CUDA graphs."""
if isinstance(graphed_callables, tuple) or isinstance(graphed_callables, list):
for callable in graphed_callables:
callable.reset()
elif isinstance(graphed_callables, dict):
for callable in graphed_callables.values():
callable.reset()
else:
graphed_callables.reset()
class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules"""
......@@ -322,7 +336,12 @@ def _test_cuda_graphs(
output.backward(grad_output)
optimizer.step()
return get_outputs(model, output)
outputs = get_outputs(model, output)
if graph_mode == "full":
reset_graphs(model)
elif graph_mode == "individual":
reset_graphs(modules)
return outputs
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
......@@ -468,7 +487,10 @@ def _test_cuda_graphs_with_dot_product_attention(
output = model(*inputs)
output.backward(grad_output)
return get_outputs(model, output)
outputs = get_outputs(model, output)
if with_graph:
reset_graphs(model)
return outputs
@pytest.mark.parametrize("dtype", dtypes)
......@@ -553,7 +575,10 @@ def _test_cuda_graphs_with_kwargs(
output.backward(grad_output)
optimizer.step()
return get_outputs(model, output)
outputs = get_outputs(model, output)
if with_graph:
reset_graphs(model)
return outputs
def test_make_graphed_callables_with_kwargs(
......@@ -668,7 +693,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
optimizer.step()
outputs = [y for _, y in sorted(outputs.items())]
return get_outputs(model, outputs)
outputs = get_outputs(model, outputs)
if with_graph:
reset_graphs(layer_forwards)
return outputs
def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
......
......@@ -756,6 +756,21 @@ def _make_graphed_callables(
return functionalized
def make_graphed_attribute_functions(graph_idx):
# Attach backward_dw as an attribute to the graphed callable.
def backward_dw():
if need_bwd_dw_graph.get(graph_idx, False):
bwd_dw_graphs[graph_idx].replay()
# Attach reset as an attribute to the graphed callable.
def reset():
fwd_graphs[graph_idx].reset()
bwd_graphs[graph_idx].reset()
bwd_dw_graphs[graph_idx].reset()
return backward_dw, reset
# Put together the final graphed callables
ret = []
for i in range(len(sample_args)):
......@@ -831,15 +846,9 @@ def _make_graphed_callables(
else:
ret.append(graphed)
# Attach backward_dw as an attribute to the graphed callable.
def backward_dw(
need_backward_dw=need_bwd_dw_graph.get(i, False),
bwd_dw_graph=bwd_dw_graphs[i],
):
if need_backward_dw:
bwd_dw_graph.replay()
setattr(ret[-1], "backward_dw", backward_dw)
backward_dw_func, reset_func = make_graphed_attribute_functions(i)
setattr(ret[-1], "backward_dw", backward_dw_func)
setattr(ret[-1], "reset", reset_func)
if just_one_callable:
return ret[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment