Commit 4c2ef2de authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

Conv lm implementation

This implements convolutional language model from https://arxiv.org/pdf/1612.08083.pdf

There are 3 modes for constructing batches:

- token block: fill each sample with a specified number of tokens without regard for sentence delimiters - this is what was used for training in the paper
- complete: fill each sample with a specified number of tokens but make sure it contains only complete sentences (i.e. if next sentence goes over token block limit, move it to the next sample) - this was used for evaluation in the paper
- eos: one sentence per sample (skip blank lines)

some results:

GCNN-13 - GBW - 37.46
GCNN-14B - GBW - 33.88
GCNN-8 - Wiki103 - 43.76
GCNN-14 - Wiki103 - 35.66

train:

python train.py /private/home/abaevski/data/wiki103 --save-dir /tmp --fp16 --max-epoch 35 --save-interval 1 --save-interval-updates 1000 --keep-interval-updates 25 --arch fconv_lm --optimizer nag --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 --decoder-embed-dim 280 --decoder-layers '[(850, 6)] * 3 + [(850,1)] + [(850,5)] * 4 + [(850,1)] + [(850,4)] * 3 + [(1024,4)] + [(2048, 4)]' --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion cross_entropy --max-tokens 1024 --max-target-positions 1024 --seed 1 --log-format json --log-interval 500

eval:

python eval_lm.py ~abaevski/data/wiki103 --path '/checkpoint02/abaevski/2018-04-27/lm_wiki.fp16.mxup300000.fconv.adam.lrs=reduce_lr_on_plateau.emb280.layers(850,6)*3+(850,1)+(850,5)*4+(850,1)+(850,4)*3+(1024,1)+(2048,4).lr0.0005.clp0.1.drp0.3.wd0.0.crt=cross_entropy.mxtk2048.smptk256.seed1.ngpu8/checkpoint_last.pt'
parent 4e1ec2d8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn.functional as F
from torch import nn
class AdaptiveSoftmax(nn.Module):
"""This is an implementation of the efficient softmax approximation for graphical processing units (GPU),
described in the paper "Efficient softmax approximation for GPUs" (http://arxiv.org/abs/1609.04309)."""
def __init__(self, vocab_size, input_dim, cutoff, dropout):
super().__init__()
if vocab_size > cutoff[-1]:
cutoff = cutoff + [vocab_size]
output_dim = cutoff[0] + len(cutoff) - 1
self.vocab_size = vocab_size
self.cutoff = cutoff
self.dropout = dropout
self.lsm = nn.LogSoftmax(dim=1)
self.head = nn.Linear(input_dim, output_dim, bias=False)
self.tail = nn.ModuleList()
for i in range(len(cutoff) - 1):
self.tail.append(
nn.Sequential(
nn.Linear(input_dim, input_dim // 4 ** i, bias=False),
nn.Dropout(dropout),
nn.Linear(input_dim // 4 ** i, cutoff[i + 1] - cutoff[i], bias=False)
)
)
def init_weights(m):
if hasattr(m, 'weight'):
nn.init.xavier_uniform(m.weight)
self.apply(init_weights)
def adapt_target(self, target):
"""In order to be efficient, the AdaptiveSoftMax does not compute the scores for all the word of the
vocabulary for all the examples.It is thus necessary to call the method adapt_target of the AdaptiveSoftMax
layer inside each forward pass."""
target = target.view(-1)
new_target = [target.clone()]
target_idxs = []
for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
new_target[0][mask] = self.cutoff[0] + i - 1
if mask.any():
target_idxs.append(mask.nonzero().squeeze(1))
new_target.append(target[mask].add(-self.cutoff[i]))
else:
target_idxs.append(None)
new_target.append(None)
return new_target, target_idxs
def forward(self, input, target):
""" accepts input (b x t x d) and target (b x t) and returns
2 lists: output for each cutoff section and new targets by cut off """
input = input.contiguous().view(-1, input.size(-1))
input = F.dropout(input, p=self.dropout, training=self.training)
new_target, target_idxs = self.adapt_target(target)
output = [self.head(input)]
for i in range(len(target_idxs)):
if target_idxs[i] is not None:
output.append(self.tail[i](input.index_select(0, target_idxs[i])))
else:
output.append(None)
return output, new_target
def get_log_prob(self, input, target):
"""computes the log probabilities for all the words of the vocabulary, given a 2D tensor of hidden vectors"""
bsz, length, dim = input.size()
input = input.contiguous().view(-1, dim)
if target is not None:
_, target_idxs = self.adapt_target(target)
else:
target_idxs = None
head_y = self.head(input)
log_probs = head_y.new_zeros(input.size(0), self.vocab_size)
head_sz = self.cutoff[0] + len(self.tail)
log_probs[:, :head_sz] = self.lsm(head_y)
tail_priors = log_probs[:, self.cutoff[0] - 1: head_sz - 1].clone()
for i in range(len(self.tail)):
start = self.cutoff[i]
end = self.cutoff[i + 1]
if target_idxs is None:
tail_out = log_probs[:, start:end]
tail_out.copy_(self.tail[i](input))
log_probs[:, start:end] = self.lsm(tail_out).add_(tail_priors[:, i, None])
elif target_idxs[i] is not None:
idxs = target_idxs[i]
tail_out = log_probs[idxs, start:end]
tail_out.copy_(self.tail[i](input[idxs]))
log_probs[idxs, start:end] = self.lsm(tail_out).add_(tail_priors[idxs, i, None])
log_probs = log_probs.view(bsz, length, -1)
return log_probs
...@@ -43,6 +43,13 @@ def _eval_float_list(x): ...@@ -43,6 +43,13 @@ def _eval_float_list(x):
return [float(x)] return [float(x)]
def get_eval_lm_parser():
parser = get_parser('Evaluate Language Model')
add_dataset_args(parser, gen=True)
add_eval_lm_args(parser)
return parser
def parse_args_and_arch(parser, input_args=None): def parse_args_and_arch(parser, input_args=None):
# The parser doesn't know about model/criterion/optimizer-specific args, so # The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we # we parse twice. First we parse the model/criterion/optimizer, then we
...@@ -102,7 +109,7 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -102,7 +109,7 @@ def add_dataset_args(parser, train=False, gen=False):
help='target language') help='target language')
group.add_argument('--max-source-positions', default=1024, type=int, metavar='N', group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence') help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', default=1024, type=int, metavar='N', group.add_argument('--max-target-positions', '--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence') help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set') help='Ignore too long or too short lines in valid and test set')
...@@ -110,6 +117,12 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -110,6 +117,12 @@ def add_dataset_args(parser, train=False, gen=False):
help='maximum number of tokens in a batch') help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch') help='maximum number of sentences in a batch')
group.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='Used only for LM datasets. If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end of sentence, but may include '
'multiple sentences per sample. If set to "eos", includes only one sentence per sample')
if train: if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT', group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'], choices=['train', 'valid', 'test'],
...@@ -216,10 +229,24 @@ def add_checkpoint_args(parser): ...@@ -216,10 +229,24 @@ def add_checkpoint_args(parser):
return group return group
def add_generation_args(parser): def add_common_eval_args(group):
group = parser.add_argument_group('Generation')
group.add_argument('--path', metavar='FILE', action='append', group.add_argument('--path', metavar='FILE', action='append',
help='path(s) to model file(s)') help='path(s) to model file(s)')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group)
def add_generation_args(parser):
group = parser.add_argument_group('Generation')
add_common_eval_args(group)
group.add_argument('--beam', default=5, type=int, metavar='N', group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size') help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N', group.add_argument('--nbest', default=1, type=int, metavar='N',
...@@ -230,15 +257,12 @@ def add_generation_args(parser): ...@@ -230,15 +257,12 @@ def add_generation_args(parser):
group.add_argument('--max-len-b', default=200, type=int, metavar='N', group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, ' help=('generate sequences of maximum length ax + b, '
'where x is the source length')) 'where x is the source length'))
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--no-early-stop', action='store_true', group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam ' help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases ' 'hypotheses; this is more correct, but increases '
'generation time by 50%%')) 'generation time by 50%%'))
group.add_argument('--unnormalized', action='store_true', group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores') help='compare unnormalized hypothesis scores')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--no-beamable-mm', action='store_true', group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers') help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float, group.add_argument('--lenpen', default=1, type=float,
...@@ -247,8 +271,6 @@ def add_generation_args(parser): ...@@ -247,8 +271,6 @@ def add_generation_args(parser):
help='unknown word penalty: <0 produces more unks, >0 produces fewer') help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None, group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)') help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
group.add_argument('--score-reference', action='store_true', group.add_argument('--score-reference', action='store_true',
help='just score the reference translation') help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS', group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
......
...@@ -9,7 +9,6 @@ import math ...@@ -9,7 +9,6 @@ import math
import torch import torch
from fairseq import utils from fairseq import utils
from fairseq.data import LanguagePairDataset
from fairseq.models import FairseqIncrementalDecoder from fairseq.models import FairseqIncrementalDecoder
......
...@@ -28,8 +28,6 @@ class SequenceScorer(object): ...@@ -28,8 +28,6 @@ class SequenceScorer(object):
if timer is not None: if timer is not None:
timer.start() timer.start()
pos_scores, attn = self.score(s) pos_scores, attn = self.score(s)
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id'].data): for i, id in enumerate(s['id'].data):
# remove padding from ref # remove padding from ref
src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad) src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
...@@ -37,8 +35,11 @@ class SequenceScorer(object): ...@@ -37,8 +35,11 @@ class SequenceScorer(object):
tgt_len = ref.numel() tgt_len = ref.numel()
pos_scores_i = pos_scores[i][:tgt_len] pos_scores_i = pos_scores[i][:tgt_len]
score_i = pos_scores_i.sum() / tgt_len score_i = pos_scores_i.sum() / tgt_len
attn_i = attn[i] if attn is not None:
_, alignment = attn_i.max(dim=0) attn_i = attn[i]
_, alignment = attn_i.max(dim=0)
else:
attn_i = alignment = None
hypos = [{ hypos = [{
'tokens': ref, 'tokens': ref,
'score': score_i, 'score': score_i,
...@@ -46,6 +47,8 @@ class SequenceScorer(object): ...@@ -46,6 +47,8 @@ class SequenceScorer(object):
'alignment': alignment, 'alignment': alignment,
'positional_scores': pos_scores_i, 'positional_scores': pos_scores_i,
}] }]
if timer is not None:
timer.stop(s['ntokens'])
# return results in the same format as SequenceGenerator # return results in the same format as SequenceGenerator
yield id, src, ref, hypos yield id, src, ref, hypos
...@@ -59,16 +62,10 @@ class SequenceScorer(object): ...@@ -59,16 +62,10 @@ class SequenceScorer(object):
for model in self.models: for model in self.models:
with utils.maybe_no_grad(): with utils.maybe_no_grad():
model.eval() model.eval()
encoder_out = model.encoder( decoder_out = model.forward(**net_input)
net_input['src_tokens'],
net_input['src_lengths'],
)
decoder_out = model.decoder(
net_input['prev_output_tokens'],
encoder_out,
)
attn = decoder_out[1] attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
probs = model.get_normalized_probs(decoder_out, log_probs=False, sample=sample).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
......
...@@ -68,7 +68,7 @@ def load_model_state(filename, model): ...@@ -68,7 +68,7 @@ def load_model_state(filename, model):
return None, [], None return None, [], None
state = torch.load(filename) state = torch.load(filename)
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model']) model.upgrade_state_dict(state['model'])
# load model parameters # load model parameters
try: try:
...@@ -134,7 +134,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, ...@@ -134,7 +134,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
{'arg_name': arg} -- to override model args that were used during model {'arg_name': arg} -- to override model args that were used during model
training training
""" """
from fairseq import data, models from fairseq import models
from fairseq.data import data_utils
# load model architectures and weights # load model architectures and weights
states = [] states = []
...@@ -150,7 +151,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, ...@@ -150,7 +151,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
if src_dict is None or dst_dict is None: if src_dict is None or dst_dict is None:
assert data_dir is not None assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang) src_dict, dst_dict = data_utils.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble # build ensemble
ensemble = [] ensemble = []
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
import torch import torch
from fairseq import bleu, data, options, progress_bar, tokenizer, utils from fairseq import bleu, options, progress_bar, tokenizer, utils
from fairseq.data import data_utils, data_loaders
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
...@@ -27,23 +28,7 @@ def main(args): ...@@ -27,23 +28,7 @@ def main(args):
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset # Load dataset
if args.replace_unk is None: dataset = data_loaders.load_dataset(args, [args.gen_subset], args.replace_unk is not None)
dataset = data.load_dataset(
args.data,
[args.gen_subset],
args.source_lang,
args.target_lang,
)
else:
dataset = data.load_raw_text_dataset(
args.data,
[args.gen_subset],
args.source_lang,
args.target_lang,
)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
...@@ -75,7 +60,7 @@ def main(args): ...@@ -75,7 +60,7 @@ def main(args):
if args.num_shards > 1: if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards: if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards') raise ValueError('--shard-id must be between 0 and num_shards')
itr = data.sharded_iterator(itr, args.num_shards, args.shard_id) itr = data_utils.sharded_iterator(itr, args.num_shards, args.shard_id)
# Initialize generator # Initialize generator
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
......
...@@ -13,7 +13,8 @@ from collections import namedtuple ...@@ -13,7 +13,8 @@ from collections import namedtuple
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import options, tokenizer, utils from fairseq import options, tokenizer, utils
from fairseq.data import LanguagePairDataset from fairseq.data.data_utils import collate_tokens
from fairseq.data.consts import LEFT_PAD_SOURCE
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths') Batch = namedtuple('Batch', 'srcs tokens lengths')
...@@ -41,9 +42,8 @@ def make_batches(lines, batch_size, src_dict): ...@@ -41,9 +42,8 @@ def make_batches(lines, batch_size, src_dict):
batches = np.array_split(indices, num_batches) batches = np.array_split(indices, num_batches)
for batch_idxs in batches: for batch_idxs in batches:
batch_toks = [tokens[i] for i in batch_idxs] batch_toks = [tokens[i] for i in batch_idxs]
batch_toks = LanguagePairDataset.collate_tokens(batch_toks, src_dict.pad(), src_dict.eos(), batch_toks = collate_tokens(batch_toks, src_dict.pad(), src_dict.eos(), LEFT_PAD_SOURCE,
LanguagePairDataset.LEFT_PAD_SOURCE, move_eos_to_beginning=False)
move_eos_to_beginning=False)
yield Batch( yield Batch(
srcs=[lines[i] for i in batch_idxs], srcs=[lines[i] for i in batch_idxs],
tokens=batch_toks, tokens=batch_toks,
......
...@@ -12,7 +12,7 @@ from itertools import zip_longest ...@@ -12,7 +12,7 @@ from itertools import zip_longest
import os import os
import shutil import shutil
from fairseq import dictionary, indexed_dataset from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line from fairseq.tokenizer import Tokenizer, tokenize_line
...@@ -54,33 +54,53 @@ def main(args): ...@@ -54,33 +54,53 @@ def main(args):
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line) Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
return d return d
def train_path(lang):
return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '')
def file_name(prefix, lang):
fname = prefix
if lang is not None:
fname += f'.{lang}'
return fname
def dest_path(prefix, lang):
return os.path.join(args.destdir, file_name(prefix, lang))
def dict_path(lang):
return dest_path('dict', lang) + '.txt'
def dataset_dest_path(output_prefix, lang, extension):
base = f'{args.destdir}/{output_prefix}'
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
return f'{base}{lang_part}.{extension}'
if args.joined_dictionary: if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary' assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary' assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
src_dict = build_dictionary([ src_dict = build_dictionary(set([
'{}.{}'.format(args.trainpref, lang) train_path(lang)
for lang in [args.source_lang, args.target_lang] for lang in [args.source_lang, args.target_lang]
]) ]))
tgt_dict = src_dict tgt_dict = src_dict
else: else:
if args.srcdict: if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict) src_dict = dictionary.Dictionary.load(args.srcdict)
else: else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified" assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.source_lang)]) src_dict = build_dictionary([train_path(args.source_lang)])
if target: if target:
if args.tgtdict: if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict) tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else: else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary(['{}.{}'.format(args.trainpref, args.target_lang)]) tgt_dict = build_dictionary([train_path(args.target_lang)])
src_dict.finalize( src_dict.finalize(
threshold=args.thresholdsrc, threshold=args.thresholdsrc,
nwords=args.nwordssrc, nwords=args.nwordssrc,
padding_factor=args.padding_factor, padding_factor=args.padding_factor,
) )
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang))) src_dict.save(dict_path(args.source_lang))
if target: if target:
if not args.joined_dictionary: if not args.joined_dictionary:
tgt_dict.finalize( tgt_dict.finalize(
...@@ -88,36 +108,31 @@ def main(args): ...@@ -88,36 +108,31 @@ def main(args):
nwords=args.nwordstgt, nwords=args.nwordstgt,
padding_factor=args.padding_factor, padding_factor=args.padding_factor,
) )
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang))) tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang): def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang))) dict = dictionary.Dictionary.load(dict_path(lang))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1)) print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
ds = indexed_dataset.IndexedDatasetBuilder( ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_path(output_prefix, lang, 'bin'))
'{}/{}.{}-{}.{}.bin'.format(args.destdir, output_prefix, args.source_lang,
args.target_lang, lang)
)
def consumer(tensor): def consumer(tensor):
ds.add_item(tensor) ds.add_item(tensor)
input_file = '{}.{}'.format(input_prefix, lang) input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
res = Tokenizer.binarize(input_file, dict, consumer) res = Tokenizer.binarize(input_file, dict, consumer)
print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format( print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
lang, input_file, res['nseq'], res['ntok'], lang, input_file, res['nseq'], res['ntok'],
100 * res['nunk'] / res['ntok'], dict.unk_word)) 100 * res['nunk'] / res['ntok'], dict.unk_word))
ds.finalize('{}/{}.{}-{}.{}.idx'.format( ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))
args.destdir, output_prefix,
args.source_lang, args.target_lang, lang))
def make_dataset(input_prefix, output_prefix, lang, output_format='binary'): def make_dataset(input_prefix, output_prefix, lang, output_format='binary'):
if output_format == 'binary': if output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang) make_binary_dataset(input_prefix, output_prefix, lang)
elif output_format == 'raw': elif output_format == 'raw':
# Copy original text file to destination folder # Copy original text file to destination folder
output_text_file = os.path.join(args.destdir, '{}.{}'.format(output_prefix, lang)) output_text_file = dest_path(output_prefix, lang)
shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file) shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(args, make_dataset, lang): def make_all(args, make_dataset, lang):
if args.trainpref: if args.trainpref:
...@@ -139,10 +154,10 @@ def main(args): ...@@ -139,10 +154,10 @@ def main(args):
if args.alignfile: if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified" assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = '{}.{}'.format(args.trainpref, args.source_lang) src_file_name = train_path(args.source_lang)
tgt_file_name = '{}.{}'.format(args.trainpref, args.target_lang) tgt_file_name = train_path(args.target_lang)
src_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang))) src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
tgt_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang))) tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
freq_map = {} freq_map = {}
with open(args.alignfile, 'r') as align_file: with open(args.alignfile, 'r') as align_file:
with open(src_file_name, 'r') as src_file: with open(src_file_name, 'r') as src_file:
......
...@@ -11,7 +11,8 @@ import argparse ...@@ -11,7 +11,8 @@ import argparse
import os import os
import sys import sys
from fairseq import bleu, dictionary, tokenizer from fairseq import bleu, tokenizer
from fairseq.data import dictionary
def main(): def main():
......
...@@ -84,6 +84,8 @@ class TestBinaries(unittest.TestCase): ...@@ -84,6 +84,8 @@ class TestBinaries(unittest.TestCase):
'--max-epoch', '1', '--max-epoch', '1',
'--no-progress-bar', '--no-progress-bar',
'--distributed-world-size', '1', '--distributed-world-size', '1',
'--source-lang', 'in',
'--target-lang', 'out',
], ],
) )
train.main(train_args) train.main(train_args)
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
from unittest.mock import MagicMock, patch
import train
def mock_trainer(epoch, num_updates):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {'epoch': epoch}
trainer.get_num_updates.return_value = num_updates
return trainer
def mock_loader(length):
ds = MagicMock()
ds.__len__.return_value = length
loader = MagicMock()
loader.__next__.return_value = ds
return loader
class TestLoadCheckpoint(unittest.TestCase):
def setUp(self):
self.patches = {
'os.makedirs': MagicMock(),
'os.path.join': MagicMock(),
'os.path.isfile': MagicMock(return_value=True),
}
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
trainer = mock_trainer(2, 200)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 50)
self.assertNotIsInstance(ds, MagicMock)
def test_load_full_checkpoint(self):
trainer = mock_trainer(2, 150)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 150)
self.assertIsInstance(ds, MagicMock)
def test_load_no_checkpoint(self):
trainer = mock_trainer(0, 0)
loader = mock_loader(150)
self.patches['os.path.isfile'].return_value = False
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 1)
self.assertEqual(len(ds), 150)
self.assertIsInstance(ds, MagicMock)
def tearDown(self):
patch.stopall()
if __name__ == '__main__':
unittest.main()
...@@ -8,7 +8,9 @@ ...@@ -8,7 +8,9 @@
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import data, dictionary, utils from fairseq.data.language_pair_dataset import collate
from fairseq import utils
from fairseq.data import dictionary
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
...@@ -45,10 +47,11 @@ def dummy_dataloader( ...@@ -45,10 +47,11 @@ def dummy_dataloader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
collate_fn=( collate_fn=(
lambda samples: data.LanguagePairDataset.collate( lambda samples: collate(
samples, samples,
padding_idx, padding_idx,
eos_idx, eos_idx,
has_target=True,
) )
), ),
) )
...@@ -134,7 +137,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder): ...@@ -134,7 +137,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
return Variable(probs), Variable(attn) return Variable(probs), Variable(attn)
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs, _):
# the decoder returns probabilities directly # the decoder returns probabilities directly
probs = net_output[0] probs = net_output[0]
if log_probs: if log_probs:
......
...@@ -11,7 +11,8 @@ import os ...@@ -11,7 +11,8 @@ import os
import math import math
import torch import torch
from fairseq import criterions, data, models, options, progress_bar from fairseq import criterions, models, options, progress_bar
from fairseq.data import data_utils, data_loaders, OffsetDataset
from fairseq.fp16_trainer import FP16Trainer from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
...@@ -72,10 +73,10 @@ def main(args): ...@@ -72,10 +73,10 @@ def main(args):
) )
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
epoch = load_checkpoint(args, trainer, train_dataloader) epoch, next_ds = load_checkpoint(args, trainer, train_dataloader)
# Send a dummy batch to warm the caching allocator # Send a dummy batch to warm the caching allocator
dummy_batch = data.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict) dummy_batch = data_utils.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict)
trainer.dummy_train_step(dummy_batch) trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small # Train until the learning rate gets too small
...@@ -87,7 +88,7 @@ def main(args): ...@@ -87,7 +88,7 @@ def main(args):
train_meter.start() train_meter.start()
while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update: while lr > args.min_lr and epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch # train for one epoch
train(args, trainer, next(train_dataloader), epoch, dataset) train(args, trainer, next_ds, epoch, dataset)
if epoch % args.validate_interval == 0: if epoch % args.validate_interval == 0:
first_val_loss = val_loss(args, trainer, dataset, epoch) first_val_loss = val_loss(args, trainer, dataset, epoch)
...@@ -100,19 +101,14 @@ def main(args): ...@@ -100,19 +101,14 @@ def main(args):
save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss) save_checkpoint(trainer, args, epoch, end_of_epoch=True, val_loss=first_val_loss)
epoch += 1 epoch += 1
next_ds = next(train_dataloader)
train_meter.stop() train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum)) print('| done training in {:.1f} seconds'.format(train_meter.sum))
def load_dataset(args, splits): def load_dataset(args, splits):
if data.has_binary_files(args.data, splits): is_raw = not data_utils.has_binary_files(args.data, splits)
dataset = data.load_dataset(args.data, splits, args.source_lang, args.target_lang) dataset = data_loaders.load_dataset(args, splits, is_raw)
else:
dataset = data.load_raw_text_dataset(args.data, splits, args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
return dataset return dataset
...@@ -311,17 +307,29 @@ def load_checkpoint(args, trainer, train_dataloader): ...@@ -311,17 +307,29 @@ def load_checkpoint(args, trainer, train_dataloader):
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file) checkpoint_path = os.path.join(args.save_dir, args.restore_file)
epoch = 1 epoch = 1
ds = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path) extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None: if extra_state is not None:
epoch = extra_state['epoch'] epoch = extra_state['epoch']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch)) trainer_updates = trainer.get_num_updates()
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(checkpoint_path, epoch, trainer_updates))
trainer.lr_step(epoch) trainer.lr_step(epoch)
updates = 0
for i in range(epoch): for i in range(epoch):
_ = next(train_dataloader) ds = next(train_dataloader)
epoch += 1 updates += len(ds)
if ds is not None and updates > trainer_updates:
ds = OffsetDataset(ds, updates - trainer_updates)
else:
ds = next(train_dataloader)
epoch += 1
trainer.get_meter('wall').reset(init=extra_state.get('wall_time', 0)) trainer.get_meter('wall').reset(init=extra_state.get('wall_time', 0))
return epoch return epoch, ds or next(train_dataloader)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment