Commit fc312d28 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

ability to checkpoint when reaching certain number of updates

parent 58e2c449
......@@ -201,6 +201,12 @@ def add_checkpoint_args(parser):
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, metavar='N',
help='if specified, saves best/last checkpoint every this many updates. '
'will also validate before saving to determine if val loss is better')
group.add_argument('--keep-interval-updates', type=int, default=0, metavar='N',
help='if --save-interval-updates is specified, keep the last this many checkpoints'
' created after specified number of updates (format is checkpoint_[epoch]_[numupd].pt')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
......
......@@ -117,6 +117,7 @@ class json_progress_bar(progress_bar):
def print(self, stats):
"""Print end-of-epoch stats."""
self.stats = stats
stats = self._format_stats(self.stats, epoch=self.epoch)
print(json.dumps(stats), flush=True)
......
......@@ -9,6 +9,7 @@ from collections import defaultdict, OrderedDict
import contextlib
import logging
import os
import re
import torch
import traceback
......@@ -352,6 +353,7 @@ def buffered_arange(max):
def convert_padding_direction(
src_tokens,
src_lengths,
padding_idx,
right_to_left=False,
left_to_right=False,
......@@ -396,3 +398,19 @@ def clip_grad_norm_(tensor, max_norm):
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
""" retrieves all checkpoints found in `path` directory. checkpoints are identified by matching filename to
the specified pattern. if the pattern contains groups, the result will be sorted by the first group in descending
order """
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
......@@ -62,9 +62,12 @@ def average_checkpoints(inputs):
return new_state
def last_n_checkpoints(paths, n):
def last_n_checkpoints(paths, n, update_based):
assert len(paths) == 1
path = paths[0]
if update_based:
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path)
......@@ -103,11 +106,16 @@ def main():
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those',
)
parser.add_argument(
'--update-based-checkpoints',
action='store_true',
help='if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints'
)
args = parser.parse_args()
print(args)
if args.num is not None:
args.inputs = last_n_checkpoints(args.inputs, args.num)
args.inputs = last_n_checkpoints(args.inputs, args.num, args.update_based_checkpoints)
print('averaging checkpoints: ', args.inputs)
new_state = average_checkpoints(args.inputs)
......
......@@ -15,10 +15,10 @@ from fairseq import criterions, data, models, options, progress_bar
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import checkpoint_paths
def main(args):
if args.max_tokens is None:
args.max_tokens = 6000
......@@ -82,26 +82,22 @@ def main(args):
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
first_val_loss = None
train_meter = StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
train(args, trainer, next(train_dataloader), epoch)
train(args, trainer, next(train_dataloader), epoch, dataset)
# evaluate on validate set
first_val_loss = None
if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
first_val_loss = val_loss
first_val_loss = val_loss(args, trainer, dataset, epoch)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, first_val_loss)
# save checkpoint
if not args.no_save and epoch % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, first_val_loss)
save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss)
epoch += 1
train_meter.stop()
......@@ -120,7 +116,7 @@ def load_dataset(args, splits):
return dataset
def train(args, trainer, itr, epoch):
def train(args, trainer, itr, epoch, dataset):
"""Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get
......@@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch):
if i == 0:
trainer.get_meter('wps').reset()
if trainer.get_num_updates() >= max_update:
num_updates = trainer.get_num_updates()
if not args.no_save and (args.save_interval_updates or 0) > 0 and num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss)
if num_updates >= max_update:
break
# log end-of-epoch stats
......@@ -202,7 +203,7 @@ def get_training_stats(trainer):
return stats
def validate(args, trainer, dataset, subset, epoch):
def validate(args, trainer, dataset, subset, epoch, num_updates, verbose):
"""Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader
......@@ -236,6 +237,7 @@ def validate(args, trainer, dataset, subset, epoch):
for sample in progress:
log_output = trainer.valid_step(sample)
if verbose:
# log mid-validation stats
stats = get_valid_stats(trainer)
for k, v in log_output.items():
......@@ -249,6 +251,10 @@ def validate(args, trainer, dataset, subset, epoch):
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
if num_updates is not None:
stats['num_updates'] = num_updates
progress.print(stats)
return stats['valid_loss']
......@@ -273,16 +279,33 @@ def get_perplexity(loss):
return float('inf')
def save_checkpoint(trainer, args, epoch, val_loss=None):
def val_loss(args, trainer, dataset, epoch, num_updates=None):
# evaluate on validate set
subsets = args.valid_subset.split(',')
# we want to validate all subsets so the results get printed out, but return only the first
losses = [validate(args, trainer, dataset, subset, epoch, num_updates, verbose=False) for subset in subsets]
return losses[0] if len(losses) > 0 else None
def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
extra_state = {
'epoch': epoch,
'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time,
}
if not args.no_epoch_checkpoints:
if end_of_epoch and not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
elif not end_of_epoch and args.keep_interval_updates > 0:
checkpoint_filename = os.path.join(args.save_dir,
'checkpoint_{}_{}.pt'.format(epoch, trainer.get_num_updates()))
trainer.save_checkpoint(checkpoint_filename, extra_state)
# remove old checkpoints
checkpoints = checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
# checkpoints are sorted in descending order
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
......@@ -317,9 +340,11 @@ if __name__ == '__main__':
if args.distributed_port > 0 or args.distributed_init_method is not None:
from distributed_train import main as distributed_main
distributed_main(args)
elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main
multiprocessing_main(args)
else:
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment