Commit f32a0a63 authored by Carl Case's avatar Carl Case
Browse files

Don't touch gpu tensors in functional API

parent 378ce1e1
......@@ -49,7 +49,7 @@ def maybe_half(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_half(y) for y in x])
if type_string(x) == 'HalfTensor':
if not x.is_cuda or type_string(x) == 'HalfTensor':
return x
else:
if verbose:
......@@ -60,7 +60,7 @@ def maybe_float(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_float(y) for y in x])
if type_string(x) == 'FloatTensor':
if not x.is_cuda or type_string(x) == 'FloatTensor':
return x
else:
if verbose:
......
......@@ -70,7 +70,7 @@ def sequence_promote(mod, fn, verbose=False):
seq, {})
return orig_fn(cast_seq, *args, **kwargs)
else:
# TODO: other mixed-type cases aren't due to autohalf.
# TODO: other mixed-type cases aren't due to amp.
# Just pass through?
return orig_fn(seq, *args, **kwargs)
utils.set_func(mod, fn, wrapper)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment