Unverified Commit 11629d52 authored by Y. Xiong's avatar Y. Xiong Committed by GitHub
Browse files

Support resume for fp16 training (#1013)

* fix cast nn.Module bug

* fix misleading doc

* add resume function

* del useless; fix typo

* change meta structure; set resume default
parent 69146fe3
...@@ -358,6 +358,9 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -358,6 +358,9 @@ class BaseRunner(metaclass=ABCMeta):
self.logger.info('the iteration number is changed due to ' self.logger.info('the iteration number is changed due to '
'change of GPU number') 'change of GPU number')
# resume meta information meta
self.meta = checkpoint['meta']
if 'optimizer' in checkpoint and resume_optimizer: if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer): if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer'])
......
...@@ -31,7 +31,9 @@ def cast_tensor_type(inputs, src_type, dst_type): ...@@ -31,7 +31,9 @@ def cast_tensor_type(inputs, src_type, dst_type):
Returns: Returns:
The same type with inputs, but all contained Tensors have been cast. The same type with inputs, but all contained Tensors have been cast.
""" """
if isinstance(inputs, torch.Tensor): if isinstance(inputs, nn.Module):
return inputs
elif isinstance(inputs, torch.Tensor):
return inputs.to(dst_type) return inputs.to(dst_type)
elif isinstance(inputs, str): elif isinstance(inputs, str):
return inputs return inputs
...@@ -376,6 +378,29 @@ class LossScaler: ...@@ -376,6 +378,29 @@ class LossScaler:
self.cur_scale *= self.scale_factor self.cur_scale *= self.scale_factor
self.cur_iter += 1 self.cur_iter += 1
def state_dict(self):
"""Returns the state of the scaler as a :class:`dict`."""
return dict(
cur_scale=self.cur_scale,
cur_iter=self.cur_iter,
mode=self.mode,
last_overflow_iter=self.last_overflow_iter,
scale_factor=self.scale_factor,
scale_window=self.scale_window)
def load_state_dict(self, state_dict):
"""Loads the loss_scaler state dict.
Args:
state_dict (dict): scaler state.
"""
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
self.mode = state_dict['mode']
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
@property @property
def loss_scale(self): def loss_scale(self):
return self.cur_scale return self.cur_scale
...@@ -70,7 +70,7 @@ if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': ...@@ -70,7 +70,7 @@ if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
... backoff_factor=0.5, ... backoff_factor=0.5,
... growth_interval=2000 ... growth_interval=2000
... ) ... )
>>> optimizer = Fp16OptimizerHook(loss_scale=loss_scale) >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
""" """
def __init__(self, def __init__(self,
...@@ -99,6 +99,10 @@ if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': ...@@ -99,6 +99,10 @@ if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
"""Preparing steps before Mixed Precision Training.""" """Preparing steps before Mixed Precision Training."""
# wrap model mode to fp16 # wrap model mode to fp16
wrap_fp16_model(runner.model) wrap_fp16_model(runner.model)
# resume from state dict
if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
scaler_state_dict = runner.meta['fp16']['loss_scaler']
self.loss_scaler.load_state_dict(scaler_state_dict)
def copy_grads_to_fp32(self, fp16_net, fp32_weights): def copy_grads_to_fp32(self, fp16_net, fp32_weights):
"""Copy gradients from fp16 model to fp32 weight copy.""" """Copy gradients from fp16 model to fp32 weight copy."""
...@@ -125,6 +129,7 @@ if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': ...@@ -125,6 +129,7 @@ if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
2. Backward the loss to obtain the gradients. 2. Backward the loss to obtain the gradients.
3. Unscale the optimizer’s gradient tensors. 3. Unscale the optimizer’s gradient tensors.
4. Call optimizer.step() and update scale factor. 4. Call optimizer.step() and update scale factor.
5. Save loss_scaler state_dict for resume purpose.
""" """
# clear grads of last iteration # clear grads of last iteration
runner.model.zero_grad() runner.model.zero_grad()
...@@ -142,6 +147,10 @@ if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': ...@@ -142,6 +147,10 @@ if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
# backward and update scaler # backward and update scaler
self.loss_scaler.step(runner.optimizer) self.loss_scaler.step(runner.optimizer)
self.loss_scaler.update(self._scale_update_param) self.loss_scaler.update(self._scale_update_param)
# save state_dict of loss_scaler
runner.meta.setdefault(
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
else: else:
@HOOKS.register_module() @HOOKS.register_module()
...@@ -210,6 +219,10 @@ else: ...@@ -210,6 +219,10 @@ else:
runner.optimizer.state = state runner.optimizer.state = state
# convert model to fp16 # convert model to fp16
wrap_fp16_model(runner.model) wrap_fp16_model(runner.model)
# resume from state dict
if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
scaler_state_dict = runner.meta['fp16']['loss_scaler']
self.loss_scaler.load_state_dict(scaler_state_dict)
def copy_grads_to_fp32(self, fp16_net, fp32_weights): def copy_grads_to_fp32(self, fp16_net, fp32_weights):
"""Copy gradients from fp16 model to fp32 weight copy.""" """Copy gradients from fp16 model to fp32 weight copy."""
...@@ -236,6 +249,7 @@ else: ...@@ -236,6 +249,7 @@ else:
3. Copy gradients from the model to the fp32 weight copy. 3. Copy gradients from the model to the fp32 weight copy.
4. Scale the gradients back and update the fp32 weight copy. 4. Scale the gradients back and update the fp32 weight copy.
5. Copy back the params from fp32 weight copy to the fp16 model. 5. Copy back the params from fp32 weight copy to the fp16 model.
6. Save loss_scaler state_dict for resume purpose.
""" """
# clear grads of last iteration # clear grads of last iteration
runner.model.zero_grad() runner.model.zero_grad()
...@@ -276,3 +290,7 @@ else: ...@@ -276,3 +290,7 @@ else:
if has_overflow: if has_overflow:
runner.logger.warning('Check overflow, downscale loss scale ' runner.logger.warning('Check overflow, downscale loss scale '
f'to {self.loss_scaler.cur_scale}') f'to {self.loss_scaler.cur_scale}')
# save state_dict of loss_scaler
runner.meta.setdefault(
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment