Commit dec4fdd6 authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

Enable Checkpointing (#420)

* add state_dict, load_state_dict

* add test_restoring, test_loss_scale_decrease

* disable amp outputs for checkpoint tests

* add test for amp.state_dict, cleanup

* add state_dict patch, add test

* fixed testing, cleanup

* add readme for checkpointing

* add docs to source/amp

* add review changes to doc
parent 30ed793e
...@@ -54,6 +54,45 @@ global batch size across all processes (which, technically, is the correct ...@@ -54,6 +54,45 @@ global batch size across all processes (which, technically, is the correct
formulation). formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models. Synchronous BN has been observed to improve converged accuracy in some of our research models.
### Checkpointing
To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
as well as `amp.load_state_dict()` to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow:
```python
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
...
```
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
# Requirements # Requirements
Python 3 Python 3
......
from .amp import init, half_function, float_function, promote_function,\ from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function register_half_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts from .handle import scale_loss, disable_casts
from .frontend import initialize from .frontend import initialize, state_dict, load_state_dict
from ._amp_state import master_params, _amp_state from ._amp_state import master_params, _amp_state
...@@ -126,6 +126,18 @@ def check_optimizers(optimizers): ...@@ -126,6 +126,18 @@ def check_optimizers(optimizers):
"optimizers.\n") "optimizers.\n")
class O2StateDictHook(object):
def __init__(self, fn):
self.fn = fn
def __call__(self, module, state_dict, prefix, local_metadata):
for key in state_dict:
param = state_dict[key]
if 'Half' in param.type():
param = param.to(torch.float32)
state_dict[key] = param
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None): def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init from .amp import init as amp_init
...@@ -188,6 +200,12 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -188,6 +200,12 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
# State dict trick to recast any preexisting per-param state tensors # State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers: for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict()) optimizer.load_state_dict(optimizer.state_dict())
# patch model.state_dict() to return float32 params
for model in models:
for module in model.modules():
module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))
elif cast_model_outputs is not None: elif cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs) output_caster = functools.partial(to_type, cast_model_outputs)
......
import torch import torch
from ._initialize import _initialize from ._initialize import _initialize
from ._amp_state import _amp_state, warn_or_err, maybe_print from ._amp_state import _amp_state, warn_or_err, maybe_print
from collections import OrderedDict
class Properties(object): class Properties(object):
...@@ -357,6 +358,48 @@ def initialize( ...@@ -357,6 +358,48 @@ def initialize(
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs) return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
def state_dict(destination=None):
if destination is None:
destination = OrderedDict()
for idx, loss_scaler in enumerate(_amp_state.loss_scalers):
destination['loss_scaler%d' % idx] = {
'loss_scale': loss_scaler.loss_scale(),
'unskipped': loss_scaler._unskipped,
}
return destination
def load_state_dict(state_dict):
# Check if state_dict containes the same number of loss_scalers as current setup
if len(state_dict) != len(_amp_state.loss_scalers):
print('Warning: state_dict contains {} entries, while {} loss_scalers are used'.format(
len(state_dict), len(_amp_state.loss_scalers)))
state_dict = state_dict.copy()
nb_loss_scalers = len(_amp_state.loss_scalers)
unexpected_keys = []
# Initialize idx outside, since unexpected_keys will increase it if enumerate is used
idx = 0
for key in state_dict:
if 'loss_scaler' not in key:
unexpected_keys.append(key)
else:
if idx > (nb_loss_scalers - 1):
print('Skipping loss_scaler[{}], since num_losses was set to {}'.format(
idx, nb_loss_scalers))
break
_amp_state.loss_scalers[idx]._loss_scale = state_dict[key]['loss_scale']
_amp_state.loss_scalers[idx]._unskipped = state_dict[key]['unskipped']
idx += 1
if len(unexpected_keys) > 0:
raise RuntimeError(
'Error(s) in loading state_dict. Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
# TODO: is this necessary/useful? # TODO: is this necessary/useful?
# def check_option_consistency(enabled=True, # def check_option_consistency(enabled=True,
# opt_level=None, # opt_level=None,
......
...@@ -177,6 +177,44 @@ Unified API ...@@ -177,6 +177,44 @@ Unified API
.. autofunction:: master_params .. autofunction:: master_params
Checkpointing
-------------
To properly save and load your amp training, we introduce the ``amp.state_dict()``, which contains all ``loss_scaler``\ s and their corresponding unskipped steps, as well as ``amp.load_state_dict()`` to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow::
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
...
Note that we recommend restoring the model using the same ``opt_level``. Also note that we recommend calling the ``load_state_dict`` methods after ``amp.initialize``.
Advanced use cases Advanced use cases
------------------ ------------------
......
import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from apex import amp
from utils import common_init, FLOAT
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(6)
self.param = nn.Parameter(torch.randn(1))
def forward(self, x):
x = x * self.param
x = F.relu(self.conv1(x))
x = self.bn1(x)
return x
class TestCheckpointing(unittest.TestCase):
def setUp(self):
self.initial_lr = 1e-3
self.test_opt_levels = ("O0", "O1", "O2", "O3")
def seed(self):
torch.manual_seed(2809)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def check_state_dict_fp32(self, state_dict):
for key in state_dict:
if 'num_batches_tracked' in key:
continue
param = state_dict[key]
self.assertEqual(param.type(), FLOAT,
'Parameter in state_dict not FLOAT')
def train_step(self, model, optimizer, data, loss_ids):
optimizer.zero_grad()
output = model(data)
# Call backward for num_losses-1
for idx in loss_ids:
loss = output.mean()
with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
scaled_loss.backward(retain_graph=True)
optimizer.step()
return output
def compare_models(self, modelA, modelB):
state_dictA = modelA.state_dict()
state_dictB = modelB.state_dict()
self.assertEqual(len(state_dictA), len(state_dictB),
'state_dicts have different lengths')
for key in state_dictA:
paramA = state_dictA[key]
paramB = state_dictB[key]
self.assertTrue(torch.allclose(paramA.float(), paramB.float(), rtol=0, atol=1e-4),
msg='Parameters in state_dicts not equal.' +
'key: {}\nparam: {}\nrestored: {}\ndiff: {}'.format(
key, paramA, paramB, paramA - paramB))
def test_restoring(self):
nb_epochs = 10
nb_epochs_restore = nb_epochs // 2
for opt_level in self.test_opt_levels:
for res_opt_level in self.test_opt_levels:
for amp_before_load in [True, False]:
for num_losses in range(1, 3):
# print('#' * 75 + '\n' + \
# f'opt_level {opt_level}\n' + \
# f'restore_opt_level {res_opt_level}\n' + \
# f'amp_before_load {amp_before_load}\n' + \
# f'num_losses {num_losses}\n')
self.seed()
# Create reference model
model = MyModel().to('cuda')
optimizer = optim.SGD(model.parameters(),
lr=self.initial_lr)
# Initialize with num_losses*2 for the original model and the restored one
model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level,
num_losses=num_losses*2, verbosity=0)
# Compare training behavior for same restore option
# We cannot really generalize it, since a saved model in O0
# would introduce a skipped step in O1, which will raise an error
if opt_level == res_opt_level:
# train for nb_epochs and restore after nb_epochs_restore
for epoch in range(nb_epochs):
x = torch.randn(16, 3, 24, 24, device='cuda')
output = self.train_step(
model, optimizer, x, range(num_losses))
# Initialize model one step before comparing.
# Otherwise the batchnorm layers will be updated
# additionally in restore_model
if epoch == (nb_epochs_restore - 1):
# Load model and optimizer
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
# Check state_dict for FP32 tensors
self.check_state_dict_fp32(checkpoint['model'])
# Restore model
restore_model = MyModel().to('cuda')
restore_optimizer = optim.SGD(
restore_model.parameters(),
lr=self.initial_lr)
if amp_before_load:
restore_model, restore_optimizer = amp.initialize(
restore_model,
restore_optimizer,
opt_level=res_opt_level,
num_losses=num_losses*2,
verbosity=0)
restore_model.load_state_dict(checkpoint['model'])
restore_optimizer.load_state_dict(checkpoint['optimizer'])
# FIXME: We cannot test the amp.state_dict in the same script
# amp.load_state_dict(checkpoint['amp'])
if not amp_before_load:
restore_model, restore_optimizer = amp.initialize(
restore_model,
restore_optimizer,
opt_level=res_opt_level,
num_losses=num_losses*2,
verbosity=0)
elif epoch >= nb_epochs_restore:
restore_output = self.train_step(
restore_model,
restore_optimizer,
x,
range(num_losses, num_losses*2))
self.assertTrue(
torch.allclose(output.float(), restore_output.float()),
'Output of reference and restored models differ')
self.compare_models(model, restore_model)
# if opt_level != res_opt_level
else:
# Only check state_dict
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
# Check state_dict for FP32 tensors
self.check_state_dict_fp32(checkpoint['model'])
# Restore model
restore_model = MyModel().to('cuda')
restore_optimizer = optim.SGD(
restore_model.parameters(),
lr=self.initial_lr)
if amp_before_load:
restore_model, restore_optimizer = amp.initialize(
restore_model,
restore_optimizer,
opt_level=res_opt_level,
num_losses=num_losses,
verbosity=0)
restore_model.load_state_dict(checkpoint['model'])
restore_optimizer.load_state_dict(checkpoint['optimizer'])
# FIXME: We cannot test the amp.state_dict in the same script
# amp.load_state_dict(checkpoint['amp'])
if not amp_before_load:
restore_model, restore_optimizer = amp.initialize(
restore_model,
restore_optimizer,
opt_level=res_opt_level,
num_losses=num_losses,
verbosity=0)
self.compare_models(model, restore_model)
def test_loss_scale_decrease(self):
num_losses = 3
nb_decrease_loss_scales = [0, 1, 2]
for opt_level in self.test_opt_levels:
#print('#' * 75 + f'\n opt_level {opt_level}\n')
# Create new tmp copy for this run
nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales)
model = MyModel().to('cuda')
optimizer = optim.SGD(model.parameters(),
lr=1e-3)#self.initial_lr)
model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level, num_losses=num_losses,
verbosity=0)
if amp._amp_state.opt_properties.loss_scale != 'dynamic':
#print('Static loss scale set. Skipping opt_level.')
continue
# force to skip some updates to decrease the loss_scale
initial_loss_scales = []
for idx in range(num_losses):
initial_loss_scales.append(
amp._amp_state.loss_scalers[idx].loss_scale())
for _ in range(len(nb_decrease_loss_scales)):
x = torch.randn(16, 3, 24, 24, device='cuda')
for idx in range(num_losses):
while nb_decrease_loss_scales_tmp[idx] > 0:
optimizer.zero_grad()
output = model(x * 2**17)
loss = output.mean()
with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
scaled_loss.backward(retain_graph=True)
optimizer.step()
nb_decrease_loss_scales_tmp[idx] -= 1
# Check loss scales afterwards
updated_loss_scales = []
for idx in range(num_losses):
updated_loss_scales.append(
amp._amp_state.loss_scalers[idx].loss_scale())
for factor, update_ls, init_ls in zip(nb_decrease_loss_scales,
updated_loss_scales,
initial_loss_scales):
self.assertEqual(update_ls, init_ls / 2**factor)
# Check state dict
amp_state_dict = amp.state_dict()
for scaler_idx, factor, init_ls in zip(amp_state_dict,
nb_decrease_loss_scales,
initial_loss_scales):
scaler = amp_state_dict[scaler_idx]
self.assertEqual(scaler['loss_scale'], init_ls / 2**factor)
unskipped_target = 0
self.assertEqual(scaler['unskipped'], unskipped_target)
def test_state_dict(self):
for opt_level in self.test_opt_levels:
# Skip O3
if opt_level == 'O3':
continue
model = MyModel().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level, verbosity=0)
# Export state_dict and check for Half
state_dict = model.state_dict()
for key in state_dict:
self.assertFalse('Half' in state_dict[key].type())
# Check, if model is still trainable
# Create dummy data
data = torch.randn(10, 3, 4, 4, device='cuda')
target = torch.randn(10, 6, 4, 4, device='cuda')
# Get initnial loss
optimizer.zero_grad()
output = model(data)
loss = F.mse_loss(output, target)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
last_loss = loss.item()
# train for some epochs
for epoch in range(10):
optimizer.zero_grad()
output = model(data)
loss = F.mse_loss(output, target)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
self.assertTrue(loss.item() < last_loss)
last_loss = loss.item()
if __name__=='__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment