Commit cd1e5c09 authored by Dmytro Okhonko's avatar Dmytro Okhonko Committed by Facebook Github Bot
Browse files

Move save/load checkpoint functions to utils

Summary:
Move `load_checkpoint`, `save_checkpoint` and `reload_train` from train.py to checkpoint_utils.py
Move `get_perplexity` from train.py to utils.py.
This will make train.py lighter and allow us to reuse all this utils functionality when fairseq is used as external library.

Reviewed By: myleott

Differential Revision: D15289607

fbshipit-source-id: 4b7c95225ac22e402bcda3497811361809110df1
parent c124d272
...@@ -7,16 +7,122 @@ ...@@ -7,16 +7,122 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Union from typing import Union
import collections
import logging import logging
import os import os
import re import re
import traceback import traceback
import shutil
import torch import torch
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import tasks from fairseq import tasks, distributed_utils
from fairseq.models import FairseqEncoder, FairseqDecoder from fairseq.models import FairseqEncoder, FairseqDecoder
from fairseq.meters import StopwatchMeter
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
write_timer = StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
}
if hasattr(save_checkpoint, 'best'):
extra_state.update({'best': save_checkpoint.best})
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
shutil.copyfile(checkpoints[0], cp)
write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
checkpoints[0], epoch, updates, write_timer.sum))
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
"""Load a checkpoint and replay dataloader to match."""
# Only rank 0 should attempt to create the required dir
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
if os.path.isabs(args.restore_file):
checkpoint_path = args.restore_file
else:
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
eval(args.optimizer_overrides))
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr_state = extra_state['train_iterator']
# If the loaded checkpoint is not at epoch 0, reload train dataset,
# as it could be potentially sharded.
if epoch_itr_state['epoch'] != 0:
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
epoch_itr.load_state_dict(epoch_itr_state)
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state and not args.reset_optimizer:
save_checkpoint.best = extra_state['best']
return True
else:
print('| no existing checkpoint found {}'.format(checkpoint_path))
return False
def load_checkpoint_to_cpu(path): def load_checkpoint_to_cpu(path):
...@@ -59,6 +165,28 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None): ...@@ -59,6 +165,28 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
return ensemble, args return ensemble, args
def reload_train(args, epoch_itr, max_positions, task):
# nothing needs to be done when the dataset is not sharded.
if "data" not in args or ("data" in args and len(args.data.split(":")) == 1):
return epoch_itr
print("| Reloading shard of train data at epoch: ", epoch_itr.epoch)
task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
epoch=epoch_itr.epoch,
)
return epoch_itr
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory. """Retrieves all checkpoints found in `path` directory.
......
...@@ -9,6 +9,7 @@ from collections import defaultdict ...@@ -9,6 +9,7 @@ from collections import defaultdict
from typing import Callable from typing import Callable
import copy import copy
import importlib.util import importlib.util
import math
import os import os
import sys import sys
import warnings import warnings
...@@ -286,6 +287,13 @@ def log_softmax(x, dim, onnx_trace=False): ...@@ -286,6 +287,13 @@ def log_softmax(x, dim, onnx_trace=False):
return F.log_softmax(x, dim=dim, dtype=torch.float32) return F.log_softmax(x, dim=dim, dtype=torch.float32)
def get_perplexity(loss):
try:
return '{:.2f}'.format(math.pow(2, loss))
except OverflowError:
return float('inf')
def deprecation_warning(message, stacklevel=3): def deprecation_warning(message, stacklevel=3):
# don't use DeprecationWarning, since it's ignored by default # don't use DeprecationWarning, since it's ignored by default
warnings.warn(message, stacklevel=stacklevel) warnings.warn(message, stacklevel=stacklevel)
......
...@@ -12,9 +12,7 @@ from unittest.mock import MagicMock, patch ...@@ -12,9 +12,7 @@ from unittest.mock import MagicMock, patch
import torch import torch
from fairseq import data from fairseq import data, checkpoint_utils
import train
def mock_trainer(epoch, num_updates, iterations_in_epoch): def mock_trainer(epoch, num_updates, iterations_in_epoch):
...@@ -72,8 +70,9 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -72,8 +70,9 @@ class TestLoadCheckpoint(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
with patch('train.reload_train', return_value=epoch_itr): with patch('fairseq.checkpoint_utils.reload_train', return_value=epoch_itr):
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) checkpoint_utils.load_checkpoint(
self.args_mock, trainer, epoch_itr, 512, None)
self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50) self.assertEqual(epoch_itr.iterations_in_epoch, 50)
...@@ -101,8 +100,9 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -101,8 +100,9 @@ class TestLoadCheckpoint(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
with patch('train.reload_train', return_value=epoch_itr): with patch('fairseq.checkpoint_utils.reload_train', return_value=epoch_itr):
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) checkpoint_utils.load_checkpoint(
self.args_mock, trainer, epoch_itr, 512, None)
itr = epoch_itr.next_epoch_itr(shuffle=False) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3) self.assertEqual(epoch_itr.epoch, 3)
...@@ -114,7 +114,7 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -114,7 +114,7 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
self.patches['os.path.isfile'].return_value = False self.patches['os.path.isfile'].return_value = False
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None) checkpoint_utils.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
itr = epoch_itr.next_epoch_itr(shuffle=False) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1) self.assertEqual(epoch_itr.epoch, 1)
......
...@@ -13,7 +13,6 @@ import collections ...@@ -13,7 +13,6 @@ import collections
import math import math
import os import os
import random import random
import shutil
import torch import torch
...@@ -84,7 +83,8 @@ def main(args, init_distributed=False): ...@@ -84,7 +83,8 @@ def main(args, init_distributed=False):
) )
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
load_checkpoint(args, trainer, epoch_itr, max_positions, task) checkpoint_utils.load_checkpoint(
args, trainer, epoch_itr, max_positions, task)
# Train until the learning rate gets too small # Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
...@@ -106,35 +106,14 @@ def main(args, init_distributed=False): ...@@ -106,35 +106,14 @@ def main(args, init_distributed=False):
# save checkpoint # save checkpoint
if epoch_itr.epoch % args.save_interval == 0: if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) checkpoint_utils.save_checkpoint(
args, trainer, epoch_itr, valid_losses[0])
epoch_itr = reload_train(args, epoch_itr, max_positions, task) epoch_itr = checkpoint_utils.reload_train(args, epoch_itr, max_positions, task)
train_meter.stop() train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum)) print('| done training in {:.1f} seconds'.format(train_meter.sum))
def reload_train(args, epoch_itr, max_positions, task):
# nothing needs to be done when the dataset is not sharded.
if "data" not in args or ("data" in args and len(args.data.split(":")) == 1):
return epoch_itr
print("| Reloading shard of train data at epoch: ", epoch_itr.epoch)
task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
epoch=epoch_itr.epoch,
)
return epoch_itr
def train(args, trainer, task, epoch_itr): def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch.""" """Train the model for one epoch."""
# Update parameters every N batches # Update parameters every N batches
...@@ -178,7 +157,7 @@ def train(args, trainer, task, epoch_itr): ...@@ -178,7 +157,7 @@ def train(args, trainer, task, epoch_itr):
num_updates = trainer.get_num_updates() num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update: if num_updates >= max_update:
break break
...@@ -206,7 +185,7 @@ def get_training_stats(trainer): ...@@ -206,7 +185,7 @@ def get_training_stats(trainer):
stats['nll_loss'] = nll_loss stats['nll_loss'] = nll_loss
else: else:
nll_loss = trainer.get_meter('train_loss') nll_loss = trainer.get_meter('train_loss')
stats['ppl'] = get_perplexity(nll_loss.avg) stats['ppl'] = utils.get_perplexity(nll_loss.avg)
stats['wps'] = trainer.get_meter('wps') stats['wps'] = trainer.get_meter('wps')
stats['ups'] = trainer.get_meter('ups') stats['ups'] = trainer.get_meter('ups')
stats['wpb'] = trainer.get_meter('wpb') stats['wpb'] = trainer.get_meter('wpb')
...@@ -282,122 +261,14 @@ def get_valid_stats(trainer): ...@@ -282,122 +261,14 @@ def get_valid_stats(trainer):
stats['nll_loss'] = nll_loss stats['nll_loss'] = nll_loss
else: else:
nll_loss = stats['loss'] nll_loss = stats['loss']
stats['ppl'] = get_perplexity(nll_loss.avg) stats['ppl'] = utils.get_perplexity(nll_loss.avg)
stats['num_updates'] = trainer.get_num_updates() stats['num_updates'] = trainer.get_num_updates()
if hasattr(save_checkpoint, 'best'): if hasattr(checkpoint_utils.save_checkpoint, 'best'):
stats['best_loss'] = min(save_checkpoint.best, stats['loss'].avg) stats['best_loss'] = min(
checkpoint_utils.save_checkpoint.best, stats['loss'].avg)
return stats return stats
def get_perplexity(loss):
try:
return '{:.2f}'.format(math.pow(2, loss))
except OverflowError:
return float('inf')
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
write_timer = StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
}
if hasattr(save_checkpoint, 'best'):
extra_state.update({'best': save_checkpoint.best})
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
shutil.copyfile(checkpoints[0], cp)
write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
checkpoints[0], epoch, updates, write_timer.sum))
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
"""Load a checkpoint and replay dataloader to match."""
# Only rank 0 should attempt to create the required dir
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
if os.path.isabs(args.restore_file):
checkpoint_path = args.restore_file
else:
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
eval(args.optimizer_overrides))
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr_state = extra_state['train_iterator']
# If the loaded checkpoint is not at epoch 0, reload train dataset,
# as it could be potentially sharded.
if epoch_itr_state['epoch'] != 0:
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
epoch_itr.load_state_dict(epoch_itr_state)
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state and not args.reset_optimizer:
save_checkpoint.best = extra_state['best']
return True
else:
print('| no existing checkpoint found {}'.format(checkpoint_path))
return False
def distributed_main(i, args, start_rank=0): def distributed_main(i, args, start_rank=0):
args.device_id = i args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn if args.distributed_rank is None: # torch.multiprocessing.spawn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment