Commit 68c850d3 authored by Michael Carilli's avatar Michael Carilli
Browse files
parent d5e2bb4b
...@@ -284,7 +284,9 @@ def _process_optimizer(optimizer, properties): ...@@ -284,7 +284,9 @@ def _process_optimizer(optimizer, properties):
_master_params_to_model_params, optimizer) _master_params_to_model_params, optimizer)
old_step = optimizer.step old_step = optimizer.step
def new_step(self): def new_step(self, closure=None):
if closure is not None:
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
retval = old_step() retval = old_step()
self._master_params_to_model_params() self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad() # Clear the master grads that wouldn't be zeroed by model.zero_grad()
......
...@@ -136,7 +136,9 @@ def scale_loss(loss, ...@@ -136,7 +136,9 @@ def scale_loss(loss,
# necessary because amp.scale_loss is already creating a temporary scope. # necessary because amp.scale_loss is already creating a temporary scope.
def patch_step(opt, loss_scaler, loss_id): def patch_step(opt, loss_scaler, loss_id):
opt_step = opt.step opt_step = opt.step
def skip_step(): def skip_step(closure=None):
if closure is not None:
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
maybe_print(("Gradient overflow. Skipping step, loss scaler " + maybe_print(("Gradient overflow. Skipping step, loss scaler " +
"{} reducing loss scale to {}").format(loss_id, "{} reducing loss scale to {}").format(loss_id,
loss_scaler.loss_scale())) loss_scaler.loss_scale()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment