Unverified Commit 8ebb47e5 authored by Lifu Zhang's avatar Lifu Zhang Committed by GitHub
Browse files

Fix on TE to support Mcore Vision Encoder CUDA Graph (#2657)



* Fix on TE to support Mcore Vision Encoder CUDA Graph
Signed-off-by: default avatarLifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactoring code
Signed-off-by: default avatarLifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarLifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>
Co-authored-by: default avatarLifu Zhang <lifuz@login-lyris02.lyris.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8d152585
...@@ -451,11 +451,12 @@ def _make_graphed_callables( ...@@ -451,11 +451,12 @@ def _make_graphed_callables(
if is_training: if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad) inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs): with _none_grad_context_wrapper(inputs):
outputs_requiring_grad = tuple(
o for o in outputs if o is not None and o.requires_grad
)
torch.autograd.backward( torch.autograd.backward(
tuple(o for o in outputs if o.requires_grad), outputs_requiring_grad,
grad_tensors=tuple( grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad),
torch.empty_like(o) for o in outputs if o.requires_grad
),
) )
grad_inputs = tuple(input.grad for input in inputs) grad_inputs = tuple(input.grad for input in inputs)
...@@ -616,19 +617,22 @@ def _make_graphed_callables( ...@@ -616,19 +617,22 @@ def _make_graphed_callables(
# Note for _reuse_graph_input_output_buffers: grad output is only used # Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time. # within backward, so we can reuse the same static buffers every time.
static_grad_outputs_keys = tuple( static_grad_outputs_keys = tuple(
(o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad (o.shape, o.dtype, o.layout)
for o in static_outputs
if o is not None and o.requires_grad
) )
if static_grad_outputs_keys in static_grad_outputs_dict: if static_grad_outputs_keys in static_grad_outputs_dict:
static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys]
else: else:
static_grad_outputs = tuple( static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs for o in static_outputs
) )
static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs
else: else:
static_grad_outputs = tuple( static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
) )
if is_training: if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad) inputs = tuple(i for i in static_input_surface if i.requires_grad)
...@@ -636,7 +640,9 @@ def _make_graphed_callables( ...@@ -636,7 +640,9 @@ def _make_graphed_callables(
bwd_graph, pool=mempool bwd_graph, pool=mempool
): ):
torch.autograd.backward( torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad), tuple(
o for o in static_outputs if o is not None and o.requires_grad
),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None), grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
...@@ -719,7 +725,8 @@ def _make_graphed_callables( ...@@ -719,7 +725,8 @@ def _make_graphed_callables(
): ):
# For now, assumes all static_outputs require grad # For now, assumes all static_outputs require grad
static_grad_outputs = tuple( static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
) )
if is_training: if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad) inputs = tuple(i for i in static_input_surface if i.requires_grad)
...@@ -727,7 +734,7 @@ def _make_graphed_callables( ...@@ -727,7 +734,7 @@ def _make_graphed_callables(
bwd_graph, pool=mempool bwd_graph, pool=mempool
): ):
torch.autograd.backward( torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad), tuple(o for o in static_outputs if o is not None and o.requires_grad),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None), grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
...@@ -794,7 +801,7 @@ def _make_graphed_callables( ...@@ -794,7 +801,7 @@ def _make_graphed_callables(
# Replay forward graph # Replay forward graph
fwd_graph.replay() fwd_graph.replay()
assert isinstance(static_outputs, tuple) assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs) return tuple(o.detach() if o is not None else o for o in static_outputs)
@staticmethod @staticmethod
@torch.autograd.function.once_differentiable @torch.autograd.function.once_differentiable
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment