Commit 4c2ef2de authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

Conv lm implementation

This implements convolutional language model from https://arxiv.org/pdf/1612.08083.pdf

There are 3 modes for constructing batches:

- token block: fill each sample with a specified number of tokens without regard for sentence delimiters - this is what was used for training in the paper
- complete: fill each sample with a specified number of tokens but make sure it contains only complete sentences (i.e. if next sentence goes over token block limit, move it to the next sample) - this was used for evaluation in the paper
- eos: one sentence per sample (skip blank lines)

some results:

GCNN-13 - GBW - 37.46
GCNN-14B - GBW - 33.88
GCNN-8 - Wiki103 - 43.76
GCNN-14 - Wiki103 - 35.66

train:

python train.py /private/home/abaevski/data/wiki103 --save-dir /tmp --fp16 --max-epoch 35 --save-interval 1 --save-interval-updates 1000 --keep-interval-updates 25 --arch fconv_lm --optimizer nag --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 --decoder-embed-dim 280 --decoder-layers '[(850, 6)] * 3 + [(850,1)] + [(850,5)] * 4 + [(850,1)] + [(850,4)] * 3 + [(1024,4)] + [(2048, 4)]' --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion cross_entropy --max-tokens 1024 --max-target-positions 1024 --seed 1 --log-format json --log-interval 500

eval:

python eval_lm.py ~abaevski/data/wiki103 --path '/checkpoint02/abaevski/2018-04-27/lm_wiki.fp16.mxup300000.fconv.adam.lrs=reduce_lr_on_plateau.emb280.layers(850,6)*3+(850,1)+(850,5)*4+(850,1)+(850,4)*3+(1024,1)+(2048,4).lr0.0005.clp0.1.drp0.3.wd0.0.crt=cross_entropy.mxtk2048.smptk256.seed1.ngpu8/checkpoint_last.pt'
parent 4e1ec2d8
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from fairseq import options, utils, progress_bar
from fairseq.data import data_utils, data_loaders
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
def main(args):
assert args.path is not None, '--path required for evaluation!'
print(args)
if args.max_target_positions is None:
args.max_target_positions = 1024
use_cuda = torch.cuda.is_available() and not args.cpu
dataset = data_loaders.load_dataset(args, [args.gen_subset], False)
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| Dictionary: {} types'.format(len(dataset.src_dict)))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()
model.src_dict = dataset.src_dict
model.dst_dict = dataset.dst_dict
itr = dataset.eval_dataloader(
args.gen_subset,
max_sentences=args.max_sentences or 4,
max_positions=args.max_target_positions or 1024,
descending=True,
)
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data_utils.sharded_iterator(itr, args.num_shards, args.shard_id)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(models)
if use_cuda:
scorer.cuda()
score_sum = 0.
count = 0
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
pos_scores = hypo['positional_scores']
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
dataset.src_dict.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum()
count += pos_scores.numel()
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
avg_nll_loss = -score_sum / count
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
if __name__ == '__main__':
parser = options.get_eval_lm_parser()
args = parser.parse_args()
main(args)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch.nn.functional as F
from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('adaptive_loss')
class AdaptiveLoss(FairseqCriterion):
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def __init__(self, args, src_dict, dst_dict):
super().__init__(args, src_dict, dst_dict)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert hasattr(model.decoder, 'adaptive_softmax') and model.decoder.adaptive_softmax is not None
adaptive_softmax = model.decoder.adaptive_softmax
net_output = model(**sample['net_input'])
target = model.get_targets(sample, net_output).view(-1)
bsz = target.size(0)
logits, target = adaptive_softmax(net_output[0], target)
assert len(target) == len(logits)
loss = net_output[0].new(1 if reduce else bsz).zero_()
for i in range(len(target)):
if target[i] is not None:
assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1))
loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'sample_size': sample_size,
}
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .dictionary import Dictionary
from .token_block_dataset import TokenBlockDataset
from .language_dataset import LanguageDatasets
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .offset_dataset import OffsetDataset
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# padding constants
LEFT_PAD_SOURCE = True
LEFT_PAD_TARGET = False
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq.data import LanguagePairDataset, MonolingualDataset
from fairseq.data.data_utils import infer_language_pair
def load_dataset(args, splits, is_raw):
""" Detect if we have a multi language dataset, or a single language dataset """
if args.source_lang is None and args.target_lang is None:
# find language pair automatically
args.source_lang, args.target_lang = infer_language_pair(args.data, splits)
if args.source_lang is None and args.target_lang is None and all(
os.path.exists(os.path.join(args.data, '{}.bin'.format(split))) for split in splits):
cls = MonolingualDataset
else:
cls = LanguagePairDataset
return cls.create_dataset(args, splits, is_raw)
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import contextlib import contextlib
import itertools
import glob import glob
import math import math
import numbers import numbers
...@@ -17,13 +16,13 @@ import torch ...@@ -17,13 +16,13 @@ import torch
from torch.autograd import Variable from torch.autograd import Variable
import torch.utils.data import torch.utils.data
from fairseq.dictionary import Dictionary from fairseq.data.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset from fairseq.data.indexed_dataset import SizedDataset
def has_binary_files(data_dir, splits): def has_binary_files(data_dir, splits):
for split in splits: for split in splits:
if len(glob.glob(os.path.join(data_dir, '{}.*-*.*.bin'.format(split)))) < 2: if len(glob.glob(os.path.join(data_dir, '{}*.bin'.format(split)))) == 0:
return False return False
return True return True
...@@ -34,7 +33,7 @@ def infer_language_pair(path, splits): ...@@ -34,7 +33,7 @@ def infer_language_pair(path, splits):
for filename in os.listdir(path): for filename in os.listdir(path):
parts = filename.split('.') parts = filename.split('.')
for split in splits: for split in splits:
if parts[0] == split and parts[-1] == 'idx': if len(parts) >= 3 and parts[0] == split and parts[-1] == 'idx':
src, dst = parts[1].split('-') src, dst = parts[1].split('-')
break break
return src, dst return src, dst
...@@ -47,148 +46,8 @@ def load_dictionaries(path, src_lang, dst_lang): ...@@ -47,148 +46,8 @@ def load_dictionaries(path, src_lang, dst_lang):
return src_dict, dst_dict return src_dict, dst_dict
def load_dataset(path, load_splits, src=None, dst=None): def fmt_path(path, fmt, *args):
"""Loads specified data splits (e.g., test, train or valid) from the return os.path.join(path, fmt.format(*args))
specified folder and check that files exist."""
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from binary files
def all_splits_exist(src, dst, lang):
for split in load_splits:
filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang)
if not os.path.exists(os.path.join(path, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst, src):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + path)
def fmt_path(fmt, *args):
return os.path.join(path, fmt.format(*args))
for split in load_splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path('{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path):
break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
target_dataset,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
def load_raw_text_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from raw text
files in the specified folder."""
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from raw text files
for split in load_splits:
src_path = os.path.join(path, '{}.{}'.format(split, src))
dst_path = os.path.join(path, '{}.{}'.format(split, dst))
dataset.splits[split] = LanguagePairDataset(
IndexedRawTextDataset(src_path, src_dict),
IndexedRawTextDataset(dst_path, dst_dict),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
self.dst = dst
self.src_dict = src_dict
self.dst_dict = dst_dict
self.splits = {}
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader_generator(
self, split, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), seed=None, sample_without_replacement=0,
shard_id=0, num_shards=1
):
dataset = self.splits[split]
with numpy_seed(seed):
batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, max_positions=max_positions,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple=8,
)
frozen_batches = tuple(batches) # freeze
def dataloader(b):
b = mask_batches(b, shard_id=shard_id, num_shards=num_shards) # shard dataset
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
for epoch in itertools.count(1):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(seed + epoch):
batches = list(frozen_batches) # copy
np.random.shuffle(batches)
if sample_without_replacement > 0:
# emit sub-epoch dataloaders
while len(batches) >= sample_without_replacement:
sampled_batches = batches[:sample_without_replacement]
remaining_batches = batches[sample_without_replacement:]
yield dataloader(sampled_batches)
batches = remaining_batches
if len(batches) > 0:
yield dataloader(batches)
else:
# emit full dataloader
yield dataloader(batches)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending,
allow_different_src_lens=True)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
class sharded_iterator(object): class sharded_iterator(object):
...@@ -208,99 +67,25 @@ class sharded_iterator(object): ...@@ -208,99 +67,25 @@ class sharded_iterator(object):
yield v yield v
class LanguagePairDataset(torch.utils.data.Dataset): def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
size = max(v.size(0) for v in values)
# padding constants res = values[0].new(len(values), size).fill_(pad_idx)
LEFT_PAD_SOURCE = True
LEFT_PAD_TARGET = False
def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src
self.dst = dst
self.pad_idx = pad_idx
self.eos_idx = eos_idx
def __getitem__(self, i): def copy_tensor(src, dst):
# subtract 1 for 0-based indexing assert dst.numel() == src.numel()
source = self.src[i].long() - 1 if move_eos_to_beginning:
res = {'id': i, 'source': source} assert src[-1] == eos_idx
if self.dst: dst[0] = eos_idx
res['target'] = self.dst[i].long() - 1 dst[1:] = src[:-1]
else:
dst.copy_(src)
return res for i, v in enumerate(values):
if left_pad:
def __len__(self): copy_tensor(v, res[i][size - len(v):])
return len(self.src) else:
copy_tensor(v, res[i][:len(v)])
def collater(self, samples): return res
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod
def collate(samples, pad_idx, eos_idx, has_target=True):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return LanguagePairDataset.collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
ntokens = None
if has_target:
target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
@staticmethod
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
if left_pad:
copy_tensor(v, res[i][size-len(v):])
else:
copy_tensor(v, res[i][:len(v)])
return res
def _valid_size(src_size, dst_size, max_positions): def _valid_size(src_size, dst_size, max_positions):
...@@ -344,9 +129,9 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, ...@@ -344,9 +129,9 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
ignored.append(idx) ignored.append(idx)
continue continue
raise Exception(( raise Exception((
"Sample #{} has size (src={}, dst={}) but max size is {}." "Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test" " Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src_size, dst_size, max_positions)) ).format(idx, src_size, dst_size, max_positions))
sample_lens.append(max(src_size, dst_size)) sample_lens.append(max(src_size, dst_size))
sample_len = max(sample_len, sample_lens[-1]) sample_len = max(sample_len, sample_lens[-1])
...@@ -373,7 +158,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -373,7 +158,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
descending=False, required_batch_size_multiple=1, allow_different_src_lens=False): descending=False, required_batch_size_multiple=1, allow_different_src_lens=False):
"""Returns batches of indices sorted by size. Sequences with different """Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch.""" source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset)) assert isinstance(src, SizedDataset) and (dst is None or isinstance(dst, SizedDataset))
if max_tokens is None: if max_tokens is None:
max_tokens = float('Inf') max_tokens = float('Inf')
if max_sentences is None: if max_sentences is None:
...@@ -393,7 +178,7 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -393,7 +178,7 @@ def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
required_batch_size_multiple=1): required_batch_size_multiple=1):
"""Returns batches of indices bucketed by size. Batches may contain """Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths.""" sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset) assert isinstance(src, SizedDataset) and isinstance(dst, SizedDataset)
if max_tokens is None: if max_tokens is None:
max_tokens = float('Inf') max_tokens = float('Inf')
if max_sentences is None: if max_sentences is None:
......
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
import os import os
import struct import struct
import torch import torch
import torch.utils.data
from fairseq.tokenizer import Tokenizer from fairseq.tokenizer import Tokenizer
...@@ -48,10 +49,20 @@ def data_file_path(prefix_path): ...@@ -48,10 +49,20 @@ def data_file_path(prefix_path):
return prefix_path + '.bin' return prefix_path + '.bin'
class IndexedDataset(object): class SizedDataset(torch.utils.data.Dataset):
def __init__(self):
self._sizes = None
@property
def sizes(self):
return self._sizes
class IndexedDataset(SizedDataset):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path): def __init__(self, path):
super().__init__()
with open(index_file_path(path), 'rb') as f: with open(index_file_path(path), 'rb') as f:
magic = f.read(8) magic = f.read(8)
assert magic == b'TNTIDX\x00\x00' assert magic == b'TNTIDX\x00\x00'
...@@ -62,7 +73,7 @@ class IndexedDataset(object): ...@@ -62,7 +73,7 @@ class IndexedDataset(object):
self.size, self.s = struct.unpack('<QQ', f.read(16)) self.size, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1) self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1) self.data_offsets = read_longs(f, self.size + 1)
self.sizes = read_longs(f, self.s) self._sizes = read_longs(f, self.s)
self.read_data(path) self.read_data(path)
def read_data(self, path): def read_data(self, path):
...@@ -121,7 +132,7 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -121,7 +132,7 @@ class IndexedRawTextDataset(IndexedDataset):
def __init__(self, path, dictionary, append_eos=True, reverse_order=False): def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = [] self.tokens_list = []
self.lines = [] self.lines = []
self.sizes = [] self._sizes = []
self.append_eos = append_eos self.append_eos = append_eos
self.reverse_order = reverse_order self.reverse_order = reverse_order
self.read_data(path, dictionary) self.read_data(path, dictionary)
...@@ -136,8 +147,8 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -136,8 +147,8 @@ class IndexedRawTextDataset(IndexedDataset):
append_eos=self.append_eos, reverse_order=self.reverse_order, append_eos=self.append_eos, reverse_order=self.reverse_order,
) + 1 # +1 for Lua compatibility ) + 1 # +1 for Lua compatibility
self.tokens_list.append(tokens) self.tokens_list.append(tokens)
self.sizes.append(len(tokens)) self._sizes.append(len(tokens))
self.sizes = np.array(self.sizes) self._sizes = np.array(self._sizes)
def __getitem__(self, i): def __getitem__(self, i):
self.check_index(i) self.check_index(i)
...@@ -155,10 +166,9 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -155,10 +166,9 @@ class IndexedRawTextDataset(IndexedDataset):
class IndexedDatasetBuilder(object): class IndexedDatasetBuilder(object):
element_sizes = { element_sizes = {
np.uint8: 1, np.uint8: 1,
np.int8: 1, np.int8: 1,
np.int16: 2, np.int16: 2,
np.int32: 4, np.int32: 4,
np.int64: 8, np.int64: 8,
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import torch
from fairseq.data.data_utils import numpy_seed, uneven_batches_by_size, mask_batches, batches_by_size
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
self.dst = dst
self.src_dict = src_dict
self.dst_dict = dst_dict
self.splits = {}
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader_generator(
self, split, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), seed=None, sample_without_replacement=0,
shard_id=0, num_shards=1
):
dataset = self.splits[split]
with numpy_seed(seed):
batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, max_positions=max_positions,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple=8,
)
frozen_batches = tuple(batches) # freeze
def dataloader(b):
b = mask_batches(b, shard_id=shard_id, num_shards=num_shards) # shard dataset
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
for epoch in itertools.count(1):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(seed + epoch):
batches = list(frozen_batches) # copy
np.random.shuffle(batches)
if sample_without_replacement > 0:
# emit sub-epoch dataloaders
while len(batches) >= sample_without_replacement:
sampled_batches = batches[:sample_without_replacement]
remaining_batches = batches[sample_without_replacement:]
yield dataloader(sampled_batches)
batches = remaining_batches
if len(batches) > 0:
yield dataloader(batches)
else:
# emit full dataloader
yield dataloader(batches)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending,
allow_different_src_lens=True)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
\ No newline at end of file
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import os
import torch
import torch.utils
from fairseq.data import LanguageDatasets
from fairseq.data.consts import LEFT_PAD_TARGET, LEFT_PAD_SOURCE
from fairseq.data.data_utils import fmt_path, load_dictionaries, collate_tokens
from fairseq.data.indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset
def collate(samples, pad_idx, eos_idx, has_target):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LEFT_PAD_SOURCE)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
ntokens = None
if has_target:
target = merge('target', left_pad=LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
class LanguagePairDataset(torch.utils.data.Dataset):
def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src
self.dst = dst
self.pad_idx = pad_idx
self.eos_idx = eos_idx
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __len__(self):
return len(self.src)
def collater(self, samples):
return collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod
def create_dataset(args, splits, is_raw):
src, dst = args.source_lang, args.target_lang
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(args.data, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
def create_raw_dataset():
"""Loads specified data splits (e.g., test, train or valid) from raw text
files in the specified folder."""
# Load dataset from raw text files
for split in splits:
src_path = os.path.join(args.data, '{}.{}'.format(split, src))
dst_path = os.path.join(args.data, '{}.{}'.format(split, dst))
dataset.splits[split] = LanguagePairDataset(
IndexedRawTextDataset(src_path, src_dict),
IndexedRawTextDataset(dst_path, dst_dict),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
def create_binary_dataset():
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
# Load dataset from binary files
def all_splits_exist(src, dst, lang):
for split in splits:
filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang)
if not os.path.exists(os.path.join(args.data, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst, src):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + args.data)
for split in splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path(args.data, '{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path(args.data, '{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path):
break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
target_dataset,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
return create_raw_dataset() if is_raw else create_binary_dataset()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import os
import torch
from torch.utils.data import Dataset
from fairseq.data import TokenBlockDataset, Dictionary, LanguageDatasets
from fairseq.data.indexed_dataset import IndexedInMemoryDataset
from fairseq.data.data_utils import fmt_path, collate_tokens
def collate(samples, pad_idx, eos_idx, has_target):
if len(samples) == 0:
return {}
def merge(key):
return collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad=False, move_eos_to_beginning=False,
)
id = torch.LongTensor([s['id'] for s in samples])
# language models only have a decoder which is not padding-aware, so don't left pad for them
src_tokens = merge('source')
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
_, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
target = None
ntokens = None
if has_target:
target = merge('target')
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
},
'target': target,
}
class MonolingualDataset(Dataset):
def __init__(self, tokens, sizes, token_block_size, break_mode, pad_idx, eos_idx, next_token_is_target):
if next_token_is_target:
self.src = TokenBlockDataset(tokens, token_block_size, sizes, offset=1, break_mode=break_mode)
self.dst = TokenBlockDataset(tokens, token_block_size, sizes, offset=0, break_mode=break_mode)
else:
self.src = TokenBlockDataset(tokens, token_block_size, sizes, offset=0, break_mode=break_mode)
self.dst = None
self.pad_idx = pad_idx
self.eos_idx = eos_idx
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __len__(self):
return len(self.src)
def collater(self, samples):
return collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod
def create_dataset(args, splits, is_raw):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
if is_raw:
raise Exception('raw text single language data sets are currently not supported')
assert args.sample_break_mode == 'eos' or args.max_target_positions is not None
path = args.data
dict = Dictionary.load(os.path.join(path, 'dict.txt'))
dataset = LanguageDatasets(None, None, dict, dict)
assert all(os.path.exists(os.path.join(path, '{}.bin'.format(split))) for split in splits)
for split in splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
split_path = fmt_path(path, '{}', prefix)
if not IndexedInMemoryDataset.exists(split_path):
break
ds = IndexedInMemoryDataset(split_path)
tokens = torch.from_numpy(ds.buffer)
dataset.splits[prefix] = MonolingualDataset(
tokens,
ds.sizes,
args.max_target_positions,
args.sample_break_mode,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
next_token_is_target=True,
)
return dataset
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.utils.data import Dataset
class OffsetDataset(Dataset):
""" Wraps an existing dataset, but starts iterating from a particular offset """
def __init__(self, dataset, offset):
"""
Args:
dataset: Dataset to wrap
offset: An integer. offset from which to start iterating
"""
super().__init__()
assert len(dataset) >= offset
self.dataset = dataset
self.offset = offset
def __getitem__(self, i):
return self.dataset[i + self.offset]
def __len__(self):
return len(self.dataset) - self.offset
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import numpy as np
import torch
from fairseq.data.indexed_dataset import SizedDataset
class TokenBlockDataset(SizedDataset):
"""Given a 1d tensor of tokens, this dataset will break tokens into blocks based on parameters. The blocks are
fetched from the original tensor so no additional memory is allocated"""
def __init__(self, tokens, block_size, sizes, offset=0, break_mode=None):
"""
Args:
tokens: torch tensor of tokens to break into blocks
block_size: An integer. the maximum size of each block (note this has no effect in 'eos' break mode)
sizes: A list of integers. sizes of sentences in the block. the sum of the sizes should add up to the
length of tokens
offset: An integer. rotates the tokens by this much before computing blocks. useful for language model targets
break_mode: A boolean if None/'none' then breaks tokens into equally sized blocks of size block_size
if 'complete' then breaks tokens into block sizes of up to block_size such that each block
contains complete sentences. block_size may be exceeded if some sentences exceed block_size
if 'eos' then each block contains a single sentence. does not respect block_size"""
super().__init__()
self.tokens = tokens
self.offset = offset
self.slice_indices = []
if break_mode is None or break_mode == 'none':
length = math.ceil(tokens.numel() / block_size)
def block_at(i):
start = i * block_size
end = min(start + block_size, len(tokens))
return (start, end)
self.slice_indices = [block_at(i) for i in np.arange(length)]
elif break_mode == 'complete':
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if curr_size + sizes[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes[sz_idx]
sz_idx += 1
else:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
curr = 0
for sz in sizes:
# skip samples with just 1 example (which would be just the eos token)
if sz > 1:
self.slice_indices.append((curr, curr + sz))
curr += sz
else:
raise Exception('invalid break_mode. Supported values: none, complete, eos')
self._sizes = np.array([e - s for s, e in self.slice_indices])
def _slice(self, s, e):
# this will copy only the first block if offset > 0, instead of all blocks if we just rotated
# the tensor with torch.cat()
if s < self.offset:
return torch.cat([self.tokens[s - self.offset:], self.tokens[s:e - self.offset]])
return self.tokens[s - self.offset:e - self.offset]
def __getitem__(self, i):
return self._slice(*self.slice_indices[i])
def __len__(self):
return len(self.slice_indices)
...@@ -11,8 +11,7 @@ import os ...@@ -11,8 +11,7 @@ import os
from .fairseq_decoder import FairseqDecoder # noqa: F401 from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401 from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import FairseqModel # noqa: F401 from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {} ARCH_MODEL_REGISTRY = {}
...@@ -29,8 +28,8 @@ def register_model(name): ...@@ -29,8 +28,8 @@ def register_model(name):
def register_model_cls(cls): def register_model_cls(cls):
if name in MODEL_REGISTRY: if name in MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model ({})'.format(name)) raise ValueError('Cannot register duplicate model ({})'.format(name))
if not issubclass(cls, FairseqModel): if not issubclass(cls, BaseFairseqModel):
raise ValueError('Model ({}: {}) must extend FairseqModel'.format(name, cls.__name__)) raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__))
MODEL_REGISTRY[name] = cls MODEL_REGISTRY[name] = cls
return cls return cls
......
...@@ -19,7 +19,7 @@ class FairseqDecoder(nn.Module): ...@@ -19,7 +19,7 @@ class FairseqDecoder(nn.Module):
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out):
raise NotImplementedError raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs, _):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0].float() logits = net_output[0].float()
if log_probs: if log_probs:
......
...@@ -5,28 +5,17 @@ ...@@ -5,28 +5,17 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch.nn as nn import torch.nn as nn
from . import FairseqDecoder, FairseqEncoder from . import FairseqDecoder, FairseqEncoder
class FairseqModel(nn.Module): class BaseFairseqModel(nn.Module):
"""Base class for encoder-decoder models.""" """Base class for fairseq models."""
def __init__(self, encoder, decoder): def __init__(self):
super().__init__() super().__init__()
self.encoder = encoder
self.decoder = decoder
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self._is_generation_fast = False self._is_generation_fast = False
@staticmethod @staticmethod
...@@ -34,32 +23,10 @@ class FairseqModel(nn.Module): ...@@ -34,32 +23,10 @@ class FairseqModel(nn.Module):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
pass pass
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
raise NotImplementedError
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs)
def get_targets(self, sample, net_output): def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output.""" """Get targets from either the sample or the net's output."""
return sample['target'] return sample['target']
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
return self.encoder.max_positions()
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from state_dict into this module and """Copies parameters and buffers from state_dict into this module and
its descendants. its descendants.
...@@ -67,13 +34,17 @@ class FairseqModel(nn.Module): ...@@ -67,13 +34,17 @@ class FairseqModel(nn.Module):
Overrides the method in nn.Module; compared with that method this Overrides the method in nn.Module; compared with that method this
additionally "upgrades" state_dicts from old checkpoints. additionally "upgrades" state_dicts from old checkpoints.
""" """
state_dict = self.upgrade_state_dict(state_dict) self.upgrade_state_dict(state_dict)
super().load_state_dict(state_dict, strict) super().load_state_dict(state_dict, strict)
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
state_dict = self.encoder.upgrade_state_dict(state_dict) assert state_dict is not None
state_dict = self.decoder.upgrade_state_dict(state_dict)
return state_dict def do_upgrade(m):
if m != self and hasattr(m, 'upgrade_state_dict'):
m.upgrade_state_dict(state_dict)
self.apply(do_upgrade)
def make_generation_fast_(self, **kwargs): def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation.""" """Optimize model for faster generation."""
...@@ -87,11 +58,13 @@ class FairseqModel(nn.Module): ...@@ -87,11 +58,13 @@ class FairseqModel(nn.Module):
nn.utils.remove_weight_norm(module) nn.utils.remove_weight_norm(module)
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return
self.apply(apply_remove_weight_norm) self.apply(apply_remove_weight_norm)
def apply_make_generation_fast_(module): def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'): if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs) module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_) self.apply(apply_make_generation_fast_)
def train(mode): def train(mode):
...@@ -101,3 +74,66 @@ class FairseqModel(nn.Module): ...@@ -101,3 +74,66 @@ class FairseqModel(nn.Module):
# this model should no longer be used for training # this model should no longer be used for training
self.eval() self.eval()
self.train = train self.train = train
class FairseqModel(BaseFairseqModel):
"""Base class for encoder-decoder models."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
raise NotImplementedError
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
return self.encoder.max_positions()
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
class FairseqLanguageModel(BaseFairseqModel):
"""Base class for decoder-only models."""
def __init__(self, decoder):
super().__init__()
self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, **unused):
return self.decoder(src_tokens)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def max_encoder_positions(self):
return self.max_decoder_positions()
This diff is collapsed.
...@@ -11,7 +11,7 @@ import torch.nn as nn ...@@ -11,7 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.data import LanguagePairDataset from fairseq.data import consts
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture
...@@ -117,7 +117,7 @@ class LSTMEncoder(FairseqEncoder): ...@@ -117,7 +117,7 @@ class LSTMEncoder(FairseqEncoder):
def __init__( def __init__(
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1, self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False, dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad_source=LanguagePairDataset.LEFT_PAD_SOURCE, left_pad_source=consts.LEFT_PAD_SOURCE,
pretrained_embed=None, pretrained_embed=None,
padding_value=0., padding_value=0.,
): ):
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data import LanguagePairDataset from fairseq.data.consts import LEFT_PAD_SOURCE, LEFT_PAD_TARGET
from fairseq.modules import ( from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
...@@ -108,7 +108,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -108,7 +108,7 @@ class TransformerEncoder(FairseqEncoder):
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx, 1024, embed_dim, self.padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE, left_pad=LEFT_PAD_SOURCE,
learned=args.encoder_learned_pos, learned=args.encoder_learned_pos,
) )
...@@ -169,7 +169,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -169,7 +169,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx, 1024, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET, left_pad=LEFT_PAD_TARGET,
learned=args.decoder_learned_pos, learned=args.decoder_learned_pos,
) )
...@@ -181,7 +181,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -181,7 +181,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if not self.share_input_output_embed: if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal(self.embed_out, mean=0, std=embed_dim**-0.5) nn.init.normal(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions # embed positions
...@@ -363,7 +363,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -363,7 +363,7 @@ class TransformerDecoderLayer(nn.Module):
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5) nn.init.normal(m.weight, mean=0, std=embedding_dim ** -0.5)
return m return m
...@@ -382,7 +382,7 @@ def Linear(in_features, out_features, bias=True): ...@@ -382,7 +382,7 @@ def Linear(in_features, out_features, bias=True):
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False): def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned: if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad) m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5) nn.init.normal(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant(m.weight[padding_idx], 0) nn.init.constant(m.weight[padding_idx], 0)
else: else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings) m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .grad_multiply import GradMultiply from .grad_multiply import GradMultiply
...@@ -14,6 +15,7 @@ from .multihead_attention import MultiheadAttention ...@@ -14,6 +15,7 @@ from .multihead_attention import MultiheadAttention
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__ = [ __all__ = [
'AdaptiveSoftmax',
'BeamableMM', 'BeamableMM',
'ConvTBC', 'ConvTBC',
'GradMultiply', 'GradMultiply',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment