Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
1737ce1e
Unverified
Commit
1737ce1e
authored
May 23, 2018
by
mcarilli
Committed by
GitHub
May 23, 2018
Browse files
Merge pull request #4 from NVIDIA/amp_compat_fix
Fix compatibility checks for 18.04 container
parents
ee117aa8
9ce3a33d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
3 deletions
+5
-3
apex/amp/compat.py
apex/amp/compat.py
+4
-0
apex/amp/lists/tensor_overrides.py
apex/amp/lists/tensor_overrides.py
+1
-1
apex/amp/wrap.py
apex/amp/wrap.py
+0
-2
No files found.
apex/amp/compat.py
View file @
1737ce1e
...
...
@@ -5,6 +5,10 @@ def variable_is_tensor():
v
=
torch
.
autograd
.
Variable
()
return
isinstance
(
v
,
torch
.
Tensor
)
def
tensor_is_variable
():
x
=
torch
.
Tensor
()
return
type
(
x
)
==
torch
.
autograd
.
Variable
# False for post-0.4
def
tensor_is_float_tensor
():
x
=
torch
.
Tensor
()
...
...
apex/amp/lists/tensor_overrides.py
View file @
1737ce1e
...
...
@@ -5,7 +5,7 @@ import importlib
import
torch
if
compat
.
variable_is_tensor
():
if
compat
.
variable_is_tensor
()
and
not
compat
.
tensor_is_variable
()
:
MODULE
=
torch
.
Tensor
else
:
MODULE
=
torch
.
autograd
.
Variable
...
...
apex/amp/wrap.py
View file @
1737ce1e
...
...
@@ -8,8 +8,6 @@ import torch
def
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
=
False
,
verbose
=
False
):
if
not
utils
.
has_func
(
mod
,
fn
):
# Should happen only pre-0.4
assert
not
compat
.
variable_is_tensor
()
return
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment