Commit 56de058f authored by George Papandreou's avatar George Papandreou
Browse files

replace print with warning for cpu tensors and allow numpy arrays as inputs to forward function

parent 74c06d87
import torch
from torch._six import string_classes
import functools
import numpy as np
import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts
from .scaler import LossScaler
......@@ -15,7 +17,7 @@ def to_type(dtype, t):
if isinstance(t, torch.Tensor):
if not t.is_cuda:
# This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ")
warnings.warn("An input tensor was not cuda.")
# GANs require this.
# if t.requires_grad:
# warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
......@@ -34,6 +36,8 @@ def applier(value, fn):
return fn(value)
elif isinstance(value, string_classes):
return value
elif isinstance(value, np.ndarray):
return value
elif isinstance(value, container_abcs.Mapping):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment