Unverified Commit 82b02233 authored by Zach Teed's avatar Zach Teed Committed by GitHub
Browse files

Merge pull request #7 from mli0603/master

fixed grad check for pytorch 1.9
parents 91887c3b 9d8130be
import torch import torch
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
from torch.types import _TensorOrTensors from torch.types import _TensorOrTensors
from torch._six import container_abcs, istuple if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
from torch._six import container_abcs, istuple
else:
import collections.abc as container_abcs
import torch.testing import torch.testing
from torch.overrides import is_tensor_like from torch.overrides import is_tensor_like
from itertools import product from itertools import product
...@@ -203,12 +211,18 @@ def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0): ...@@ -203,12 +211,18 @@ def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0):
def _as_tuple(x): def _as_tuple(x):
if istuple(x): if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
b_tuple = istuple(x)
else:
b_tuple = isinstance(x, tuple)
if b_tuple:
return x return x
elif isinstance(x, list): elif isinstance(x, list):
return tuple(x) return tuple(x)
else: else:
return x, return x,
def _differentiable_outputs(x): def _differentiable_outputs(x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment