Unverified Commit 9c436d53 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix saved_tensors access in Ops Fuser (#1807)



fix saved_tensors
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent 097afc00
...@@ -216,8 +216,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -216,8 +216,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs = func_ctx.basic_op_ctxs basic_op_ctxs = func_ctx.basic_op_ctxs
# Unflatten list of saved tensors # Unflatten list of saved tensors
saved_tensors = func_ctx.saved_tensors
for ctx in basic_op_ctxs: for ctx in basic_op_ctxs:
ctx.saved_tensors = func_ctx.saved_tensors[slice(*ctx._saved_tensors_range)] ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None ctx._saved_tensors_range = None
# Unflatten list of extra tensor output grads # Unflatten list of extra tensor output grads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment