"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "2f61c40192d0cfdc825c4063c62c9a12a352ffa9"
Unverified Commit 50352325 authored by Robin Zhang's avatar Robin Zhang Committed by GitHub
Browse files

[PyTorch] Convert sample tuple to list in cudagraph input reuse (#2426)



Convert sample tuple to list in reuse
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 811e0908
...@@ -218,9 +218,10 @@ def _make_graphed_callables( ...@@ -218,9 +218,10 @@ def _make_graphed_callables(
assert ( assert (
is_training is_training
), "`_reuse_graph_input_output_buffers` is only available in training mode." ), "`_reuse_graph_input_output_buffers` is only available in training mode."
assert isinstance( if isinstance(sample_args, tuple):
sample_args, list sample_args = list(sample_args)
), "sample_args must be a list for _reuse_graph_input_output_buffers." if isinstance(sample_kwargs, tuple):
sample_kwargs = list(sample_kwargs)
# Reorganize args and kwargs for input tensor reuse. # Reorganize args and kwargs for input tensor reuse.
# fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples. # fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment