Commit fc312d28 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

ability to checkpoint when reaching certain number of updates

parent 58e2c449
...@@ -201,6 +201,12 @@ def add_checkpoint_args(parser): ...@@ -201,6 +201,12 @@ def add_checkpoint_args(parser):
help='filename in save-dir from which to load checkpoint') help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N', group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs') help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, metavar='N',
help='if specified, saves best/last checkpoint every this many updates. '
'will also validate before saving to determine if val loss is better')
group.add_argument('--keep-interval-updates', type=int, default=0, metavar='N',
help='if --save-interval-updates is specified, keep the last this many checkpoints'
' created after specified number of updates (format is checkpoint_[epoch]_[numupd].pt')
group.add_argument('--no-save', action='store_true', group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints') help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
......
...@@ -117,6 +117,7 @@ class json_progress_bar(progress_bar): ...@@ -117,6 +117,7 @@ class json_progress_bar(progress_bar):
def print(self, stats): def print(self, stats):
"""Print end-of-epoch stats.""" """Print end-of-epoch stats."""
self.stats = stats
stats = self._format_stats(self.stats, epoch=self.epoch) stats = self._format_stats(self.stats, epoch=self.epoch)
print(json.dumps(stats), flush=True) print(json.dumps(stats), flush=True)
......
...@@ -9,6 +9,7 @@ from collections import defaultdict, OrderedDict ...@@ -9,6 +9,7 @@ from collections import defaultdict, OrderedDict
import contextlib import contextlib
import logging import logging
import os import os
import re
import torch import torch
import traceback import traceback
...@@ -351,10 +352,11 @@ def buffered_arange(max): ...@@ -351,10 +352,11 @@ def buffered_arange(max):
def convert_padding_direction( def convert_padding_direction(
src_tokens, src_tokens,
padding_idx, src_lengths,
right_to_left=False, padding_idx,
left_to_right=False, right_to_left=False,
left_to_right=False,
): ):
assert right_to_left ^ left_to_right assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx) pad_mask = src_tokens.eq(padding_idx)
...@@ -396,3 +398,19 @@ def clip_grad_norm_(tensor, max_norm): ...@@ -396,3 +398,19 @@ def clip_grad_norm_(tensor, max_norm):
def fill_with_neg_inf(t): def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf.""" """FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t) return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
""" retrieves all checkpoints found in `path` directory. checkpoints are identified by matching filename to
the specified pattern. if the pattern contains groups, the result will be sorted by the first group in descending
order """
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
...@@ -62,10 +62,13 @@ def average_checkpoints(inputs): ...@@ -62,10 +62,13 @@ def average_checkpoints(inputs):
return new_state return new_state
def last_n_checkpoints(paths, n): def last_n_checkpoints(paths, n, update_based):
assert len(paths) == 1 assert len(paths) == 1
path = paths[0] path = paths[0]
pt_regexp = re.compile(r'checkpoint(\d+)\.pt') if update_based:
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path) files = os.listdir(path)
entries = [] entries = []
...@@ -81,7 +84,7 @@ def last_n_checkpoints(paths, n): ...@@ -81,7 +84,7 @@ def last_n_checkpoints(paths, n):
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to ' description='Tool to average the params of input checkpoints to '
'produce a new checkpoint', 'produce a new checkpoint',
) )
parser.add_argument( parser.add_argument(
...@@ -95,7 +98,7 @@ def main(): ...@@ -95,7 +98,7 @@ def main():
required=True, required=True,
metavar='FILE', metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this ' help='Write the new checkpoint containing the averaged weights to this '
'path.', 'path.',
) )
parser.add_argument( parser.add_argument(
'--num', '--num',
...@@ -103,11 +106,16 @@ def main(): ...@@ -103,11 +106,16 @@ def main():
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those', 'and average last num of those',
) )
parser.add_argument(
'--update-based-checkpoints',
action='store_true',
help='if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints'
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.num is not None: if args.num is not None:
args.inputs = last_n_checkpoints(args.inputs, args.num) args.inputs = last_n_checkpoints(args.inputs, args.num, args.update_based_checkpoints)
print('averaging checkpoints: ', args.inputs) print('averaging checkpoints: ', args.inputs)
new_state = average_checkpoints(args.inputs) new_state = average_checkpoints(args.inputs)
......
...@@ -15,10 +15,10 @@ from fairseq import criterions, data, models, options, progress_bar ...@@ -15,10 +15,10 @@ from fairseq import criterions, data, models, options, progress_bar
from fairseq.fp16_trainer import FP16Trainer from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import checkpoint_paths
def main(args): def main(args):
if args.max_tokens is None: if args.max_tokens is None:
args.max_tokens = 6000 args.max_tokens = 6000
...@@ -82,26 +82,22 @@ def main(args): ...@@ -82,26 +82,22 @@ def main(args):
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
lr = trainer.get_lr() lr = trainer.get_lr()
first_val_loss = None
train_meter = StopwatchMeter() train_meter = StopwatchMeter()
train_meter.start() train_meter.start()
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update: while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch # train for one epoch
train(args, trainer, next(train_dataloader), epoch) train(args, trainer, next(train_dataloader), epoch, dataset)
# evaluate on validate set
first_val_loss = None
if epoch % args.validate_interval == 0: if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')): first_val_loss = val_loss(args, trainer, dataset, epoch)
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
first_val_loss = val_loss
# only use first validation loss to update the learning rate # only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, first_val_loss) lr = trainer.lr_step(epoch, first_val_loss)
# save checkpoint # save checkpoint
if not args.no_save and epoch % args.save_interval == 0: if not args.no_save and epoch % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, first_val_loss) save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss)
epoch += 1 epoch += 1
train_meter.stop() train_meter.stop()
...@@ -120,7 +116,7 @@ def load_dataset(args, splits): ...@@ -120,7 +116,7 @@ def load_dataset(args, splits):
return dataset return dataset
def train(args, trainer, itr, epoch): def train(args, trainer, itr, epoch, dataset):
"""Train the model for one epoch.""" """Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get # Set seed based on args.seed and the epoch number so that we get
...@@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch): ...@@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch):
if i == 0: if i == 0:
trainer.get_meter('wps').reset() trainer.get_meter('wps').reset()
if trainer.get_num_updates() >= max_update: num_updates = trainer.get_num_updates()
if not args.no_save and (args.save_interval_updates or 0) > 0 and num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss)
if num_updates >= max_update:
break break
# log end-of-epoch stats # log end-of-epoch stats
...@@ -202,7 +203,7 @@ def get_training_stats(trainer): ...@@ -202,7 +203,7 @@ def get_training_stats(trainer):
return stats return stats
def validate(args, trainer, dataset, subset, epoch): def validate(args, trainer, dataset, subset, epoch, num_updates, verbose):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader # Initialize dataloader
...@@ -236,19 +237,24 @@ def validate(args, trainer, dataset, subset, epoch): ...@@ -236,19 +237,24 @@ def validate(args, trainer, dataset, subset, epoch):
for sample in progress: for sample in progress:
log_output = trainer.valid_step(sample) log_output = trainer.valid_step(sample)
# log mid-validation stats if verbose:
stats = get_valid_stats(trainer) # log mid-validation stats
for k, v in log_output.items(): stats = get_valid_stats(trainer)
if k in ['loss', 'nll_loss', 'sample_size']: for k, v in log_output.items():
continue if k in ['loss', 'nll_loss', 'sample_size']:
extra_meters[k].update(v) continue
stats[k] = extra_meters[k].avg extra_meters[k].update(v)
progress.log(stats) stats[k] = extra_meters[k].avg
progress.log(stats)
# log validation stats # log validation stats
stats = get_valid_stats(trainer) stats = get_valid_stats(trainer)
for k, meter in extra_meters.items(): for k, meter in extra_meters.items():
stats[k] = meter.avg stats[k] = meter.avg
if num_updates is not None:
stats['num_updates'] = num_updates
progress.print(stats) progress.print(stats)
return stats['valid_loss'] return stats['valid_loss']
...@@ -273,16 +279,33 @@ def get_perplexity(loss): ...@@ -273,16 +279,33 @@ def get_perplexity(loss):
return float('inf') return float('inf')
def save_checkpoint(trainer, args, epoch, val_loss=None): def val_loss(args, trainer, dataset, epoch, num_updates=None):
# evaluate on validate set
subsets = args.valid_subset.split(',')
# we want to validate all subsets so the results get printed out, but return only the first
losses = [validate(args, trainer, dataset, subset, epoch, num_updates, verbose=False) for subset in subsets]
return losses[0] if len(losses) > 0 else None
def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
extra_state = { extra_state = {
'epoch': epoch, 'epoch': epoch,
'val_loss': val_loss, 'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time, 'wall_time': trainer.get_meter('wall').elapsed_time,
} }
if not args.no_epoch_checkpoints: if end_of_epoch and not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch)) epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state) trainer.save_checkpoint(epoch_filename, extra_state)
elif not end_of_epoch and args.keep_interval_updates > 0:
checkpoint_filename = os.path.join(args.save_dir,
'checkpoint_{}_{}.pt'.format(epoch, trainer.get_num_updates()))
trainer.save_checkpoint(checkpoint_filename, extra_state)
# remove old checkpoints
checkpoints = checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
# checkpoints are sorted in descending order
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)
assert val_loss is not None assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best: if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
...@@ -317,9 +340,11 @@ if __name__ == '__main__': ...@@ -317,9 +340,11 @@ if __name__ == '__main__':
if args.distributed_port > 0 or args.distributed_init_method is not None: if args.distributed_port > 0 or args.distributed_init_method is not None:
from distributed_train import main as distributed_main from distributed_train import main as distributed_main
distributed_main(args) distributed_main(args)
elif args.distributed_world_size > 1: elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main from multiprocessing_train import main as multiprocessing_main
multiprocessing_main(args) multiprocessing_main(args)
else: else:
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment