Commit a153c41a authored by Michael Carilli's avatar Michael Carilli
Browse files

FP16 grad downscale (which shouldn't happen in user code) fallback + warning

parent fd03f26a
import torch import torch
import logging
# from apex_C import scale_check_overflow # from apex_C import scale_check_overflow
...@@ -18,6 +19,7 @@ def scale_check_overflow_python(d_grads, scale): ...@@ -18,6 +19,7 @@ def scale_check_overflow_python(d_grads, scale):
class LossScaler(object): class LossScaler(object):
warned_no_fused_kernel = False warned_no_fused_kernel = False
warned_fp16_grad = False
has_fused_kernel = False has_fused_kernel = False
def __init__(self): def __init__(self):
...@@ -46,18 +48,25 @@ class LossScaler(object): ...@@ -46,18 +48,25 @@ class LossScaler(object):
self._has_overflow = False self._has_overflow = False
for p in iter_params(param_groups): for p in iter_params(param_groups):
if p.grad is not None: if p.grad is not None:
if LossScaler.has_fused_kernel: if LossScaler.has_fused_kernel and p.grad.data.type() == "torch.cuda.FloatTensor":
LossScaler.scale_check_overflow(p.grad.data, LossScaler.scale_check_overflow(p.grad.data,
1. / scale, 1. / scale,
self._overflow_buf) self._overflow_buf)
else: else:
if p.grad.data.type() != "torch.cuda.FloatTensor"
and not LossScaler.warned_fp16_grad:
logger = logging.getLogger("apex.amp")
logger.warning("Incoming grads are not fp32 (not master grads). "
"Downscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_fp16_grad = True
self._has_overflow = LossScaler.scale_check_overflow(p.grad.data, self._has_overflow = LossScaler.scale_check_overflow(p.grad.data,
1. / scale) 1. / scale)
if self._has_overflow: if self._has_overflow:
break break
# If the fused kernel is available, we only need one D2H memcopy and sync. # If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel: if LossScaler.has_fused_kernel and not self._has_overflow:
self._has_overflow = self._overflow_buf.any() self._has_overflow = self._overflow_buf.any()
if self._has_overflow: if self._has_overflow:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment