Unverified Commit 50352325 authored by Robin Zhang's avatar Robin Zhang Committed by GitHub
Browse files

[PyTorch] Convert sample tuple to list in cudagraph input reuse (#2426)



Convert sample tuple to list in reuse
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 811e0908
......@@ -218,9 +218,10 @@ def _make_graphed_callables(
assert (
is_training
), "`_reuse_graph_input_output_buffers` is only available in training mode."
assert isinstance(
sample_args, list
), "sample_args must be a list for _reuse_graph_input_output_buffers."
if isinstance(sample_args, tuple):
sample_args = list(sample_args)
if isinstance(sample_kwargs, tuple):
sample_kwargs = list(sample_kwargs)
# Reorganize args and kwargs for input tensor reuse.
# fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment