Commit d52edb9e authored by Michael Carilli's avatar Michael Carilli
Browse files

FP16_Optimizer + dynamic loss scaling now works with optimizers that require closures, e.g. LBFGS

parent 6e39bee3
......@@ -382,9 +382,6 @@ class FP16_Optimizer(object):
.. _`ordinary Pytorch optimizer use`:
http://pytorch.org/docs/master/optim.html#optimizer-step-closure
"""
if closure is not None and isinstance(self.loss_scaler, DynamicLossScaler):
raise TypeError("Using step with a closure is currently not "
"compatible with dynamic loss scaling.")
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
......@@ -405,6 +402,9 @@ class FP16_Optimizer(object):
def _step_with_closure(self, closure):
def wrapped_closure():
# helpful for debugging
# print("Calling wrapped_closure, first_closure_call_this_step = {}"
# .format(self.first_closure_call_this_step))
if self.first_closure_call_this_step:
# We expect that the fp16 params are initially fresh on entering self.step(),
# so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
......@@ -425,6 +425,13 @@ class FP16_Optimizer(object):
# for the optimizer to play with, so all wrapped_closure needs to do is call
# closure() and return the loss.
temp_loss = closure()
while(self.overflow):
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
if self.overflow:
print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(scale, self.loss_scale))
temp_loss = closure()
return temp_loss
retval = self.optimizer.step(wrapped_closure)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment