Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ee69ab64
Unverified
Commit
ee69ab64
authored
Mar 19, 2019
by
mcarilli
Committed by
GitHub
Mar 19, 2019
Browse files
Merge pull request #207 from arielai/master
More permissive inputs to forward function
parents
ac7dbf67
56de058f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
1 deletion
+5
-1
apex/amp/_initialize.py
apex/amp/_initialize.py
+5
-1
No files found.
apex/amp/_initialize.py
View file @
ee69ab64
import
torch
import
torch
from
torch._six
import
string_classes
from
torch._six
import
string_classes
import
functools
import
functools
import
numpy
as
np
import
warnings
from
._amp_state
import
_amp_state
,
warn_or_err
,
container_abcs
from
._amp_state
import
_amp_state
,
warn_or_err
,
container_abcs
from
.handle
import
disable_casts
from
.handle
import
disable_casts
from
.scaler
import
LossScaler
from
.scaler
import
LossScaler
...
@@ -15,7 +17,7 @@ def to_type(dtype, t):
...
@@ -15,7 +17,7 @@ def to_type(dtype, t):
if
isinstance
(
t
,
torch
.
Tensor
):
if
isinstance
(
t
,
torch
.
Tensor
):
if
not
t
.
is_cuda
:
if
not
t
.
is_cuda
:
# This should not be a hard error, since it may be legitimate.
# This should not be a hard error, since it may be legitimate.
print
(
"Warning:
An input tensor was not cuda.
"
)
warnings
.
warn
(
"
An input tensor was not cuda."
)
# GANs require this.
# GANs require this.
# if t.requires_grad:
# if t.requires_grad:
# warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
# warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
...
@@ -34,6 +36,8 @@ def applier(value, fn):
...
@@ -34,6 +36,8 @@ def applier(value, fn):
return
fn
(
value
)
return
fn
(
value
)
elif
isinstance
(
value
,
string_classes
):
elif
isinstance
(
value
,
string_classes
):
return
value
return
value
elif
isinstance
(
value
,
np
.
ndarray
):
return
value
elif
isinstance
(
value
,
container_abcs
.
Mapping
):
elif
isinstance
(
value
,
container_abcs
.
Mapping
):
return
{
applier
(
k
,
fn
)
:
applier
(
v
,
fn
)
for
k
,
v
in
value
.
items
()}
return
{
applier
(
k
,
fn
)
:
applier
(
v
,
fn
)
for
k
,
v
in
value
.
items
()}
elif
isinstance
(
value
,
container_abcs
.
Iterable
):
elif
isinstance
(
value
,
container_abcs
.
Iterable
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment