Unverified Commit ee69ab64 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #207 from arielai/master

More permissive inputs to forward function
parents ac7dbf67 56de058f
import torch import torch
from torch._six import string_classes from torch._six import string_classes
import functools import functools
import numpy as np
import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts from .handle import disable_casts
from .scaler import LossScaler from .scaler import LossScaler
...@@ -15,7 +17,7 @@ def to_type(dtype, t): ...@@ -15,7 +17,7 @@ def to_type(dtype, t):
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
if not t.is_cuda: if not t.is_cuda:
# This should not be a hard error, since it may be legitimate. # This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ") warnings.warn("An input tensor was not cuda.")
# GANs require this. # GANs require this.
# if t.requires_grad: # if t.requires_grad:
# warn_or_err("input data requires grad. Since input data is not a model parameter,\n" # warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
...@@ -34,6 +36,8 @@ def applier(value, fn): ...@@ -34,6 +36,8 @@ def applier(value, fn):
return fn(value) return fn(value)
elif isinstance(value, string_classes): elif isinstance(value, string_classes):
return value return value
elif isinstance(value, np.ndarray):
return value
elif isinstance(value, container_abcs.Mapping): elif isinstance(value, container_abcs.Mapping):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()} return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable): elif isinstance(value, container_abcs.Iterable):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment