Unverified Commit e950ceb0 authored by buptzyb's avatar buptzyb Committed by GitHub
Browse files

[PyTorch] Optimize cudagraph static_grad_outputs reuse (#1992)



* optimize static grad outputs
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 25a82192
......@@ -422,7 +422,7 @@ def _make_graphed_callables(
per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
static_grad_outputs = None
static_grad_outputs_dict = {}
previous_per_callable_bwd_idx = None
for c_id in _order:
if c_id > 0:
......@@ -454,9 +454,21 @@ def _make_graphed_callables(
static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad
if not _reuse_graph_input_output_buffers or static_grad_outputs is None:
if _reuse_graph_input_output_buffers:
# Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time.
static_grad_outputs_keys = tuple(
(o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad
)
if static_grad_outputs_keys in static_grad_outputs_dict:
static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys]
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None
for o in static_outputs
)
static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment