Commit 9d8130be authored by mli0603's avatar mli0603
Browse files

fixed grad check for pytorch 1.9

parent 91887c3b
import torch
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
from torch.types import _TensorOrTensors
from torch._six import container_abcs, istuple
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
from torch._six import container_abcs, istuple
else:
import collections.abc as container_abcs
import torch.testing
from torch.overrides import is_tensor_like
from itertools import product
......@@ -203,7 +211,12 @@ def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0):
def _as_tuple(x):
if istuple(x):
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
b_tuple = istuple(x)
else:
b_tuple = isinstance(x, tuple)
if b_tuple:
return x
elif isinstance(x, list):
return tuple(x)
......@@ -211,6 +224,7 @@ def _as_tuple(x):
return x,
def _differentiable_outputs(x):
return tuple(o for o in _as_tuple(x) if o.requires_grad)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment