Commit 2063287b authored by Michael Carilli's avatar Michael Carilli
Browse files

Using temporary Python-side inf+nan check in amp/scaler.py

parent ae921de2
...@@ -16,7 +16,7 @@ def scale_check_overflow(d_grads, scale): ...@@ -16,7 +16,7 @@ def scale_check_overflow(d_grads, scale):
return True return True
d_grads.mul_(scale) d_grads.mul_(scale)
return False return False
class LossScaler(object): class LossScaler(object):
def __init__(self): def __init__(self):
self._loss_scale = 2.**16 self._loss_scale = 2.**16
...@@ -36,7 +36,8 @@ class LossScaler(object): ...@@ -36,7 +36,8 @@ class LossScaler(object):
if p.grad is not None: if p.grad is not None:
self._has_overflow = scale_check_overflow(p.grad.data, self._has_overflow = scale_check_overflow(p.grad.data,
1. / scale) 1. / scale)
if self._has_overflow: break if self._has_overflow:
break
# if self._overflow_buf.any(): # if self._overflow_buf.any():
if self._has_overflow: if self._has_overflow:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment