Commit d45db804 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#654)

Summary:
- Add --add-bos-token option to LM task
- Cleanup utils.py and options.py
Pull Request resolved: https://github.com/pytorch/fairseq/pull/654

Differential Revision: D15041794

Pulled By: myleott

fbshipit-source-id: 3ad00007769d5f48308052cfd40de39c5ffa1a6e
parent 89a69616
......@@ -46,6 +46,8 @@ class TriangularSchedule(FairseqLRScheduler):
help='max learning rate, must be more than args.lr')
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
help='initial number of updates per period (cycle length)')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing')
parser.add_argument('--shrink-min', action='store_true',
help='if set, also shrinks min lr')
# fmt: on
......
......@@ -16,6 +16,16 @@ class FairseqNAG(FairseqOptimizer):
super().__init__(args, params)
self._optimizer = NAG(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""
......
......@@ -16,6 +16,16 @@ class SGD(FairseqOptimizer):
super().__init__(args, params)
self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
help='momentum factor')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""
......
......@@ -303,19 +303,13 @@ def add_optimization_args(parser):
metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
group.add_argument('--lr-scheduler', default='fixed',
choices=LR_SCHEDULER_REGISTRY.keys(),
help='Learning Rate Scheduler')
group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=-1, type=float, metavar='LR',
help='minimum learning rate')
help='stop training when the learning rate reaches this minimum')
# fmt: on
return group
......
......@@ -81,6 +81,8 @@ class LanguageModelingTask(FairseqTask):
help='include future target')
parser.add_argument('--past-target', action='store_true',
help='include past target')
parser.add_argument('--add-bos-token', action='store_true',
help='prepend beginning of sentence token (<s>)')
# fmt: on
def __init__(self, args, dictionary, output_dictionary, targets=None):
......@@ -185,7 +187,7 @@ class LanguageModelingTask(FairseqTask):
self.datasets[split] = MonolingualDataset(
dataset, sizes, self.dictionary, self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True,
targets=self.targets,
targets=self.targets, add_bos_token=self.args.add_bos_token,
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
......@@ -205,6 +207,7 @@ class LanguageModelingTask(FairseqTask):
self.target_dictionary,
add_eos_for_other_targets=False,
shuffle=False,
add_bos_token=self.args.add_bos_token,
),
eos=self.source_dictionary.eos(),
# remove EOS since this will be used as a prefix for generation
......
......@@ -8,7 +8,7 @@
import itertools
import os
from fairseq import options, utils
from fairseq import options
from fairseq.data import (
ConcatDataset,
data_utils,
......@@ -69,24 +69,6 @@ class TranslationTask(FairseqTask):
help='amount to upsample primary dataset')
# fmt: on
@staticmethod
def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
model = utils.load_checkpoint_to_cpu(path)
args = model['args']
state_dict = model['model']
args = utils.override_model_args(args, arg_overrides)
src_dict = Dictionary.load(src_dict_path)
tgt_dict = Dictionary.load(tgt_dict_path)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
task = TranslationTask(args, src_dict, tgt_dict)
model = task.build_model(args)
model.upgrade_state_dict(state_dict)
model.load_state_dict(state_dict, strict=True)
return model
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
......@@ -102,6 +84,10 @@ class TranslationTask(FairseqTask):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
# upgrade old checkpoints
if isinstance(args.data, str):
args.data = [args.data]
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
......@@ -147,9 +133,7 @@ class TranslationTask(FairseqTask):
src_datasets = []
tgt_datasets = []
data_paths = self.args.data
for dk, data_path in enumerate(data_paths):
for dk, data_path in enumerate(self.args.data):
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
......
......@@ -11,10 +11,11 @@ Train a network across multiple GPUs.
from collections import OrderedDict
from itertools import chain
import os
import torch
from fairseq import distributed_utils, models, optim, utils
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.optim import lr_scheduler
......@@ -119,16 +120,31 @@ class Trainer(object):
"""Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
utils.save_state(
checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.get_model(),
)
extra_state, self._optim_history, last_optim_state = None, [], None
if os.path.exists(filename):
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
# load model parameters
try:
self.get_model().load_state_dict(state['model'], strict=True)
except Exception:
raise Exception(
'Cannot load model parameters from checkpoint, '
'please ensure that the architectures match.'
)
extra_state = state['extra_state']
self._optim_history = state['optimizer_history']
last_optim_state = state['last_optimizer_state']
if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed
self._build_optimizer()
......@@ -136,9 +152,9 @@ class Trainer(object):
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
'criterion does not match; please reset the optimizer (--reset-optimizer)'
'Criterion does not match; please reset the optimizer (--reset-optimizer).'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'optimizer does not match; please reset the optimizer (--reset-optimizer)'
'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
......
......@@ -9,182 +9,25 @@ from collections import defaultdict, OrderedDict
from typing import Callable
import copy
import importlib.util
import logging
import os
import re
import sys
import traceback
import warnings
import torch
import torch.nn.functional as F
from torch.serialization import default_restore_location
from fairseq.modules import gelu, gelu_fast
def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(),
'num_updates': num_updates,
}
],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)
def load_model_state(filename, model):
if not os.path.exists(filename):
return None, [], None
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
model.upgrade_state_dict(state['model'])
# load model parameters
try:
model.load_state_dict(state['model'], strict=True)
except Exception:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
# record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
# move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = {
'best': state['optimizer_history'][-1]['best_loss'],
}
del state['optimizer_history'][-1]['best_loss']
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
return state
def load_checkpoint_to_cpu(path):
state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
return state
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename)
states.append(state)
ensemble = []
for state in states:
args = state['args']
if model_arg_overrides is not None:
args = override_model_args(args, model_arg_overrides)
# build model for ensemble
model = task.build_model(args)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
# some args (e.g., tokens_per_sample) might have been updated while building the model
if model_arg_overrides is not None:
args = override_model_args(args, model_arg_overrides)
return ensemble, args
def override_model_args(args, model_arg_overrides):
# Uses model_arg_overrides {'arg_name': arg} to override model args
for arg_name, arg_val in model_arg_overrides.items():
setattr(args, arg_name, arg_val)
return args
from fairseq import checkpoint_utils
deprecation_warning(
'utils.load_ensemble_for_inference is deprecated. '
'Please use checkpoint_utils.load_model_ensemble instead.'
)
return checkpoint_utils.load_model_ensemble(
filenames, arg_overrides=model_arg_overrides, task=task,
)
def move_to_cuda(sample):
......@@ -379,25 +222,6 @@ def fill_with_neg_inf(t):
return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources."""
......
......@@ -11,9 +11,8 @@ Translate pre-processed data with a trained model.
import torch
from fairseq import bleu, options, progress_bar, tasks, utils
from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module
def main(args):
......@@ -23,7 +22,7 @@ def main(args):
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
import_user_module(args)
utils.import_user_module(args)
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
......@@ -34,7 +33,6 @@ def main(args):
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Set dictionaries
try:
......@@ -45,8 +43,10 @@ def main(args):
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
)
# Optimize ensemble for generation
......
......@@ -15,9 +15,8 @@ import sys
import torch
from fairseq import options, tasks, utils
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.sequence_generator import SequenceGenerator
from fairseq.utils import import_user_module
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
......@@ -56,7 +55,7 @@ def make_batches(lines, args, task, max_positions):
def main(args):
import_user_module(args)
utils.import_user_module(args)
if args.buffer_size < 1:
args.buffer_size = 1
......@@ -77,8 +76,10 @@ def main(args):
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
)
# Set dictionaries
......
......@@ -12,10 +12,9 @@ Data pre-processing: build vocabularies and binarize training data.
from collections import Counter
from itertools import zip_longest
from fairseq import options, tasks
from fairseq import options, tasks, utils
from fairseq.data import indexed_dataset
from fairseq.binarizer import Binarizer
from fairseq.utils import import_user_module
from multiprocessing import Pool
import os
......@@ -23,7 +22,7 @@ import shutil
def main(args):
import_user_module(args)
utils.import_user_module(args)
print(args)
......
......@@ -239,7 +239,20 @@ class TestLanguageModeling(unittest.TestCase):
with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'fconv_lm')
train_language_model(data_dir, 'fconv_lm', [
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--optimizer', 'nag',
'--lr', '0.1',
])
eval_lm_main(data_dir)
def test_transformer_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'transformer_lm', ['--add-bos-token'])
eval_lm_main(data_dir)
......@@ -534,7 +547,7 @@ def preprocess_lm_data(data_dir):
preprocess.main(preprocess_args)
def train_language_model(data_dir, arch):
def train_language_model(data_dir, arch, extra_flags=None):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
......@@ -542,12 +555,10 @@ def train_language_model(data_dir, arch):
'--task', 'language_modeling',
data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '0.1',
'--optimizer', 'adam',
'--lr', '0.0001',
'--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500',
'--tokens-per-sample', '500',
'--save-dir', data_dir,
......@@ -555,7 +566,7 @@ def train_language_model(data_dir, arch):
'--no-progress-bar',
'--distributed-world-size', '1',
'--ddp-backend', 'no_c10d',
],
] + (extra_flags or []),
)
train.main(train_args)
......
......@@ -38,7 +38,7 @@ class TestReproducibility(unittest.TestCase):
] + extra_flags,
)
stdout = stdout.getvalue()
train_log, valid_log = map(json.loads, stdout.split('\n')[-4:-2])
train_log, valid_log = map(json.loads, stdout.split('\n')[-5:-3])
# train epoch 2, resuming from previous checkpoint 1
os.rename(
......@@ -56,7 +56,7 @@ class TestReproducibility(unittest.TestCase):
] + extra_flags,
)
stdout = stdout.getvalue()
train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-4:-2])
train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-5:-3])
def cast(s):
return round(float(s), 3)
......
......@@ -17,15 +17,14 @@ import random
import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import import_user_module
def main(args, init_distributed=False):
import_user_module(args)
utils.import_user_module(args)
if args.max_tokens is None:
args.max_tokens = 6000
......@@ -326,14 +325,18 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt')
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment