Unverified Commit 1737ce1e authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #4 from NVIDIA/amp_compat_fix

Fix compatibility checks for 18.04 container
parents ee117aa8 9ce3a33d
......@@ -5,6 +5,10 @@ def variable_is_tensor():
v = torch.autograd.Variable()
return isinstance(v, torch.Tensor)
def tensor_is_variable():
x = torch.Tensor()
return type(x) == torch.autograd.Variable
# False for post-0.4
def tensor_is_float_tensor():
x = torch.Tensor()
......
......@@ -5,7 +5,7 @@ import importlib
import torch
if compat.variable_is_tensor():
if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE = torch.Tensor
else:
MODULE = torch.autograd.Variable
......
......@@ -8,8 +8,6 @@ import torch
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
# Should happen only pre-0.4
assert not compat.variable_is_tensor()
return
orig_fn = utils.get_func(mod, fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment