Commit 58dd1862 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix resuming from FP16 checkpoints (#424)

Summary:
This was broken in 03a57dec.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/424

Differential Revision: D13557540

Pulled By: myleott

fbshipit-source-id: 62deda5353032aff20d35d046b0bb843da44d27c
parent 31a43973
...@@ -63,6 +63,20 @@ class ConvertToFP32(object): ...@@ -63,6 +63,20 @@ class ConvertToFP32(object):
self.params = params self.params = params
self.itr = map(convert_to_fp32, params) self.itr = map(convert_to_fp32, params)
@staticmethod
def wrap_optimizer_(optimizer):
for group in optimizer.param_groups:
group['params'] = ConvertToFP32(group['params'])
@staticmethod
def unwrap_optimizer_(optimizer):
for group in optimizer.param_groups:
group['params'] = group['params'].params # unwrap from ConvertToFP32
for p in group['params']:
p.data = p.data.half()
if p.grad is not None:
p.grad.data = p.grad.data.half()
def __len__(self): def __len__(self):
return len(self.params) return len(self.params)
...@@ -145,7 +159,9 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -145,7 +159,9 @@ class FP16Optimizer(optim.FairseqOptimizer):
""" """
if 'loss_scale' in state_dict: if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale'] self.scaler.loss_scale = state_dict['loss_scale']
ConvertToFP32.wrap_optimizer_(self.wrapped_optimizer.optimizer)
self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides)
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
def backward(self, loss): def backward(self, loss):
loss = loss * self.scaler.loss_scale loss = loss * self.scaler.loss_scale
...@@ -194,18 +210,12 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -194,18 +210,12 @@ class FP16Optimizer(optim.FairseqOptimizer):
self._unscale_grads() self._unscale_grads()
# convert params and grads to FP32 (lazily) # convert params and grads to FP32 (lazily)
for group in self.wrapped_optimizer.optimizer.param_groups: ConvertToFP32.wrap_optimizer_(self.wrapped_optimizer.optimizer)
group['params'] = ConvertToFP32(group['params'])
self.wrapped_optimizer.step(closure) self.wrapped_optimizer.step(closure)
# convert params back to FP16 # convert params back to FP16
for group in self.wrapped_optimizer.optimizer.param_groups: ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
group['params'] = group['params'].params # unwrap from ConvertToFP32
for p in group['params']:
p.data = p.data.half()
if p.grad is not None:
p.grad.data = p.grad.data.half()
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
......
...@@ -114,8 +114,9 @@ class Trainer(object): ...@@ -114,8 +114,9 @@ class Trainer(object):
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = \ extra_state, self._optim_history, last_optim_state = utils.load_model_state(
utils.load_model_state(filename, self.get_model()) filename, self.get_model(),
)
if last_optim_state is not None and not reset_optimizer: if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
self._build_optimizer() self._build_optimizer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment