You need to sign in or sign up before continuing.
Commit fc312d28 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

ability to checkpoint when reaching certain number of updates

parent 58e2c449
...@@ -201,6 +201,12 @@ def add_checkpoint_args(parser): ...@@ -201,6 +201,12 @@ def add_checkpoint_args(parser):
help='filename in save-dir from which to load checkpoint') help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N', group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs') help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, metavar='N',
help='if specified, saves best/last checkpoint every this many updates. '
'will also validate before saving to determine if val loss is better')
group.add_argument('--keep-interval-updates', type=int, default=0, metavar='N',
help='if --save-interval-updates is specified, keep the last this many checkpoints'
' created after specified number of updates (format is checkpoint_[epoch]_[numupd].pt')
group.add_argument('--no-save', action='store_true', group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints') help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
......
...@@ -117,6 +117,7 @@ class json_progress_bar(progress_bar): ...@@ -117,6 +117,7 @@ class json_progress_bar(progress_bar):
def print(self, stats): def print(self, stats):
"""Print end-of-epoch stats.""" """Print end-of-epoch stats."""
self.stats = stats
stats = self._format_stats(self.stats, epoch=self.epoch) stats = self._format_stats(self.stats, epoch=self.epoch)
print(json.dumps(stats), flush=True) print(json.dumps(stats), flush=True)
......
...@@ -9,6 +9,7 @@ from collections import defaultdict, OrderedDict ...@@ -9,6 +9,7 @@ from collections import defaultdict, OrderedDict
import contextlib import contextlib
import logging import logging
import os import os
import re
import torch import torch
import traceback import traceback
...@@ -351,10 +352,11 @@ def buffered_arange(max): ...@@ -351,10 +352,11 @@ def buffered_arange(max):
def convert_padding_direction( def convert_padding_direction(
src_tokens, src_tokens,
padding_idx, src_lengths,
right_to_left=False, padding_idx,
left_to_right=False, right_to_left=False,
left_to_right=False,
): ):
assert right_to_left ^ left_to_right assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx) pad_mask = src_tokens.eq(padding_idx)
...@@ -396,3 +398,19 @@ def clip_grad_norm_(tensor, max_norm): ...@@ -396,3 +398,19 @@ def clip_grad_norm_(tensor, max_norm):
def fill_with_neg_inf(t): def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf.""" """FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t) return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
""" retrieves all checkpoints found in `path` directory. checkpoints are identified by matching filename to
the specified pattern. if the pattern contains groups, the result will be sorted by the first group in descending
order """
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
...@@ -62,10 +62,13 @@ def average_checkpoints(inputs): ...@@ -62,10 +62,13 @@ def average_checkpoints(inputs):
return new_state return new_state
def last_n_checkpoints(paths, n): def last_n_checkpoints(paths, n, update_based):
assert len(paths) == 1 assert len(paths) == 1
path = paths[0] path = paths[0]
pt_regexp = re.compile(r'checkpoint(\d+)\.pt') if update_based:
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path) files = os.listdir(path)
entries = [] entries = []
...@@ -81,7 +84,7 @@ def last_n_checkpoints(paths, n): ...@@ -81,7 +84,7 @@ def last_n_checkpoints(paths, n):
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to ' description='Tool to average the params of input checkpoints to '
'produce a new checkpoint', 'produce a new checkpoint',
) )
parser.add_argument( parser.add_argument(
...@@ -95,7 +98,7 @@ def main(): ...@@ -95,7 +98,7 @@ def main():
required=True, required=True,
metavar='FILE', metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this ' help='Write the new checkpoint containing the averaged weights to this '
'path.', 'path.',
) )
parser.add_argument( parser.add_argument(
'--num', '--num',
...@@ -103,11 +106,16 @@ def main(): ...@@ -103,11 +106,16 @@ def main():
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those', 'and average last num of those',
) )
parser.add_argument(
'--update-based-checkpoints',
action='store_true',
help='if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints'
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.num is not None: if args.num is not None:
args.inputs = last_n_checkpoints(args.inputs, args.num) args.inputs = last_n_checkpoints(args.inputs, args.num, args.update_based_checkpoints)
print('averaging checkpoints: ', args.inputs) print('averaging checkpoints: ', args.inputs)
new_state = average_checkpoints(args.inputs) new_state = average_checkpoints(args.inputs)
......
...@@ -15,10 +15,10 @@ from fairseq import criterions, data, models, options, progress_bar ...@@ -15,10 +15,10 @@ from fairseq import criterions, data, models, options, progress_bar
from fairseq.fp16_trainer import FP16Trainer from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import checkpoint_paths
def main(args): def main(args):
if args.max_tokens is None: if args.max_tokens is None:
args.max_tokens = 6000 args.max_tokens = 6000
...@@ -82,26 +82,22 @@ def main(args): ...@@ -82,26 +82,22 @@ def main(args):
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
lr = trainer.get_lr() lr = trainer.get_lr()
first_val_loss = None
train_meter = StopwatchMeter() train_meter = StopwatchMeter()
train_meter.start() train_meter.start()
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update: while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch # train for one epoch
train(args, trainer, next(train_dataloader), epoch) train(args, trainer, next(train_dataloader), epoch, dataset)
# evaluate on validate set
first_val_loss = None
if epoch % args.validate_interval == 0: if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')): first_val_loss = val_loss(args, trainer, dataset, epoch)
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
first_val_loss = val_loss
# only use first validation loss to update the learning rate # only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, first_val_loss) lr = trainer.lr_step(epoch, first_val_loss)
# save checkpoint # save checkpoint
if not args.no_save and epoch % args.save_interval == 0: if not args.no_save and epoch % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, first_val_loss) save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss)
epoch += 1 epoch += 1
train_meter.stop() train_meter.stop()
...@@ -120,7 +116,7 @@ def load_dataset(args, splits): ...@@ -120,7 +116,7 @@ def load_dataset(args, splits):
return dataset return dataset
def train(args, trainer, itr, epoch): def train(args, trainer, itr, epoch, dataset):
"""Train the model for one epoch.""" """Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get # Set seed based on args.seed and the epoch number so that we get
...@@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch): ...@@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch):
if i == 0: if i == 0:
trainer.get_meter('wps').reset() trainer.get_meter('wps').reset()
if trainer.get_num_updates() >= max_update: num_updates = trainer.get_num_updates()
if not args.no_save and (args.save_interval_updates or 0) > 0 and num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss)
if num_updates >= max_update:
break break
# log end-of-epoch stats # log end-of-epoch stats
...@@ -202,7 +203,7 @@ def get_training_stats(trainer): ...@@ -202,7 +203,7 @@ def get_training_stats(trainer):
return stats return stats
def validate(args, trainer, dataset, subset, epoch): def validate(args, trainer, dataset, subset, epoch, num_updates, verbose):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader # Initialize dataloader
...@@ -236,19 +237,24 @@ def validate(args, trainer, dataset, subset, epoch): ...@@ -236,19 +237,24 @@ def validate(args, trainer, dataset, subset, epoch):
for sample in progress: for sample in progress:
log_output = trainer.valid_step(sample) log_output = trainer.valid_step(sample)
# log mid-validation stats if verbose:
stats = get_valid_stats(trainer) # log mid-validation stats
for k, v in log_output.items(): stats = get_valid_stats(trainer)
if k in ['loss', 'nll_loss', 'sample_size']: for k, v in log_output.items():
continue if k in ['loss', 'nll_loss', 'sample_size']:
extra_meters[k].update(v) continue
stats[k] = extra_meters[k].avg extra_meters[k].update(v)
progress.log(stats) stats[k] = extra_meters[k].avg
progress.log(stats)
# log validation stats # log validation stats
stats = get_valid_stats(trainer) stats = get_valid_stats(trainer)
for k, meter in extra_meters.items(): for k, meter in extra_meters.items():
stats[k] = meter.avg stats[k] = meter.avg
if num_updates is not None:
stats['num_updates'] = num_updates
progress.print(stats) progress.print(stats)
return stats['valid_loss'] return stats['valid_loss']
...@@ -273,16 +279,33 @@ def get_perplexity(loss): ...@@ -273,16 +279,33 @@ def get_perplexity(loss):
return float('inf') return float('inf')
def save_checkpoint(trainer, args, epoch, val_loss=None): def val_loss(args, trainer, dataset, epoch, num_updates=None):
# evaluate on validate set
subsets = args.valid_subset.split(',')
# we want to validate all subsets so the results get printed out, but return only the first
losses = [validate(args, trainer, dataset, subset, epoch, num_updates, verbose=False) for subset in subsets]
return losses[0] if len(losses) > 0 else None
def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
extra_state = { extra_state = {
'epoch': epoch, 'epoch': epoch,
'val_loss': val_loss, 'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time, 'wall_time': trainer.get_meter('wall').elapsed_time,
} }
if not args.no_epoch_checkpoints: if end_of_epoch and not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch)) epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state) trainer.save_checkpoint(epoch_filename, extra_state)
elif not end_of_epoch and args.keep_interval_updates > 0:
checkpoint_filename = os.path.join(args.save_dir,
'checkpoint_{}_{}.pt'.format(epoch, trainer.get_num_updates()))
trainer.save_checkpoint(checkpoint_filename, extra_state)
# remove old checkpoints
checkpoints = checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
# checkpoints are sorted in descending order
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)
assert val_loss is not None assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best: if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
...@@ -317,9 +340,11 @@ if __name__ == '__main__': ...@@ -317,9 +340,11 @@ if __name__ == '__main__':
if args.distributed_port > 0 or args.distributed_init_method is not None: if args.distributed_port > 0 or args.distributed_init_method is not None:
from distributed_train import main as distributed_main from distributed_train import main as distributed_main
distributed_main(args) distributed_main(args)
elif args.distributed_world_size > 1: elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main from multiprocessing_train import main as multiprocessing_main
multiprocessing_main(args) multiprocessing_main(args)
else: else:
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment