Commit 03b0eeb8 authored by Michael Carilli's avatar Michael Carilli
Browse files

Only warn once in LossScaler constructor

parent a153c41a
...@@ -31,13 +31,15 @@ class LossScaler(object): ...@@ -31,13 +31,15 @@ class LossScaler(object):
try: try:
import amp_C import amp_C
LossScaler.has_fused_kernel = True LossScaler.has_fused_kernel = True
LossScaler.scale_check_overflow = amp_C.scale_check_overflow LossScaler.scale_check_overflow_cuda = amp_C.scale_check_overflow
self._overflow_buf = torch.cuda.ByteTensor(1024,) self._overflow_buf = torch.cuda.ByteTensor(1024,)
except ImportError as err: except ImportError as err:
print("Warning: Amp fused downscale kernel is unavailable, possibly because apex " if not LossScaler.warned_no_fused_kernel:
"was installed without --cuda_ext. Using Python fallback. ImportError was: ", err) print("Warning: Amp fused downscale kernel is unavailable, possibly because apex "
"was installed without --cuda_ext. Using Python fallback. ImportError was: ",
err)
LossScaler.has_fused_kernel = False LossScaler.has_fused_kernel = False
LossScaler.scale_check_overflow = scale_check_overflow_python LossScaler.warned_no_fused_kernel = True
def loss_scale(self): def loss_scale(self):
return self._loss_scale return self._loss_scale
...@@ -49,19 +51,19 @@ class LossScaler(object): ...@@ -49,19 +51,19 @@ class LossScaler(object):
for p in iter_params(param_groups): for p in iter_params(param_groups):
if p.grad is not None: if p.grad is not None:
if LossScaler.has_fused_kernel and p.grad.data.type() == "torch.cuda.FloatTensor": if LossScaler.has_fused_kernel and p.grad.data.type() == "torch.cuda.FloatTensor":
LossScaler.scale_check_overflow(p.grad.data, LossScaler.scale_check_overflow_cuda(p.grad.data,
1. / scale, 1./scale,
self._overflow_buf) self._overflow_buf)
else: else:
if p.grad.data.type() != "torch.cuda.FloatTensor" if (p.grad.data.type() != "torch.cuda.FloatTensor"
and not LossScaler.warned_fp16_grad: and not LossScaler.warned_fp16_grad):
logger = logging.getLogger("apex.amp") logger = logging.getLogger("apex.amp")
logger.warning("Incoming grads are not fp32 (not master grads). " logger.warning("Incoming grads are not fp32 (not master grads). "
"Downscaling non-fp32 grads may indicate an error. " "Downscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.") "When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_fp16_grad = True LossScaler.warned_fp16_grad = True
self._has_overflow = LossScaler.scale_check_overflow(p.grad.data, self._has_overflow = scale_check_overflow_python(p.grad.data,
1. / scale) 1./scale)
if self._has_overflow: if self._has_overflow:
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment