Commit 56de058f authored by George Papandreou's avatar George Papandreou
Browse files

replace print with warning for cpu tensors and allow numpy arrays as inputs to forward function

parent 74c06d87
import torch import torch
from torch._six import string_classes from torch._six import string_classes
import functools import functools
import numpy as np
import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts from .handle import disable_casts
from .scaler import LossScaler from .scaler import LossScaler
...@@ -15,7 +17,7 @@ def to_type(dtype, t): ...@@ -15,7 +17,7 @@ def to_type(dtype, t):
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
if not t.is_cuda: if not t.is_cuda:
# This should not be a hard error, since it may be legitimate. # This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ") warnings.warn("An input tensor was not cuda.")
# GANs require this. # GANs require this.
# if t.requires_grad: # if t.requires_grad:
# warn_or_err("input data requires grad. Since input data is not a model parameter,\n" # warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
...@@ -34,6 +36,8 @@ def applier(value, fn): ...@@ -34,6 +36,8 @@ def applier(value, fn):
return fn(value) return fn(value)
elif isinstance(value, string_classes): elif isinstance(value, string_classes):
return value return value
elif isinstance(value, np.ndarray):
return value
elif isinstance(value, container_abcs.Mapping): elif isinstance(value, container_abcs.Mapping):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()} return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable): elif isinstance(value, container_abcs.Iterable):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment