Unverified Commit 262c184e authored by Robin Zhang's avatar Robin Zhang Committed by GitHub
Browse files

[PyTorch] Add reset cudagraph interface (#2367)



* reset cudagraph
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* use closure instead of mutable default values
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* add test
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* fix test
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0ded1134
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Iterable, List, Union from typing import Callable, Dict, Iterable, List, Tuple, Union
import pytest import pytest
import torch import torch
...@@ -160,6 +160,20 @@ def get_outputs( ...@@ -160,6 +160,20 @@ def get_outputs(
return values return values
def reset_graphs(
graphed_callables: Union[Callable, Tuple[Callable, ...], Dict[Tuple[int, int], Callable]],
) -> None:
"""Reset CUDA graphs."""
if isinstance(graphed_callables, tuple) or isinstance(graphed_callables, list):
for callable in graphed_callables:
callable.reset()
elif isinstance(graphed_callables, dict):
for callable in graphed_callables.values():
callable.reset()
else:
graphed_callables.reset()
class _Sequential(torch.nn.Sequential): class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules""" """Sequential model that forwards keyword arguments to modules"""
...@@ -322,7 +336,12 @@ def _test_cuda_graphs( ...@@ -322,7 +336,12 @@ def _test_cuda_graphs(
output.backward(grad_output) output.backward(grad_output)
optimizer.step() optimizer.step()
return get_outputs(model, output) outputs = get_outputs(model, output)
if graph_mode == "full":
reset_graphs(model)
elif graph_mode == "individual":
reset_graphs(modules)
return outputs
@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("module", _test_cuda_graphs_modules)
...@@ -468,7 +487,10 @@ def _test_cuda_graphs_with_dot_product_attention( ...@@ -468,7 +487,10 @@ def _test_cuda_graphs_with_dot_product_attention(
output = model(*inputs) output = model(*inputs)
output.backward(grad_output) output.backward(grad_output)
return get_outputs(model, output) outputs = get_outputs(model, output)
if with_graph:
reset_graphs(model)
return outputs
@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dtype", dtypes)
...@@ -553,7 +575,10 @@ def _test_cuda_graphs_with_kwargs( ...@@ -553,7 +575,10 @@ def _test_cuda_graphs_with_kwargs(
output.backward(grad_output) output.backward(grad_output)
optimizer.step() optimizer.step()
return get_outputs(model, output) outputs = get_outputs(model, output)
if with_graph:
reset_graphs(model)
return outputs
def test_make_graphed_callables_with_kwargs( def test_make_graphed_callables_with_kwargs(
...@@ -668,7 +693,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( ...@@ -668,7 +693,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
optimizer.step() optimizer.step()
outputs = [y for _, y in sorted(outputs.items())] outputs = [y for _, y in sorted(outputs.items())]
return get_outputs(model, outputs) outputs = get_outputs(model, outputs)
if with_graph:
reset_graphs(layer_forwards)
return outputs
def test_make_graphed_callables_with_interleaved_pipeline_parallelism( def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
......
...@@ -756,6 +756,21 @@ def _make_graphed_callables( ...@@ -756,6 +756,21 @@ def _make_graphed_callables(
return functionalized return functionalized
def make_graphed_attribute_functions(graph_idx):
# Attach backward_dw as an attribute to the graphed callable.
def backward_dw():
if need_bwd_dw_graph.get(graph_idx, False):
bwd_dw_graphs[graph_idx].replay()
# Attach reset as an attribute to the graphed callable.
def reset():
fwd_graphs[graph_idx].reset()
bwd_graphs[graph_idx].reset()
bwd_dw_graphs[graph_idx].reset()
return backward_dw, reset
# Put together the final graphed callables # Put together the final graphed callables
ret = [] ret = []
for i in range(len(sample_args)): for i in range(len(sample_args)):
...@@ -831,15 +846,9 @@ def _make_graphed_callables( ...@@ -831,15 +846,9 @@ def _make_graphed_callables(
else: else:
ret.append(graphed) ret.append(graphed)
# Attach backward_dw as an attribute to the graphed callable. backward_dw_func, reset_func = make_graphed_attribute_functions(i)
def backward_dw( setattr(ret[-1], "backward_dw", backward_dw_func)
need_backward_dw=need_bwd_dw_graph.get(i, False), setattr(ret[-1], "reset", reset_func)
bwd_dw_graph=bwd_dw_graphs[i],
):
if need_backward_dw:
bwd_dw_graph.replay()
setattr(ret[-1], "backward_dw", backward_dw)
if just_one_callable: if just_one_callable:
return ret[0] return ret[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment