"tests/python/test_tokenizer.py" did not exist on "5ea40abf613e47bb56a0c06f695644d55671f585"
Unverified Commit 1737ce1e authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #4 from NVIDIA/amp_compat_fix

Fix compatibility checks for 18.04 container
parents ee117aa8 9ce3a33d
...@@ -5,6 +5,10 @@ def variable_is_tensor(): ...@@ -5,6 +5,10 @@ def variable_is_tensor():
v = torch.autograd.Variable() v = torch.autograd.Variable()
return isinstance(v, torch.Tensor) return isinstance(v, torch.Tensor)
def tensor_is_variable():
x = torch.Tensor()
return type(x) == torch.autograd.Variable
# False for post-0.4 # False for post-0.4
def tensor_is_float_tensor(): def tensor_is_float_tensor():
x = torch.Tensor() x = torch.Tensor()
......
...@@ -5,7 +5,7 @@ import importlib ...@@ -5,7 +5,7 @@ import importlib
import torch import torch
if compat.variable_is_tensor(): if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE = torch.Tensor MODULE = torch.Tensor
else: else:
MODULE = torch.autograd.Variable MODULE = torch.autograd.Variable
......
...@@ -8,8 +8,6 @@ import torch ...@@ -8,8 +8,6 @@ import torch
def cached_cast(mod, fn, cast_fn, handle, def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False): try_caching=False, verbose=False):
if not utils.has_func(mod, fn): if not utils.has_func(mod, fn):
# Should happen only pre-0.4
assert not compat.variable_is_tensor()
return return
orig_fn = utils.get_func(mod, fn) orig_fn = utils.get_func(mod, fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment