"vscode:/vscode.git/clone" did not exist on "ea26c37bad0614fe67411daf501d034d0848e2cb"
Commit cd1e5c09 authored by Dmytro Okhonko's avatar Dmytro Okhonko Committed by Facebook Github Bot
Browse files

Move save/load checkpoint functions to utils

Summary:
Move `load_checkpoint`, `save_checkpoint` and `reload_train` from train.py to checkpoint_utils.py
Move `get_perplexity` from train.py to utils.py.
This will make train.py lighter and allow us to reuse all this utils functionality when fairseq is used as external library.

Reviewed By: myleott

Differential Revision: D15289607

fbshipit-source-id: 4b7c95225ac22e402bcda3497811361809110df1
parent c124d272
......@@ -7,16 +7,122 @@
from collections import OrderedDict
from typing import Union
import collections
import logging
import os
import re
import traceback
import shutil
import torch
from torch.serialization import default_restore_location
from fairseq import tasks
from fairseq import tasks, distributed_utils
from fairseq.models import FairseqEncoder, FairseqDecoder
from fairseq.meters import StopwatchMeter
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
write_timer = StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
}
if hasattr(save_checkpoint, 'best'):
extra_state.update({'best': save_checkpoint.best})
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
shutil.copyfile(checkpoints[0], cp)
write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
checkpoints[0], epoch, updates, write_timer.sum))
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
"""Load a checkpoint and replay dataloader to match."""
# Only rank 0 should attempt to create the required dir
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
if os.path.isabs(args.restore_file):
checkpoint_path = args.restore_file
else:
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
eval(args.optimizer_overrides))
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr_state = extra_state['train_iterator']
# If the loaded checkpoint is not at epoch 0, reload train dataset,
# as it could be potentially sharded.
if epoch_itr_state['epoch'] != 0:
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
epoch_itr.load_state_dict(epoch_itr_state)
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state and not args.reset_optimizer:
save_checkpoint.best = extra_state['best']
return True
else:
print('| no existing checkpoint found {}'.format(checkpoint_path))
return False
def load_checkpoint_to_cpu(path):
......@@ -59,6 +165,28 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None):
return ensemble, args
def reload_train(args, epoch_itr, max_positions, task):
# nothing needs to be done when the dataset is not sharded.
if "data" not in args or ("data" in args and len(args.data.split(":")) == 1):
return epoch_itr
print("| Reloading shard of train data at epoch: ", epoch_itr.epoch)
task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
epoch=epoch_itr.epoch,
)
return epoch_itr
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
......
......@@ -9,6 +9,7 @@ from collections import defaultdict
from typing import Callable
import copy
import importlib.util
import math
import os
import sys
import warnings
......@@ -286,6 +287,13 @@ def log_softmax(x, dim, onnx_trace=False):
return F.log_softmax(x, dim=dim, dtype=torch.float32)
def get_perplexity(loss):
try:
return '{:.2f}'.format(math.pow(2, loss))
except OverflowError:
return float('inf')
def deprecation_warning(message, stacklevel=3):
# don't use DeprecationWarning, since it's ignored by default
warnings.warn(message, stacklevel=stacklevel)
......
......@@ -12,9 +12,7 @@ from unittest.mock import MagicMock, patch
import torch
from fairseq import data
import train
from fairseq import data, checkpoint_utils
def mock_trainer(epoch, num_updates, iterations_in_epoch):
......@@ -72,8 +70,9 @@ class TestLoadCheckpoint(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
with patch('train.reload_train', return_value=epoch_itr):
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
with patch('fairseq.checkpoint_utils.reload_train', return_value=epoch_itr):
checkpoint_utils.load_checkpoint(
self.args_mock, trainer, epoch_itr, 512, None)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
......@@ -101,8 +100,9 @@ class TestLoadCheckpoint(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
with patch('train.reload_train', return_value=epoch_itr):
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
with patch('fairseq.checkpoint_utils.reload_train', return_value=epoch_itr):
checkpoint_utils.load_checkpoint(
self.args_mock, trainer, epoch_itr, 512, None)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3)
......@@ -114,7 +114,7 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
self.patches['os.path.isfile'].return_value = False
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
checkpoint_utils.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)
......
......@@ -13,7 +13,6 @@ import collections
import math
import os
import random
import shutil
import torch
......@@ -84,7 +83,8 @@ def main(args, init_distributed=False):
)
# Load the latest checkpoint if one is available
load_checkpoint(args, trainer, epoch_itr, max_positions, task)
checkpoint_utils.load_checkpoint(
args, trainer, epoch_itr, max_positions, task)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
......@@ -106,35 +106,14 @@ def main(args, init_distributed=False):
# save checkpoint
if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
checkpoint_utils.save_checkpoint(
args, trainer, epoch_itr, valid_losses[0])
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
epoch_itr = checkpoint_utils.reload_train(args, epoch_itr, max_positions, task)
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def reload_train(args, epoch_itr, max_positions, task):
# nothing needs to be done when the dataset is not sharded.
if "data" not in args or ("data" in args and len(args.data.split(":")) == 1):
return epoch_itr
print("| Reloading shard of train data at epoch: ", epoch_itr.epoch)
task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
epoch=epoch_itr.epoch,
)
return epoch_itr
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Update parameters every N batches
......@@ -178,7 +157,7 @@ def train(args, trainer, task, epoch_itr):
num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update:
break
......@@ -206,7 +185,7 @@ def get_training_stats(trainer):
stats['nll_loss'] = nll_loss
else:
nll_loss = trainer.get_meter('train_loss')
stats['ppl'] = get_perplexity(nll_loss.avg)
stats['ppl'] = utils.get_perplexity(nll_loss.avg)
stats['wps'] = trainer.get_meter('wps')
stats['ups'] = trainer.get_meter('ups')
stats['wpb'] = trainer.get_meter('wpb')
......@@ -282,122 +261,14 @@ def get_valid_stats(trainer):
stats['nll_loss'] = nll_loss
else:
nll_loss = stats['loss']
stats['ppl'] = get_perplexity(nll_loss.avg)
stats['ppl'] = utils.get_perplexity(nll_loss.avg)
stats['num_updates'] = trainer.get_num_updates()
if hasattr(save_checkpoint, 'best'):
stats['best_loss'] = min(save_checkpoint.best, stats['loss'].avg)
if hasattr(checkpoint_utils.save_checkpoint, 'best'):
stats['best_loss'] = min(
checkpoint_utils.save_checkpoint.best, stats['loss'].avg)
return stats
def get_perplexity(loss):
try:
return '{:.2f}'.format(math.pow(2, loss))
except OverflowError:
return float('inf')
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
write_timer = StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
}
if hasattr(save_checkpoint, 'best'):
extra_state.update({'best': save_checkpoint.best})
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
shutil.copyfile(checkpoints[0], cp)
write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
checkpoints[0], epoch, updates, write_timer.sum))
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
"""Load a checkpoint and replay dataloader to match."""
# Only rank 0 should attempt to create the required dir
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
if os.path.isabs(args.restore_file):
checkpoint_path = args.restore_file
else:
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler,
eval(args.optimizer_overrides))
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr_state = extra_state['train_iterator']
# If the loaded checkpoint is not at epoch 0, reload train dataset,
# as it could be potentially sharded.
if epoch_itr_state['epoch'] != 0:
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
epoch_itr.load_state_dict(epoch_itr_state)
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state and not args.reset_optimizer:
save_checkpoint.best = extra_state['best']
return True
else:
print('| no existing checkpoint found {}'.format(checkpoint_path))
return False
def distributed_main(i, args, start_rank=0):
args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment