"git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "e46b924dea823407a1822099580b0be667a6e9a4"
Commit d45db804 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#654)

Summary:
- Add --add-bos-token option to LM task
- Cleanup utils.py and options.py
Pull Request resolved: https://github.com/pytorch/fairseq/pull/654

Differential Revision: D15041794

Pulled By: myleott

fbshipit-source-id: 3ad00007769d5f48308052cfd40de39c5ffa1a6e
parent 89a69616
...@@ -46,6 +46,8 @@ class TriangularSchedule(FairseqLRScheduler): ...@@ -46,6 +46,8 @@ class TriangularSchedule(FairseqLRScheduler):
help='max learning rate, must be more than args.lr') help='max learning rate, must be more than args.lr')
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
help='initial number of updates per period (cycle length)') help='initial number of updates per period (cycle length)')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing')
parser.add_argument('--shrink-min', action='store_true', parser.add_argument('--shrink-min', action='store_true',
help='if set, also shrinks min lr') help='if set, also shrinks min lr')
# fmt: on # fmt: on
......
...@@ -16,6 +16,16 @@ class FairseqNAG(FairseqOptimizer): ...@@ -16,6 +16,16 @@ class FairseqNAG(FairseqOptimizer):
super().__init__(args, params) super().__init__(args, params)
self._optimizer = NAG(params, **self.optimizer_config) self._optimizer = NAG(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property @property
def optimizer_config(self): def optimizer_config(self):
""" """
......
...@@ -16,6 +16,16 @@ class SGD(FairseqOptimizer): ...@@ -16,6 +16,16 @@ class SGD(FairseqOptimizer):
super().__init__(args, params) super().__init__(args, params)
self._optimizer = torch.optim.SGD(params, **self.optimizer_config) self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
help='momentum factor')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property @property
def optimizer_config(self): def optimizer_config(self):
""" """
......
...@@ -303,19 +303,13 @@ def add_optimization_args(parser): ...@@ -303,19 +303,13 @@ def add_optimization_args(parser):
metavar='LR_1,LR_2,...,LR_N', metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N' help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)') ' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/ # Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau', group.add_argument('--lr-scheduler', default='fixed',
choices=LR_SCHEDULER_REGISTRY.keys(), choices=LR_SCHEDULER_REGISTRY.keys(),
help='Learning Rate Scheduler') help='Learning Rate Scheduler')
group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=-1, type=float, metavar='LR', group.add_argument('--min-lr', default=-1, type=float, metavar='LR',
help='minimum learning rate') help='stop training when the learning rate reaches this minimum')
# fmt: on # fmt: on
return group return group
......
...@@ -81,6 +81,8 @@ class LanguageModelingTask(FairseqTask): ...@@ -81,6 +81,8 @@ class LanguageModelingTask(FairseqTask):
help='include future target') help='include future target')
parser.add_argument('--past-target', action='store_true', parser.add_argument('--past-target', action='store_true',
help='include past target') help='include past target')
parser.add_argument('--add-bos-token', action='store_true',
help='prepend beginning of sentence token (<s>)')
# fmt: on # fmt: on
def __init__(self, args, dictionary, output_dictionary, targets=None): def __init__(self, args, dictionary, output_dictionary, targets=None):
...@@ -185,7 +187,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -185,7 +187,7 @@ class LanguageModelingTask(FairseqTask):
self.datasets[split] = MonolingualDataset( self.datasets[split] = MonolingualDataset(
dataset, sizes, self.dictionary, self.output_dictionary, dataset, sizes, self.dictionary, self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True,
targets=self.targets, targets=self.targets, add_bos_token=self.args.add_bos_token,
) )
def build_dataset_for_inference(self, src_tokens, src_lengths): def build_dataset_for_inference(self, src_tokens, src_lengths):
...@@ -205,6 +207,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -205,6 +207,7 @@ class LanguageModelingTask(FairseqTask):
self.target_dictionary, self.target_dictionary,
add_eos_for_other_targets=False, add_eos_for_other_targets=False,
shuffle=False, shuffle=False,
add_bos_token=self.args.add_bos_token,
), ),
eos=self.source_dictionary.eos(), eos=self.source_dictionary.eos(),
# remove EOS since this will be used as a prefix for generation # remove EOS since this will be used as a prefix for generation
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
import itertools import itertools
import os import os
from fairseq import options, utils from fairseq import options
from fairseq.data import ( from fairseq.data import (
ConcatDataset, ConcatDataset,
data_utils, data_utils,
...@@ -69,24 +69,6 @@ class TranslationTask(FairseqTask): ...@@ -69,24 +69,6 @@ class TranslationTask(FairseqTask):
help='amount to upsample primary dataset') help='amount to upsample primary dataset')
# fmt: on # fmt: on
@staticmethod
def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
model = utils.load_checkpoint_to_cpu(path)
args = model['args']
state_dict = model['model']
args = utils.override_model_args(args, arg_overrides)
src_dict = Dictionary.load(src_dict_path)
tgt_dict = Dictionary.load(tgt_dict_path)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
task = TranslationTask(args, src_dict, tgt_dict)
model = task.build_model(args)
model.upgrade_state_dict(state_dict)
model.load_state_dict(state_dict, strict=True)
return model
def __init__(self, args, src_dict, tgt_dict): def __init__(self, args, src_dict, tgt_dict):
super().__init__(args) super().__init__(args)
self.src_dict = src_dict self.src_dict = src_dict
...@@ -102,6 +84,10 @@ class TranslationTask(FairseqTask): ...@@ -102,6 +84,10 @@ class TranslationTask(FairseqTask):
args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target) args.left_pad_target = options.eval_bool(args.left_pad_target)
# upgrade old checkpoints
if isinstance(args.data, str):
args.data = [args.data]
# find language pair automatically # find language pair automatically
if args.source_lang is None or args.target_lang is None: if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0]) args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
...@@ -147,9 +133,7 @@ class TranslationTask(FairseqTask): ...@@ -147,9 +133,7 @@ class TranslationTask(FairseqTask):
src_datasets = [] src_datasets = []
tgt_datasets = [] tgt_datasets = []
data_paths = self.args.data for dk, data_path in enumerate(self.args.data):
for dk, data_path in enumerate(data_paths):
for k in itertools.count(): for k in itertools.count():
split_k = split + (str(k) if k > 0 else '') split_k = split + (str(k) if k > 0 else '')
......
...@@ -11,10 +11,11 @@ Train a network across multiple GPUs. ...@@ -11,10 +11,11 @@ Train a network across multiple GPUs.
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
import os
import torch import torch
from fairseq import distributed_utils, models, optim, utils from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.optim import lr_scheduler from fairseq.optim import lr_scheduler
...@@ -119,16 +120,31 @@ class Trainer(object): ...@@ -119,16 +120,31 @@ class Trainer(object):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters extra_state['train_meters'] = self.meters
utils.save_state( checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer, filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state, self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
) )
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = utils.load_model_state( extra_state, self._optim_history, last_optim_state = None, [], None
filename, self.get_model(),
) if os.path.exists(filename):
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
# load model parameters
try:
self.get_model().load_state_dict(state['model'], strict=True)
except Exception:
raise Exception(
'Cannot load model parameters from checkpoint, '
'please ensure that the architectures match.'
)
extra_state = state['extra_state']
self._optim_history = state['optimizer_history']
last_optim_state = state['last_optimizer_state']
if last_optim_state is not None and not reset_optimizer: if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
self._build_optimizer() self._build_optimizer()
...@@ -136,9 +152,9 @@ class Trainer(object): ...@@ -136,9 +152,9 @@ class Trainer(object):
# only reload optimizer and lr_scheduler if they match # only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1] last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \ assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
'criterion does not match; please reset the optimizer (--reset-optimizer)' 'Criterion does not match; please reset the optimizer (--reset-optimizer).'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'optimizer does not match; please reset the optimizer (--reset-optimizer)' 'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
if not reset_lr_scheduler: if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
......
...@@ -9,182 +9,25 @@ from collections import defaultdict, OrderedDict ...@@ -9,182 +9,25 @@ from collections import defaultdict, OrderedDict
from typing import Callable from typing import Callable
import copy import copy
import importlib.util import importlib.util
import logging
import os import os
import re
import sys import sys
import traceback
import warnings import warnings
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.serialization import default_restore_location
from fairseq.modules import gelu, gelu_fast from fairseq.modules import gelu, gelu_fast
def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(),
'num_updates': num_updates,
}
],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)
def load_model_state(filename, model):
if not os.path.exists(filename):
return None, [], None
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
model.upgrade_state_dict(state['model'])
# load model parameters
try:
model.load_state_dict(state['model'], strict=True)
except Exception:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
# record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
# move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = {
'best': state['optimizer_history'][-1]['best_loss'],
}
del state['optimizer_history'][-1]['best_loss']
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
return state
def load_checkpoint_to_cpu(path):
state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
return state
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference. from fairseq import checkpoint_utils
deprecation_warning(
model_arg_overrides allows you to pass a dictionary model_arg_overrides -- 'utils.load_ensemble_for_inference is deprecated. '
{'arg_name': arg} -- to override model args that were used during model 'Please use checkpoint_utils.load_model_ensemble instead.'
training )
""" return checkpoint_utils.load_model_ensemble(
# load model architectures and weights filenames, arg_overrides=model_arg_overrides, task=task,
states = [] )
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename)
states.append(state)
ensemble = []
for state in states:
args = state['args']
if model_arg_overrides is not None:
args = override_model_args(args, model_arg_overrides)
# build model for ensemble
model = task.build_model(args)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
# some args (e.g., tokens_per_sample) might have been updated while building the model
if model_arg_overrides is not None:
args = override_model_args(args, model_arg_overrides)
return ensemble, args
def override_model_args(args, model_arg_overrides):
# Uses model_arg_overrides {'arg_name': arg} to override model args
for arg_name, arg_val in model_arg_overrides.items():
setattr(args, arg_name, arg_val)
return args
def move_to_cuda(sample): def move_to_cuda(sample):
...@@ -379,25 +222,6 @@ def fill_with_neg_inf(t): ...@@ -379,25 +222,6 @@ def fill_with_neg_inf(t):
return t.float().fill_(float('-inf')).type_as(t) return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def resolve_max_positions(*args): def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources.""" """Resolve max position constraints from multiple sources."""
......
...@@ -11,9 +11,8 @@ Translate pre-processed data with a trained model. ...@@ -11,9 +11,8 @@ Translate pre-processed data with a trained model.
import torch import torch
from fairseq import bleu, options, progress_bar, tasks, utils from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module
def main(args): def main(args):
...@@ -23,7 +22,7 @@ def main(args): ...@@ -23,7 +22,7 @@ def main(args):
assert args.replace_unk is None or args.raw_text, \ assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)' '--replace-unk requires a raw text dataset (--raw-text)'
import_user_module(args) utils.import_user_module(args)
if args.max_tokens is None and args.max_sentences is None: if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000 args.max_tokens = 12000
...@@ -34,7 +33,6 @@ def main(args): ...@@ -34,7 +33,6 @@ def main(args):
# Load dataset splits # Load dataset splits
task = tasks.setup_task(args) task = tasks.setup_task(args)
task.load_dataset(args.gen_subset) task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Set dictionaries # Set dictionaries
try: try:
...@@ -45,8 +43,10 @@ def main(args): ...@@ -45,8 +43,10 @@ def main(args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference( models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
) )
# Optimize ensemble for generation # Optimize ensemble for generation
......
...@@ -15,9 +15,8 @@ import sys ...@@ -15,9 +15,8 @@ import sys
import torch import torch
from fairseq import options, tasks, utils from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
from fairseq.utils import import_user_module
Batch = namedtuple('Batch', 'ids src_tokens src_lengths') Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
...@@ -56,7 +55,7 @@ def make_batches(lines, args, task, max_positions): ...@@ -56,7 +55,7 @@ def make_batches(lines, args, task, max_positions):
def main(args): def main(args):
import_user_module(args) utils.import_user_module(args)
if args.buffer_size < 1: if args.buffer_size < 1:
args.buffer_size = 1 args.buffer_size = 1
...@@ -77,8 +76,10 @@ def main(args): ...@@ -77,8 +76,10 @@ def main(args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference( models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
) )
# Set dictionaries # Set dictionaries
......
...@@ -12,10 +12,9 @@ Data pre-processing: build vocabularies and binarize training data. ...@@ -12,10 +12,9 @@ Data pre-processing: build vocabularies and binarize training data.
from collections import Counter from collections import Counter
from itertools import zip_longest from itertools import zip_longest
from fairseq import options, tasks from fairseq import options, tasks, utils
from fairseq.data import indexed_dataset from fairseq.data import indexed_dataset
from fairseq.binarizer import Binarizer from fairseq.binarizer import Binarizer
from fairseq.utils import import_user_module
from multiprocessing import Pool from multiprocessing import Pool
import os import os
...@@ -23,7 +22,7 @@ import shutil ...@@ -23,7 +22,7 @@ import shutil
def main(args): def main(args):
import_user_module(args) utils.import_user_module(args)
print(args) print(args)
......
...@@ -239,7 +239,20 @@ class TestLanguageModeling(unittest.TestCase): ...@@ -239,7 +239,20 @@ class TestLanguageModeling(unittest.TestCase):
with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir: with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_lm_data(data_dir) preprocess_lm_data(data_dir)
train_language_model(data_dir, 'fconv_lm') train_language_model(data_dir, 'fconv_lm', [
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--optimizer', 'nag',
'--lr', '0.1',
])
eval_lm_main(data_dir)
def test_transformer_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'transformer_lm', ['--add-bos-token'])
eval_lm_main(data_dir) eval_lm_main(data_dir)
...@@ -534,7 +547,7 @@ def preprocess_lm_data(data_dir): ...@@ -534,7 +547,7 @@ def preprocess_lm_data(data_dir):
preprocess.main(preprocess_args) preprocess.main(preprocess_args)
def train_language_model(data_dir, arch): def train_language_model(data_dir, arch, extra_flags=None):
train_parser = options.get_training_parser() train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch( train_args = options.parse_args_and_arch(
train_parser, train_parser,
...@@ -542,12 +555,10 @@ def train_language_model(data_dir, arch): ...@@ -542,12 +555,10 @@ def train_language_model(data_dir, arch):
'--task', 'language_modeling', '--task', 'language_modeling',
data_dir, data_dir,
'--arch', arch, '--arch', arch,
'--optimizer', 'nag', '--optimizer', 'adam',
'--lr', '0.1', '--lr', '0.0001',
'--criterion', 'adaptive_loss', '--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15', '--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500', '--max-tokens', '500',
'--tokens-per-sample', '500', '--tokens-per-sample', '500',
'--save-dir', data_dir, '--save-dir', data_dir,
...@@ -555,7 +566,7 @@ def train_language_model(data_dir, arch): ...@@ -555,7 +566,7 @@ def train_language_model(data_dir, arch):
'--no-progress-bar', '--no-progress-bar',
'--distributed-world-size', '1', '--distributed-world-size', '1',
'--ddp-backend', 'no_c10d', '--ddp-backend', 'no_c10d',
], ] + (extra_flags or []),
) )
train.main(train_args) train.main(train_args)
......
...@@ -38,7 +38,7 @@ class TestReproducibility(unittest.TestCase): ...@@ -38,7 +38,7 @@ class TestReproducibility(unittest.TestCase):
] + extra_flags, ] + extra_flags,
) )
stdout = stdout.getvalue() stdout = stdout.getvalue()
train_log, valid_log = map(json.loads, stdout.split('\n')[-4:-2]) train_log, valid_log = map(json.loads, stdout.split('\n')[-5:-3])
# train epoch 2, resuming from previous checkpoint 1 # train epoch 2, resuming from previous checkpoint 1
os.rename( os.rename(
...@@ -56,7 +56,7 @@ class TestReproducibility(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestReproducibility(unittest.TestCase):
] + extra_flags, ] + extra_flags,
) )
stdout = stdout.getvalue() stdout = stdout.getvalue()
train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-4:-2]) train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-5:-3])
def cast(s): def cast(s):
return round(float(s), 3) return round(float(s), 3)
......
...@@ -17,15 +17,14 @@ import random ...@@ -17,15 +17,14 @@ import random
import torch import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
from fairseq.data import iterators from fairseq.data import iterators
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import import_user_module
def main(args, init_distributed=False): def main(args, init_distributed=False):
import_user_module(args) utils.import_user_module(args)
if args.max_tokens is None: if args.max_tokens is None:
args.max_tokens = 6000 args.max_tokens = 6000
...@@ -326,14 +325,18 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -326,14 +325,18 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if not end_of_epoch and args.keep_interval_updates > 0: if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order # remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_interval_updates:]: for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk): if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
if args.keep_last_epochs > 0: if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order # remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt') checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]: for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk): if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment