Commit d52edb9e authored by Michael Carilli's avatar Michael Carilli
Browse files

FP16_Optimizer + dynamic loss scaling now works with optimizers that require closures, e.g. LBFGS

parent 6e39bee3
...@@ -382,9 +382,6 @@ class FP16_Optimizer(object): ...@@ -382,9 +382,6 @@ class FP16_Optimizer(object):
.. _`ordinary Pytorch optimizer use`: .. _`ordinary Pytorch optimizer use`:
http://pytorch.org/docs/master/optim.html#optimizer-step-closure http://pytorch.org/docs/master/optim.html#optimizer-step-closure
""" """
if closure is not None and isinstance(self.loss_scaler, DynamicLossScaler):
raise TypeError("Using step with a closure is currently not "
"compatible with dynamic loss scaling.")
scale = self.loss_scaler.loss_scale scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow) self._update_scale(self.overflow)
...@@ -405,6 +402,9 @@ class FP16_Optimizer(object): ...@@ -405,6 +402,9 @@ class FP16_Optimizer(object):
def _step_with_closure(self, closure): def _step_with_closure(self, closure):
def wrapped_closure(): def wrapped_closure():
# helpful for debugging
# print("Calling wrapped_closure, first_closure_call_this_step = {}"
# .format(self.first_closure_call_this_step))
if self.first_closure_call_this_step: if self.first_closure_call_this_step:
# We expect that the fp16 params are initially fresh on entering self.step(), # We expect that the fp16 params are initially fresh on entering self.step(),
# so _master_params_to_model_params() is unnecessary the first time wrapped_closure() # so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
...@@ -425,6 +425,13 @@ class FP16_Optimizer(object): ...@@ -425,6 +425,13 @@ class FP16_Optimizer(object):
# for the optimizer to play with, so all wrapped_closure needs to do is call # for the optimizer to play with, so all wrapped_closure needs to do is call
# closure() and return the loss. # closure() and return the loss.
temp_loss = closure() temp_loss = closure()
while(self.overflow):
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
if self.overflow:
print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(scale, self.loss_scale))
temp_loss = closure()
return temp_loss return temp_loss
retval = self.optimizer.step(wrapped_closure) retval = self.optimizer.step(wrapped_closure)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment