Commit dec4fdd6 authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

Enable Checkpointing (#420)

* add state_dict, load_state_dict

* add test_restoring, test_loss_scale_decrease

* disable amp outputs for checkpoint tests

* add test for amp.state_dict, cleanup

* add state_dict patch, add test

* fixed testing, cleanup

* add readme for checkpointing

* add docs to source/amp

* add review changes to doc
parent 30ed793e
......@@ -54,6 +54,45 @@ global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.
### Checkpointing
To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
as well as `amp.load_state_dict()` to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow:
```python
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
...
```
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
# Requirements
Python 3
......
from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts
from .frontend import initialize
from .frontend import initialize, state_dict, load_state_dict
from ._amp_state import master_params, _amp_state
......@@ -126,6 +126,18 @@ def check_optimizers(optimizers):
"optimizers.\n")
class O2StateDictHook(object):
def __init__(self, fn):
self.fn = fn
def __call__(self, module, state_dict, prefix, local_metadata):
for key in state_dict:
param = state_dict[key]
if 'Half' in param.type():
param = param.to(torch.float32)
state_dict[key] = param
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init
......@@ -188,6 +200,12 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
# patch model.state_dict() to return float32 params
for model in models:
for module in model.modules():
module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))
elif cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
......
import torch
from ._initialize import _initialize
from ._amp_state import _amp_state, warn_or_err, maybe_print
from collections import OrderedDict
class Properties(object):
......@@ -357,6 +358,48 @@ def initialize(
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
def state_dict(destination=None):
if destination is None:
destination = OrderedDict()
for idx, loss_scaler in enumerate(_amp_state.loss_scalers):
destination['loss_scaler%d' % idx] = {
'loss_scale': loss_scaler.loss_scale(),
'unskipped': loss_scaler._unskipped,
}
return destination
def load_state_dict(state_dict):
# Check if state_dict containes the same number of loss_scalers as current setup
if len(state_dict) != len(_amp_state.loss_scalers):
print('Warning: state_dict contains {} entries, while {} loss_scalers are used'.format(
len(state_dict), len(_amp_state.loss_scalers)))
state_dict = state_dict.copy()
nb_loss_scalers = len(_amp_state.loss_scalers)
unexpected_keys = []
# Initialize idx outside, since unexpected_keys will increase it if enumerate is used
idx = 0
for key in state_dict:
if 'loss_scaler' not in key:
unexpected_keys.append(key)
else:
if idx > (nb_loss_scalers - 1):
print('Skipping loss_scaler[{}], since num_losses was set to {}'.format(
idx, nb_loss_scalers))
break
_amp_state.loss_scalers[idx]._loss_scale = state_dict[key]['loss_scale']
_amp_state.loss_scalers[idx]._unskipped = state_dict[key]['unskipped']
idx += 1
if len(unexpected_keys) > 0:
raise RuntimeError(
'Error(s) in loading state_dict. Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
# TODO: is this necessary/useful?
# def check_option_consistency(enabled=True,
# opt_level=None,
......
......@@ -177,6 +177,44 @@ Unified API
.. autofunction:: master_params
Checkpointing
-------------
To properly save and load your amp training, we introduce the ``amp.state_dict()``, which contains all ``loss_scaler``\ s and their corresponding unskipped steps, as well as ``amp.load_state_dict()`` to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow::
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
...
Note that we recommend restoring the model using the same ``opt_level``. Also note that we recommend calling the ``load_state_dict`` methods after ``amp.initialize``.
Advanced use cases
------------------
......
import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from apex import amp
from utils import common_init, FLOAT
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(6)
self.param = nn.Parameter(torch.randn(1))
def forward(self, x):
x = x * self.param
x = F.relu(self.conv1(x))
x = self.bn1(x)
return x
class TestCheckpointing(unittest.TestCase):
def setUp(self):
self.initial_lr = 1e-3
self.test_opt_levels = ("O0", "O1", "O2", "O3")
def seed(self):
torch.manual_seed(2809)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def check_state_dict_fp32(self, state_dict):
for key in state_dict:
if 'num_batches_tracked' in key:
continue
param = state_dict[key]
self.assertEqual(param.type(), FLOAT,
'Parameter in state_dict not FLOAT')
def train_step(self, model, optimizer, data, loss_ids):
optimizer.zero_grad()
output = model(data)
# Call backward for num_losses-1
for idx in loss_ids:
loss = output.mean()
with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
scaled_loss.backward(retain_graph=True)
optimizer.step()
return output
def compare_models(self, modelA, modelB):
state_dictA = modelA.state_dict()
state_dictB = modelB.state_dict()
self.assertEqual(len(state_dictA), len(state_dictB),
'state_dicts have different lengths')
for key in state_dictA:
paramA = state_dictA[key]
paramB = state_dictB[key]
self.assertTrue(torch.allclose(paramA.float(), paramB.float(), rtol=0, atol=1e-4),
msg='Parameters in state_dicts not equal.' +
'key: {}\nparam: {}\nrestored: {}\ndiff: {}'.format(
key, paramA, paramB, paramA - paramB))
def test_restoring(self):
nb_epochs = 10
nb_epochs_restore = nb_epochs // 2
for opt_level in self.test_opt_levels:
for res_opt_level in self.test_opt_levels:
for amp_before_load in [True, False]:
for num_losses in range(1, 3):
# print('#' * 75 + '\n' + \
# f'opt_level {opt_level}\n' + \
# f'restore_opt_level {res_opt_level}\n' + \
# f'amp_before_load {amp_before_load}\n' + \
# f'num_losses {num_losses}\n')
self.seed()
# Create reference model
model = MyModel().to('cuda')
optimizer = optim.SGD(model.parameters(),
lr=self.initial_lr)
# Initialize with num_losses*2 for the original model and the restored one
model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level,
num_losses=num_losses*2, verbosity=0)
# Compare training behavior for same restore option
# We cannot really generalize it, since a saved model in O0
# would introduce a skipped step in O1, which will raise an error
if opt_level == res_opt_level:
# train for nb_epochs and restore after nb_epochs_restore
for epoch in range(nb_epochs):
x = torch.randn(16, 3, 24, 24, device='cuda')
output = self.train_step(
model, optimizer, x, range(num_losses))
# Initialize model one step before comparing.
# Otherwise the batchnorm layers will be updated
# additionally in restore_model
if epoch == (nb_epochs_restore - 1):
# Load model and optimizer
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
# Check state_dict for FP32 tensors
self.check_state_dict_fp32(checkpoint['model'])
# Restore model
restore_model = MyModel().to('cuda')
restore_optimizer = optim.SGD(
restore_model.parameters(),
lr=self.initial_lr)
if amp_before_load:
restore_model, restore_optimizer = amp.initialize(
restore_model,
restore_optimizer,
opt_level=res_opt_level,
num_losses=num_losses*2,
verbosity=0)
restore_model.load_state_dict(checkpoint['model'])
restore_optimizer.load_state_dict(checkpoint['optimizer'])
# FIXME: We cannot test the amp.state_dict in the same script
# amp.load_state_dict(checkpoint['amp'])
if not amp_before_load:
restore_model, restore_optimizer = amp.initialize(
restore_model,
restore_optimizer,
opt_level=res_opt_level,
num_losses=num_losses*2,
verbosity=0)
elif epoch >= nb_epochs_restore:
restore_output = self.train_step(
restore_model,
restore_optimizer,
x,
range(num_losses, num_losses*2))
self.assertTrue(
torch.allclose(output.float(), restore_output.float()),
'Output of reference and restored models differ')
self.compare_models(model, restore_model)
# if opt_level != res_opt_level
else:
# Only check state_dict
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
# Check state_dict for FP32 tensors
self.check_state_dict_fp32(checkpoint['model'])
# Restore model
restore_model = MyModel().to('cuda')
restore_optimizer = optim.SGD(
restore_model.parameters(),
lr=self.initial_lr)
if amp_before_load:
restore_model, restore_optimizer = amp.initialize(
restore_model,
restore_optimizer,
opt_level=res_opt_level,
num_losses=num_losses,
verbosity=0)
restore_model.load_state_dict(checkpoint['model'])
restore_optimizer.load_state_dict(checkpoint['optimizer'])
# FIXME: We cannot test the amp.state_dict in the same script
# amp.load_state_dict(checkpoint['amp'])
if not amp_before_load:
restore_model, restore_optimizer = amp.initialize(
restore_model,
restore_optimizer,
opt_level=res_opt_level,
num_losses=num_losses,
verbosity=0)
self.compare_models(model, restore_model)
def test_loss_scale_decrease(self):
num_losses = 3
nb_decrease_loss_scales = [0, 1, 2]
for opt_level in self.test_opt_levels:
#print('#' * 75 + f'\n opt_level {opt_level}\n')
# Create new tmp copy for this run
nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales)
model = MyModel().to('cuda')
optimizer = optim.SGD(model.parameters(),
lr=1e-3)#self.initial_lr)
model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level, num_losses=num_losses,
verbosity=0)
if amp._amp_state.opt_properties.loss_scale != 'dynamic':
#print('Static loss scale set. Skipping opt_level.')
continue
# force to skip some updates to decrease the loss_scale
initial_loss_scales = []
for idx in range(num_losses):
initial_loss_scales.append(
amp._amp_state.loss_scalers[idx].loss_scale())
for _ in range(len(nb_decrease_loss_scales)):
x = torch.randn(16, 3, 24, 24, device='cuda')
for idx in range(num_losses):
while nb_decrease_loss_scales_tmp[idx] > 0:
optimizer.zero_grad()
output = model(x * 2**17)
loss = output.mean()
with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
scaled_loss.backward(retain_graph=True)
optimizer.step()
nb_decrease_loss_scales_tmp[idx] -= 1
# Check loss scales afterwards
updated_loss_scales = []
for idx in range(num_losses):
updated_loss_scales.append(
amp._amp_state.loss_scalers[idx].loss_scale())
for factor, update_ls, init_ls in zip(nb_decrease_loss_scales,
updated_loss_scales,
initial_loss_scales):
self.assertEqual(update_ls, init_ls / 2**factor)
# Check state dict
amp_state_dict = amp.state_dict()
for scaler_idx, factor, init_ls in zip(amp_state_dict,
nb_decrease_loss_scales,
initial_loss_scales):
scaler = amp_state_dict[scaler_idx]
self.assertEqual(scaler['loss_scale'], init_ls / 2**factor)
unskipped_target = 0
self.assertEqual(scaler['unskipped'], unskipped_target)
def test_state_dict(self):
for opt_level in self.test_opt_levels:
# Skip O3
if opt_level == 'O3':
continue
model = MyModel().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level, verbosity=0)
# Export state_dict and check for Half
state_dict = model.state_dict()
for key in state_dict:
self.assertFalse('Half' in state_dict[key].type())
# Check, if model is still trainable
# Create dummy data
data = torch.randn(10, 3, 4, 4, device='cuda')
target = torch.randn(10, 6, 4, 4, device='cuda')
# Get initnial loss
optimizer.zero_grad()
output = model(data)
loss = F.mse_loss(output, target)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
last_loss = loss.item()
# train for some epochs
for epoch in range(10):
optimizer.zero_grad()
output = model(data)
loss = F.mse_loss(output, target)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
self.assertTrue(loss.item() < last_loss)
last_loss = loss.item()
if __name__=='__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment