Unverified Commit 8ebb47e5 authored by Lifu Zhang's avatar Lifu Zhang Committed by GitHub
Browse files

Fix on TE to support Mcore Vision Encoder CUDA Graph (#2657)



* Fix on TE to support Mcore Vision Encoder CUDA Graph
Signed-off-by: default avatarLifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactoring code
Signed-off-by: default avatarLifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarLifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>
Co-authored-by: default avatarLifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8d152585
......@@ -451,11 +451,12 @@ def _make_graphed_callables(
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs):
outputs_requiring_grad = tuple(
o for o in outputs if o is not None and o.requires_grad
)
torch.autograd.backward(
tuple(o for o in outputs if o.requires_grad),
grad_tensors=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
outputs_requiring_grad,
grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad),
)
grad_inputs = tuple(input.grad for input in inputs)
......@@ -616,19 +617,22 @@ def _make_graphed_callables(
# Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time.
static_grad_outputs_keys = tuple(
(o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad
(o.shape, o.dtype, o.layout)
for o in static_outputs
if o is not None and o.requires_grad
)
if static_grad_outputs_keys in static_grad_outputs_dict:
static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys]
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
......@@ -636,7 +640,9 @@ def _make_graphed_callables(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
tuple(
o for o in static_outputs if o is not None and o.requires_grad
),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
......@@ -719,7 +725,8 @@ def _make_graphed_callables(
):
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
......@@ -727,7 +734,7 @@ def _make_graphed_callables(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
tuple(o for o in static_outputs if o is not None and o.requires_grad),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
......@@ -794,7 +801,7 @@ def _make_graphed_callables(
# Replay forward graph
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
return tuple(o.detach() if o is not None else o for o in static_outputs)
@staticmethod
@torch.autograd.function.once_differentiable
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment