"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "560719da41b2acf0847e887c1b17fd039ed24bba"
Commit fc312d28 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

ability to checkpoint when reaching certain number of updates

parent 58e2c449
......@@ -201,6 +201,12 @@ def add_checkpoint_args(parser):
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, metavar='N',
help='if specified, saves best/last checkpoint every this many updates. '
'will also validate before saving to determine if val loss is better')
group.add_argument('--keep-interval-updates', type=int, default=0, metavar='N',
help='if --save-interval-updates is specified, keep the last this many checkpoints'
' created after specified number of updates (format is checkpoint_[epoch]_[numupd].pt')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
......
......@@ -117,6 +117,7 @@ class json_progress_bar(progress_bar):
def print(self, stats):
"""Print end-of-epoch stats."""
self.stats = stats
stats = self._format_stats(self.stats, epoch=self.epoch)
print(json.dumps(stats), flush=True)
......
......@@ -9,6 +9,7 @@ from collections import defaultdict, OrderedDict
import contextlib
import logging
import os
import re
import torch
import traceback
......@@ -351,10 +352,11 @@ def buffered_arange(max):
def convert_padding_direction(
src_tokens,
padding_idx,
right_to_left=False,
left_to_right=False,
src_tokens,
src_lengths,
padding_idx,
right_to_left=False,
left_to_right=False,
):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
......@@ -396,3 +398,19 @@ def clip_grad_norm_(tensor, max_norm):
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
""" retrieves all checkpoints found in `path` directory. checkpoints are identified by matching filename to
the specified pattern. if the pattern contains groups, the result will be sorted by the first group in descending
order """
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
......@@ -62,10 +62,13 @@ def average_checkpoints(inputs):
return new_state
def last_n_checkpoints(paths, n):
def last_n_checkpoints(paths, n, update_based):
assert len(paths) == 1
path = paths[0]
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
if update_based:
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path)
entries = []
......@@ -81,7 +84,7 @@ def last_n_checkpoints(paths, n):
def main():
parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to '
'produce a new checkpoint',
'produce a new checkpoint',
)
parser.add_argument(
......@@ -95,7 +98,7 @@ def main():
required=True,
metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this '
'path.',
'path.',
)
parser.add_argument(
'--num',
......@@ -103,11 +106,16 @@ def main():
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those',
)
parser.add_argument(
'--update-based-checkpoints',
action='store_true',
help='if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints'
)
args = parser.parse_args()
print(args)
if args.num is not None:
args.inputs = last_n_checkpoints(args.inputs, args.num)
args.inputs = last_n_checkpoints(args.inputs, args.num, args.update_based_checkpoints)
print('averaging checkpoints: ', args.inputs)
new_state = average_checkpoints(args.inputs)
......
......@@ -15,10 +15,10 @@ from fairseq import criterions, data, models, options, progress_bar
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import checkpoint_paths
def main(args):
if args.max_tokens is None:
args.max_tokens = 6000
......@@ -82,26 +82,22 @@ def main(args):
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
first_val_loss = None
train_meter = StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
train(args, trainer, next(train_dataloader), epoch)
train(args, trainer, next(train_dataloader), epoch, dataset)
# evaluate on validate set
first_val_loss = None
if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
first_val_loss = val_loss
first_val_loss = val_loss(args, trainer, dataset, epoch)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, first_val_loss)
# save checkpoint
if not args.no_save and epoch % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, first_val_loss)
save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss)
epoch += 1
train_meter.stop()
......@@ -120,7 +116,7 @@ def load_dataset(args, splits):
return dataset
def train(args, trainer, itr, epoch):
def train(args, trainer, itr, epoch, dataset):
"""Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get
......@@ -168,7 +164,12 @@ def train(args, trainer, itr, epoch):
if i == 0:
trainer.get_meter('wps').reset()
if trainer.get_num_updates() >= max_update:
num_updates = trainer.get_num_updates()
if not args.no_save and (args.save_interval_updates or 0) > 0 and num_updates % args.save_interval_updates == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch, num_updates)
save_checkpoint(trainer, args, epoch, end_of_epoch=False, val_loss=first_val_loss)
if num_updates >= max_update:
break
# log end-of-epoch stats
......@@ -202,7 +203,7 @@ def get_training_stats(trainer):
return stats
def validate(args, trainer, dataset, subset, epoch):
def validate(args, trainer, dataset, subset, epoch, num_updates, verbose):
"""Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader
......@@ -236,19 +237,24 @@ def validate(args, trainer, dataset, subset, epoch):
for sample in progress:
log_output = trainer.valid_step(sample)
# log mid-validation stats
stats = get_valid_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
if verbose:
# log mid-validation stats
stats = get_valid_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
if num_updates is not None:
stats['num_updates'] = num_updates
progress.print(stats)
return stats['valid_loss']
......@@ -273,16 +279,33 @@ def get_perplexity(loss):
return float('inf')
def save_checkpoint(trainer, args, epoch, val_loss=None):
def val_loss(args, trainer, dataset, epoch, num_updates=None):
# evaluate on validate set
subsets = args.valid_subset.split(',')
# we want to validate all subsets so the results get printed out, but return only the first
losses = [validate(args, trainer, dataset, subset, epoch, num_updates, verbose=False) for subset in subsets]
return losses[0] if len(losses) > 0 else None
def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
extra_state = {
'epoch': epoch,
'val_loss': val_loss,
'wall_time': trainer.get_meter('wall').elapsed_time,
}
if not args.no_epoch_checkpoints:
if end_of_epoch and not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
elif not end_of_epoch and args.keep_interval_updates > 0:
checkpoint_filename = os.path.join(args.save_dir,
'checkpoint_{}_{}.pt'.format(epoch, trainer.get_num_updates()))
trainer.save_checkpoint(checkpoint_filename, extra_state)
# remove old checkpoints
checkpoints = checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
# checkpoints are sorted in descending order
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
......@@ -317,9 +340,11 @@ if __name__ == '__main__':
if args.distributed_port > 0 or args.distributed_init_method is not None:
from distributed_train import main as distributed_main
distributed_main(args)
elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main
multiprocessing_main(args)
else:
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment