Commit cec08a41 authored by rohithkrn's avatar rohithkrn
Browse files

revert to original

parent 3ff2178c
...@@ -124,7 +124,7 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca ...@@ -124,7 +124,7 @@ def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_ca
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist # 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
# methods on FloatTensor, since they're distinct types. # methods on FloatTensor, since they're distinct types.
if compat.tensor_is_float_tensor(): if compat.tensor_is_float_tensor():
for fn in getattr(tensor_overrides, 'FP16_FUNCS'): for fn in tensor_overrides.FP16_FUNCS:
wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half, wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,
handle, try_caching=True, verbose=verbose) handle, try_caching=True, verbose=verbose)
for fn in tensor_overrides.FP32_FUNCS: for fn in tensor_overrides.FP32_FUNCS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment