Unverified Commit 96c4daab authored by Chunyang Wen's avatar Chunyang Wen Committed by GitHub
Browse files

minor refactor loss scaler (#261)

parent f5025506
Pipeline #203 failed with stages
in 0 seconds
...@@ -31,7 +31,29 @@ def to_python_float(t): ...@@ -31,7 +31,29 @@ def to_python_float(t):
return t[0] return t[0]
class LossScaler: class LossScalerBase:
"""LossScalarBase
Base class for a loss scaler
"""
def __init__(self, cur_scale):
self.cur_scale = cur_scale
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def update_scale(self, overflow):
pass
def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
class LossScaler(LossScalerBase):
""" """
Class that manages a static loss scale. This class is intended to interact with Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user. :class:`FP16_Optimizer`, and should not be directly manipulated by the user.
...@@ -43,7 +65,7 @@ class LossScaler: ...@@ -43,7 +65,7 @@ class LossScaler:
scale (float, optional, default=1.0): The loss scale. scale (float, optional, default=1.0): The loss scale.
""" """
def __init__(self, scale=1): def __init__(self, scale=1):
self.cur_scale = scale super(LossScaler, self).__init__(scale)
# `params` is a list / generator of torch.Variable # `params` is a list / generator of torch.Variable
def has_overflow(self, params): def has_overflow(self, params):
...@@ -53,22 +75,8 @@ class LossScaler: ...@@ -53,22 +75,8 @@ class LossScaler:
def _has_inf_or_nan(x): def _has_inf_or_nan(x):
return False return False
def update_scale(self, overflow):
pass
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False): class DynamicLossScaler(LossScalerBase):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler:
""" """
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
...@@ -100,7 +108,7 @@ class DynamicLossScaler: ...@@ -100,7 +108,7 @@ class DynamicLossScaler:
min_scale=1, min_scale=1,
delayed_shift=1, delayed_shift=1,
consecutive_hysteresis=False): consecutive_hysteresis=False):
self.cur_scale = init_scale super(DynamicLossScaler, self).__init__(init_scale)
self.cur_iter = 0 self.cur_iter = 0
self.last_overflow_iter = -1 self.last_overflow_iter = -1
self.scale_factor = scale_factor self.scale_factor = scale_factor
...@@ -113,7 +121,7 @@ class DynamicLossScaler: ...@@ -113,7 +121,7 @@ class DynamicLossScaler:
# `params` is a list / generator of torch.Variable # `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params): def has_overflow_serial(self, params):
for p in params: for p in params:
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True return True
return False return False
...@@ -135,7 +143,7 @@ class DynamicLossScaler: ...@@ -135,7 +143,7 @@ class DynamicLossScaler:
raise raise
return True return True
else: else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:
return True return True
return False return False
...@@ -157,17 +165,6 @@ class DynamicLossScaler: ...@@ -157,17 +165,6 @@ class DynamicLossScaler:
self.cur_scale *= self.scale_factor self.cur_scale *= self.scale_factor
self.cur_iter += 1 self.cur_iter += 1
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
############################################################## ##############################################################
# Example usage below here -- assuming it's in a separate file # Example usage below here -- assuming it's in a separate file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment