Unverified Commit e5b3c1f4 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge pull request #54: Version 0.1.0 -> 0.2.0

Release notes:
- 5c7f4954: Added simple LSTM model with input feeding and attention
- 6e4b7e22: Refactored model definitions and incremental generation to be cleaner
- 7ae79c12: Split interactive generation out of generate.py and into a new binary: interactive.py
- 19a3865d: Subtle correctness fix in beam search decoder. Previously, for a beam size of k, we might emit a hypotheses
           if the <eos> was among the top 2*k candidates. Now we only emit hypotheses for which the <eos> is among the
           top-k candidates. This may subtly change generation results, and in the case of k=1 we will now produce
           strictly greedy outputs.
- 97d7fcb9: Fixed bug in padding direction, where previously we right-padded the source and left-padded the target. We
           now left-pad the source and right-pad the target. This should not effect existing trained models, but may
           change (usually improves) the quality of new models.
- f442f896: Add support for batching based on the number of sentences (`--max-sentences`) in addition to the number of
           tokens (`--max-tokens`). When batching by the number of sentences, one can optionally normalize the gradients
           by the number of sentences with `--sentence-avg` (the default is to normalize by the number of tokens).
- c6d6256b: Add `--log-format` option and JSON logger
parents ba5d7dcd 13a3c811
...@@ -43,7 +43,7 @@ The following command-line tools are available: ...@@ -43,7 +43,7 @@ The following command-line tools are available:
* `python preprocess.py`: Data pre-processing: build vocabularies and binarize training data * `python preprocess.py`: Data pre-processing: build vocabularies and binarize training data
* `python train.py`: Train a new model on one or multiple GPUs * `python train.py`: Train a new model on one or multiple GPUs
* `python generate.py`: Translate pre-processed data with a trained model * `python generate.py`: Translate pre-processed data with a trained model
* `python generate.py -i`: Translate raw text with a trained model * `python interactive.py`: Translate raw text with a trained model
* `python score.py`: BLEU scoring of generated translations against reference translations * `python score.py`: BLEU scoring of generated translations against reference translations
## Evaluating Pre-trained Models ## Evaluating Pre-trained Models
...@@ -57,22 +57,21 @@ This can be done with the [apply_bpe.py](https://github.com/rsennrich/subword-nm ...@@ -57,22 +57,21 @@ This can be done with the [apply_bpe.py](https://github.com/rsennrich/subword-nm
`@@` is used as a continuation marker and the original text can be easily recovered with e.g. `sed s/@@ //g` or by passing the `--remove-bpe` flag to `generate.py`. `@@` is used as a continuation marker and the original text can be easily recovered with e.g. `sed s/@@ //g` or by passing the `--remove-bpe` flag to `generate.py`.
Prior to BPE, input text needs to be tokenized using `tokenizer.perl` from [mosesdecoder](https://github.com/moses-smt/mosesdecoder). Prior to BPE, input text needs to be tokenized using `tokenizer.perl` from [mosesdecoder](https://github.com/moses-smt/mosesdecoder).
Let's use `python generate.py -i` to generate translations. Let's use `python interactive.py` to generate translations interactively.
Here, we use a beam size of 5: Here, we use a beam size of 5:
``` ```
$ MODEL_DIR=wmt14.en-fr.fconv-py $ MODEL_DIR=wmt14.en-fr.fconv-py
$ python generate.py -i \ $ python interactive.py \
--path $MODEL_DIR/model.pt $MODEL_DIR \ --path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --beam 5
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types | [en] dictionary: 44206 types
| [fr] dictionary: 44463 types | [fr] dictionary: 44463 types
| model fconv_wmt_en_fr | Type the input sentence and press return:
| loaded checkpoint /private/home/edunov/wmt14.en-fr.fconv-py/model.pt (epoch 37)
> Why is it rare to discover new marine mam@@ mal species ? > Why is it rare to discover new marine mam@@ mal species ?
S Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ? O Why is it rare to discover new marine mam@@ mal species ?
H -0.08662842959165573 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ? H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 10 8 8 8 11 12 A 0 1 3 3 5 6 6 8 8 8 7 11 12
``` ```
This generation script produces four types of outputs: a line prefixed with *S* shows the supplied source sentence after applying the vocabulary; *O* is a copy of the original source sentence; *H* is the hypothesis along with an average log-likelihood; and *A* is the attention maxima for each word in the hypothesis, including the end-of-sentence marker which is omitted from the text. This generation script produces four types of outputs: a line prefixed with *S* shows the supplied source sentence after applying the vocabulary; *O* is a copy of the original source sentence; *H* is the hypothesis along with an average log-likelihood; and *A* is the attention maxima for each word in the hypothesis, including the end-of-sentence marker which is omitted from the text.
...@@ -114,7 +113,7 @@ Also note that the batch size is specified in terms of the maximum number of tok ...@@ -114,7 +113,7 @@ Also note that the batch size is specified in terms of the maximum number of tok
You may need to use a smaller value depending on the available GPU memory on your system. You may need to use a smaller value depending on the available GPU memory on your system.
### Generation ### Generation
Once your model is trained, you can generate translations using `python generate.py` **(for binarized data)** or `python generate.py -i` **(for raw text)**: Once your model is trained, you can generate translations using `python generate.py` **(for binarized data)** or `python interactive.py` **(for raw text)**:
``` ```
$ python generate.py data-bin/iwslt14.tokenized.de-en \ $ python generate.py data-bin/iwslt14.tokenized.de-en \
--path checkpoints/fconv/checkpoint_best.pt \ --path checkpoints/fconv/checkpoint_best.pt \
...@@ -152,14 +151,14 @@ $ python generate.py data-bin/wmt14.en-fr.newstest2014 \ ...@@ -152,14 +151,14 @@ $ python generate.py data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \ --path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
... ...
| Translated 3003 sentences (95451 tokens) in 94.5s (1009.60 tokens/s) | Translated 3003 sentences (95451 tokens) in 93.8s (1018.09 tokens/s)
| Generate test with beam=5: BLEU4 = 40.70, 67.7/46.8/34.2/25.4 (BP=1.000, ratio=1.000, syslen=81190, reflen=81194) | Generate test with beam=5: BLEU4 = 40.67, 67.7/46.7/34.2/25.3 (BP=1.000, ratio=0.998, syslen=81377, reflen=81194)
# Scoring with score.py: # Scoring with score.py:
$ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys $ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
$ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref $ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.70, 67.7/46.8/34.2/25.4 (BP=1.000, ratio=1.000, syslen=81190, reflen=81194) BLEU4 = 40.67, 67.7/46.7/34.2/25.3 (BP=1.000, ratio=0.998, syslen=81377, reflen=81194)
``` ```
# Join the fairseq community # Join the fairseq community
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
# #
from .cross_entropy import CrossEntropyCriterion from .cross_entropy import CrossEntropyCriterion
from .fairseq_criterion import FairseqCriterion
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
__all__ = [ __all__ = [
......
...@@ -14,9 +14,8 @@ from .fairseq_criterion import FairseqCriterion ...@@ -14,9 +14,8 @@ from .fairseq_criterion import FairseqCriterion
class CrossEntropyCriterion(FairseqCriterion): class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, padding_idx): def __init__(self, args, dst_dict):
super().__init__() super().__init__(args, dst_dict)
self.padding_idx = padding_idx
def forward(self, model, sample): def forward(self, model, sample):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -30,7 +29,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -30,7 +29,7 @@ class CrossEntropyCriterion(FairseqCriterion):
input = net_output.view(-1, net_output.size(-1)) input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx) loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
sample_size = sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0], 'loss': loss.data[0],
'sample_size': sample_size, 'sample_size': sample_size,
......
...@@ -11,8 +11,10 @@ from torch.nn.modules.loss import _Loss ...@@ -11,8 +11,10 @@ from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss): class FairseqCriterion(_Loss):
def __init__(self): def __init__(self, args, dst_dict):
super().__init__() super().__init__()
self.args = args
self.padding_idx = dst_dict.pad()
def forward(self, model, sample): def forward(self, model, sample):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
......
...@@ -14,7 +14,7 @@ import torch.nn.functional as F ...@@ -14,7 +14,7 @@ import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion from .fairseq_criterion import FairseqCriterion
class LabelSmoothedCrossEntropy(torch.autograd.Function): class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, target, eps, padding_idx, weights): def forward(ctx, input, target, eps, padding_idx, weights):
...@@ -43,10 +43,9 @@ class LabelSmoothedCrossEntropy(torch.autograd.Function): ...@@ -43,10 +43,9 @@ class LabelSmoothedCrossEntropy(torch.autograd.Function):
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, eps, padding_idx=None, weights=None): def __init__(self, args, dst_dict, weights=None):
super().__init__() super().__init__(args, dst_dict)
self.eps = eps self.eps = args.label_smoothing
self.padding_idx = padding_idx
self.weights = weights self.weights = weights
def forward(self, model, sample): def forward(self, model, sample):
...@@ -60,8 +59,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -60,8 +59,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1))) input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights) loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights)
sample_size = sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0], 'loss': loss.data[0],
'sample_size': sample_size, 'sample_size': sample_size,
......
...@@ -8,70 +8,85 @@ ...@@ -8,70 +8,85 @@
import contextlib import contextlib
import itertools import itertools
import glob
import numbers
import numpy as np import numpy as np
import os import os
import torch import torch
import torch.utils.data import torch.utils.data
from fairseq.dictionary import Dictionary from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
def load_with_check(path, load_splits, src=None, dst=None): def has_binary_files(data_dir, splits):
"""Loads specified data splits (e.g., test, train or valid) from the for split in splits:
specified folder and check that files exist.""" if len(glob.glob(os.path.join(data_dir, f'{split}.*-*.*.bin'))) < 2:
return False
return True
def find_language_pair(files):
for split in load_splits:
for filename in files:
parts = filename.split('.')
if parts[0] == split and parts[-1] == 'idx':
return parts[1].split('-')
def split_exists(split, src, dst): def infer_language_pair(path, splits):
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst) """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
return os.path.exists(os.path.join(path, filename)) src, dst = None, None
for filename in os.listdir(path):
parts = filename.split('.')
for split in splits:
if parts[0] == split and parts[-1] == 'idx':
src, dst = parts[1].split('-')
break
return src, dst
if src is None and dst is None:
# find language pair automatically
src, dst = find_language_pair(os.listdir(path))
if not split_exists(load_splits[0], src, dst): def load_dictionaries(path, src_lang, dst_lang):
# try reversing src and dst """Load dictionaries for a given language pair."""
src, dst = dst, src src_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(src_lang)))
dst_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(dst_lang)))
return src_dict, dst_dict
for split in load_splits:
if not split_exists(load_splits[0], src, dst):
raise ValueError('Data split not found: {}-{} ({})'.format(
src, dst, split))
dataset = load(path, load_splits, src, dst)
return dataset
def load_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
def load(path, load_splits, src, dst): src_dict, dst_dict = load_dictionaries(path, src, dst)
"""Loads specified data splits (e.g. test, train or valid) from the path.""" dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
langcode = '{}-{}'.format(src, dst) # Load dataset from binary files
def all_splits_exist(src, dst):
for split in load_splits:
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
if not os.path.exists(os.path.join(path, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + path)
def fmt_path(fmt, *args): def fmt_path(fmt, *args):
return os.path.join(path, fmt.format(*args)) return os.path.join(path, fmt.format(*args))
src_dict = Dictionary.load(fmt_path('dict.{}.txt', src))
dst_dict = Dictionary.load(fmt_path('dict.{}.txt', dst))
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
for split in load_splits: for split in load_splits:
for k in itertools.count(): for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '') prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src) src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path('{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path): if not IndexedInMemoryDataset.exists(src_path):
break break
dataset.splits[prefix] = LanguagePairDataset( dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path), IndexedInMemoryDataset(src_path),
IndexedInMemoryDataset(fmt_path('{}.{}.{}', prefix, langcode, dst)), IndexedInMemoryDataset(dst_path),
pad_idx=dataset.src_dict.pad(), pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(), eos_idx=dataset.src_dict.eos(),
) )
...@@ -79,6 +94,30 @@ def load(path, load_splits, src, dst): ...@@ -79,6 +94,30 @@ def load(path, load_splits, src, dst):
return dataset return dataset
def load_raw_text_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from raw text
files in the specified folder."""
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from raw text files
for split in load_splits:
src_path = os.path.join(path, f'{split}.{src}')
dst_path = os.path.join(path, f'{split}.{dst}')
dataset.splits[split] = LanguagePairDataset(
IndexedRawTextDataset(src_path, src_dict),
IndexedRawTextDataset(dst_path, dst_dict),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
class LanguageDatasets(object): class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict): def __init__(self, src, dst, src_dict, dst_dict):
self.src = src self.src = src
...@@ -91,32 +130,34 @@ class LanguageDatasets(object): ...@@ -91,32 +130,34 @@ class LanguageDatasets(object):
assert self.src_dict.eos() == self.dst_dict.eos() assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk() assert self.src_dict.unk() == self.dst_dict.unk()
def dataloader(self, split, batch_size=1, num_workers=0, def train_dataloader(self, split, num_workers=0, max_tokens=None,
max_tokens=None, seed=None, epoch=1, max_sentences=None, max_positions=(1024, 1024),
sample_without_replacement=0, max_positions=1024, seed=None, epoch=1, sample_without_replacement=0,
skip_invalid_size_inputs_valid_test=False): sort_by_source_size=False):
dataset = self.splits[split] dataset = self.splits[split]
if split.startswith('train'): with numpy_seed(seed):
with numpy_seed(seed): batch_sampler = shuffled_batches_by_size(
batch_sampler = shuffled_batches_by_size( dataset.src, dataset.dst, max_tokens=max_tokens,
dataset.src, dataset.dst, max_sentences=max_sentences, epoch=epoch,
max_tokens=max_tokens, epoch=epoch, sample=sample_without_replacement, max_positions=max_positions,
sample=sample_without_replacement, sort_by_source_size=sort_by_source_size)
max_positions=max_positions) return torch.utils.data.DataLoader(
elif split.startswith('valid'): dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst, batch_sampler=batch_sampler)
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
else:
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False):
dataset = self.splits[split]
batch_sampler = list(batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending))
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, dataset, num_workers=num_workers, collate_fn=dataset.collater,
num_workers=num_workers, batch_sampler=batch_sampler)
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
)
def skip_group_enumerator(it, ngpus, offset=0): def skip_group_enumerator(it, ngpus, offset=0):
...@@ -135,6 +176,11 @@ def skip_group_enumerator(it, ngpus, offset=0): ...@@ -135,6 +176,11 @@ def skip_group_enumerator(it, ngpus, offset=0):
class LanguagePairDataset(object): class LanguagePairDataset(object):
# padding constants
LEFT_PAD_SOURCE = True
LEFT_PAD_TARGET = False
def __init__(self, src, dst, pad_idx, eos_idx): def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src self.src = src
self.dst = dst self.dst = dst
...@@ -164,18 +210,15 @@ class LanguagePairDataset(object): ...@@ -164,18 +210,15 @@ class LanguagePairDataset(object):
return LanguagePairDataset.collate_tokens( return LanguagePairDataset.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning) [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)
def merge_positions(key, left_pad):
return LanguagePairDataset.collate_positions([s[key] for s in samples], pad_idx, left_pad)
ntokens = sum(len(s['target']) for s in samples)
return { return {
'id': torch.LongTensor([s['id'].item() for s in samples]), 'id': torch.LongTensor([s['id'].item() for s in samples]),
'input_tokens': merge('target', left_pad=True, move_eos_to_beginning=True), 'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
'input_positions': merge_positions('target', left_pad=True), # we create a shifted version of targets for feeding the previous
'target': merge('target', left_pad=True), # output token(s) into the next decoder step
'src_tokens': merge('source', left_pad=False), 'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
'src_positions': merge_positions('source', left_pad=False), move_eos_to_beginning=True),
'ntokens': ntokens, 'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'ntokens': sum(len(s['target']) for s in samples),
} }
@staticmethod @staticmethod
...@@ -199,82 +242,91 @@ class LanguagePairDataset(object): ...@@ -199,82 +242,91 @@ class LanguagePairDataset(object):
copy_tensor(v, res[i][:len(v)]) copy_tensor(v, res[i][:len(v)])
return res return res
@staticmethod
def collate_positions(values, pad_idx, left_pad):
start = pad_idx + 1
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
for i, v in enumerate(values):
if left_pad:
torch.arange(start, start + len(v), out=res[i][size-len(v):])
else:
torch.arange(start, start + len(v), out=res[i][:len(v)])
return res
def _valid_size(src_size, dst_size, max_positions):
if isinstance(max_positions, numbers.Number):
max_src_positions, max_dst_positions = max_positions, max_positions
else:
max_src_positions, max_dst_positions = max_positions
if src_size < 2 or src_size > max_src_positions:
return False
if dst_size is not None and (dst_size < 2 or dst_size > max_dst_positions):
return False
return True
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
max_positions=1024, ignore_invalid_inputs=False):
"""Returns batches of indices sorted by size. Sequences of different lengths
are not allowed in the same batch."""
assert isinstance(src, IndexedDataset)
assert dst is None or isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
sizes = src.sizes
indices = np.argsort(sizes, kind='mergesort')
if dst is not None:
sizes = np.maximum(sizes, dst.sizes)
def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=False, allow_different_src_lens=False):
batch = [] batch = []
def yield_batch(next_idx, num_tokens): def yield_batch(next_idx, num_tokens):
if len(batch) == 0: if len(batch) == 0:
return False return False
if len(batch) == batch_size: if len(batch) == max_sentences:
return True return True
if sizes[batch[0]] != sizes[next_idx]: if num_tokens > max_tokens:
return True return True
if num_tokens >= max_tokens: if not allow_different_src_lens and \
(src.sizes[batch[0]] != src.sizes[next_idx]):
return True return True
return False return False
cur_max_size = 0 sample_len = 0
ignored = [] ignored = []
for idx in indices: for idx in indices:
# - 2 here stems from make_positions() where we offset positions if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
# by padding_value + 1
if src.sizes[idx] < 2 or \
(False if dst is None else dst.sizes[idx] < 2) or \
sizes[idx] > max_positions - 2:
if ignore_invalid_inputs: if ignore_invalid_inputs:
ignored.append(idx) ignored.append(idx)
continue continue
raise Exception(
"Unable to handle input id {} of size {} / {}.".format(
idx, src.sizes[idx], dst.sizes[idx]))
raise Exception("Unable to handle input id {} of " sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
"size {} / {}.".format(idx, src.sizes[idx], num_tokens = (len(batch) + 1) * sample_len
"none" if dst is None else dst.sizes[idx])) if yield_batch(idx, num_tokens):
if yield_batch(idx, cur_max_size * (len(batch) + 1)):
yield batch yield batch
batch = [] batch = []
cur_max_size = 0 sample_len = max(src.sizes[idx], dst.sizes[idx])
batch.append(idx)
cur_max_size = max(cur_max_size, sizes[idx])
if len(ignored) > 0: batch.append(idx)
print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored))
if len(batch) > 0: if len(batch) > 0:
yield batch yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_positions=1024): def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return _make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False)
def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
epoch=1, sample=0, max_positions=(1024, 1024),
sort_by_source_size=False):
"""Returns batches of indices, bucketed by size and then shuffled. Batches """Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths.""" may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset) assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None: if max_tokens is None:
max_tokens = float('Inf') max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.random.permutation(len(src)) indices = np.random.permutation(len(src))
...@@ -282,35 +334,12 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p ...@@ -282,35 +334,12 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')] indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')] indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
def make_batches(): batches = list(_make_batches(
batch = [] src, dst, indices, max_tokens, max_sentences, max_positions,
sample_len = 0 ignore_invalid_inputs=True, allow_different_src_lens=True))
ignored = []
for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or dst.sizes[idx] < 2 or \
src.sizes[idx] > max_positions - 2 or \
dst.sizes[idx] > max_positions - 2:
ignored.append(idx)
continue
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
if len(batch) > 0 and (len(batch) + 1) * sample_len > max_tokens:
yield batch
batch = []
sample_len = max(src.sizes[idx], dst.sizes[idx])
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, sample ids={}".format(len(ignored), ignored))
batches = list(make_batches()) if not sort_by_source_size:
np.random.shuffle(batches) np.random.shuffle(batches)
if sample: if sample:
offset = (epoch - 1) * sample offset = (epoch - 1) * sample
...@@ -327,9 +356,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p ...@@ -327,9 +356,6 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
"batch length is not correct {}".format(len(result)) "batch length is not correct {}".format(len(result))
batches = result batches = result
else:
for _ in range(epoch - 1):
np.random.shuffle(batches)
return batches return batches
......
...@@ -17,6 +17,7 @@ class Dictionary(object): ...@@ -17,6 +17,7 @@ class Dictionary(object):
self.symbols = [] self.symbols = []
self.count = [] self.count = []
self.indices = {} self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
self.add_symbol('<Lua heritage>') self.add_symbol('<Lua heritage>')
self.pad_index = self.add_symbol(pad) self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos) self.eos_index = self.add_symbol(eos)
...@@ -44,7 +45,7 @@ class Dictionary(object): ...@@ -44,7 +45,7 @@ class Dictionary(object):
Can optionally remove BPE symbols or escape <unk> words. Can optionally remove BPE symbols or escape <unk> words.
""" """
if torch.is_tensor(tensor) and tensor.dim() == 2: if torch.is_tensor(tensor) and tensor.dim() == 2:
return '\n'.join(self.to_string(t) for t in tensor) return '\n'.join(self.string(t) for t in tensor)
def token_string(i): def token_string(i):
if i == self.unk(): if i == self.unk():
......
...@@ -11,6 +11,8 @@ import os ...@@ -11,6 +11,8 @@ import os
import struct import struct
import torch import torch
from fairseq.tokenizer import Tokenizer
def read_longs(f, n): def read_longs(f, n):
a = np.empty(n, dtype=np.int64) a = np.empty(n, dtype=np.int64)
...@@ -59,12 +61,15 @@ class IndexedDataset(object): ...@@ -59,12 +61,15 @@ class IndexedDataset(object):
def read_data(self, path): def read_data(self, path):
self.data_file = open(path + '.bin', 'rb', buffering=0) self.data_file = open(path + '.bin', 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __del__(self): def __del__(self):
self.data_file.close() self.data_file.close()
def __getitem__(self, i): def __getitem__(self, i):
if i < 0 or i >= self.size: self.check_index(i)
raise IndexError('index out of range')
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
...@@ -92,14 +97,49 @@ class IndexedInMemoryDataset(IndexedDataset): ...@@ -92,14 +97,49 @@ class IndexedInMemoryDataset(IndexedDataset):
pass pass
def __getitem__(self, i): def __getitem__(self, i):
if i < 0 or i >= self.size: self.check_index(i)
raise IndexError('index out of range')
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]]) np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a) return torch.from_numpy(a)
class IndexedRawTextDataset(IndexedDataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def __init__(self, path, dictionary):
self.tokens_list = []
self.lines = []
self.sizes = []
self.read_data(path, dictionary)
self.size = len(self.tokens_list)
def read_data(self, path, dictionary):
with open(path, 'r') as f:
for line in f:
self.lines.append(line.strip('\n'))
# +1 for Lua compatibility
tokens = Tokenizer.tokenize(line, dictionary, add_if_not_exist=False) + 1
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
def get_original_text(self, i):
self.check_index(i)
return self.lines[i]
def __del__(self):
pass
def __len__(self):
return self.size
class IndexedDatasetBuilder(object): class IndexedDatasetBuilder(object):
element_sizes = { element_sizes = {
......
...@@ -6,10 +6,15 @@ ...@@ -6,10 +6,15 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
from . import fconv from .fairseq_decoder import FairseqDecoder
from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import FairseqModel
from . import fconv, lstm
__all__ = ['fconv']
__all__ = ['fconv', 'lstm']
arch_model_map = {} arch_model_map = {}
for model in __all__: for model in __all__:
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
class FairseqDecoder(nn.Module):
"""Base class for decoders."""
def __init__(self):
super().__init__()
def max_positions(self):
"""Maximum input length supported by the decoder."""
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
class FairseqEncoder(nn.Module):
"""Base class for encoders."""
def __init__(self):
super().__init__()
def max_positions(self):
"""Maximum input length supported by the encoder."""
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from . import FairseqDecoder
class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders."""
def __init__(self):
super().__init__()
self._is_incremental_eval = False
self._incremental_state = {}
def forward(self, tokens, encoder_out):
if self._is_incremental_eval:
raise NotImplementedError
else:
raise NotImplementedError
def incremental_inference(self):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call reorder_incremental_state to update the
relevant buffers. To generate a fresh sequence, first call
clear_incremental_state.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out)
probs = torch.nn.functional.log_softmax(out[:, -1, :])
```
"""
class IncrementalInference(object):
def __init__(self, decoder):
self.decoder = decoder
def __enter__(self):
self.decoder.incremental_eval(True)
def __exit__(self, *args):
self.decoder.incremental_eval(False)
return IncrementalInference(self)
def incremental_eval(self, mode=True):
"""Sets the decoder and all children in incremental evaluation mode."""
assert self._is_incremental_eval != mode, \
'incremental_eval already set to mode {}'.format(mode)
self._is_incremental_eval = mode
if mode:
self.clear_incremental_state()
def apply_incremental_eval(module):
if module != self and hasattr(module, 'incremental_eval'):
module.incremental_eval(mode)
self.apply(apply_incremental_eval)
def get_incremental_state(self, key):
"""Return cached state or None if not in incremental inference mode."""
if self._is_incremental_eval and key in self._incremental_state:
return self._incremental_state[key]
return None
def set_incremental_state(self, key, value):
"""Cache state needed for incremental inference mode."""
if self._is_incremental_eval:
self._incremental_state[key] = value
return value
def clear_incremental_state(self):
"""Clear all state used for incremental generation.
**For incremental inference only**
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
"""
if self._is_incremental_eval:
self._incremental_state = {}
def apply_clear_incremental_state(module):
if module != self and hasattr(module, 'clear_incremental_state'):
module.clear_incremental_state()
self.apply(apply_clear_incremental_state)
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation).
**For incremental inference only**
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the choice of beams.
"""
if self._is_incremental_eval:
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(new_order)
self.apply(apply_reorder_incremental_state)
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
def apply_set_beam_size(module):
if module != self and hasattr(module, 'set_beam_size'):
module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
from . import FairseqDecoder, FairseqEncoder
class FairseqModel(nn.Module):
"""Base class for encoder-decoder models."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self._is_generation_fast = False
def forward(self, src_tokens, input_tokens):
encoder_out = self.encoder(src_tokens)
decoder_out, _ = self.decoder(input_tokens, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
return self.encoder.max_positions()
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def apply_remove_weight_norm(module):
try:
nn.utils.remove_weight_norm(module)
except ValueError: # this module didn't have weight norm
return
self.apply(apply_remove_weight_norm)
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
# this model should no longer be used for training
self.eval()
self.train = train
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
...@@ -8,74 +8,42 @@ ...@@ -8,74 +8,42 @@
import math import math
import torch import torch
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules import BeamableMM, LinearizedConvolution from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LinearizedConvolution
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
class FConvModel(nn.Module):
def __init__(self, encoder, decoder):
super(FConvModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self.encoder.num_attention_layers = sum([layer is not None for layer in decoder.attention])
self._is_generation_fast = False
def forward(self, src_tokens, src_positions, input_tokens, input_positions):
encoder_out = self.encoder(src_tokens, src_positions)
decoder_out = self.decoder(input_tokens, input_positions, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
def make_generation_fast_(self, use_beamable_mm=False):
"""Optimize model for faster generation.
Optimizations include:
- remove WeightNorm
- (optionally) use BeamableMM in attention layers
The optimized model should not be used again for training. def make_positions(tokens, padding_idx, left_pad, offset=0):
seqlen = tokens.size(1)
if not hasattr(make_positions, 'range'):
make_positions.range = tokens.new()
if make_positions.range.numel() < offset + seqlen:
# offset positions by the padding index
torch.arange(padding_idx + 1, padding_idx + 1 + offset + seqlen,
out=make_positions.range)
mask = tokens.ne(padding_idx)
positions = make_positions.range[offset:offset+seqlen].expand_as(tokens)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tokens.clone().masked_scatter_(mask, positions[mask])
Note: this can be combined with incremental inference in the Decoder for
even faster generation.
"""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def remove_weight_norm(m):
try:
nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(remove_weight_norm)
# use BeamableMM in attention layers
if use_beamable_mm:
self.decoder._use_beamable_mm()
def train(mode): class FConvModel(FairseqModel):
if mode: def __init__(self, encoder, decoder):
raise RuntimeError('cannot train after make_generation_fast') super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
# this model should no longer be used for training
self.eval()
self.train = train
class Encoder(nn.Module): class FConvEncoder(FairseqEncoder):
"""Convolutional encoder""" """Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, max_positions=1024, def __init__(self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1): convolutions=((512, 3),) * 20, dropout=0.1):
super(Encoder, self).__init__() super().__init__()
self.dictionary = dictionary self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
self.num_attention_layers = None self.num_attention_layers = None
...@@ -99,9 +67,12 @@ class Encoder(nn.Module): ...@@ -99,9 +67,12 @@ class Encoder(nn.Module):
in_channels = out_channels in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim) self.fc2 = Linear(in_channels, embed_dim)
def forward(self, tokens, positions): def forward(self, src_tokens):
positions = Variable(make_positions(src_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE))
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions) x = self.embed_tokens(src_tokens) + self.embed_positions(positions)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
input_embedding = x input_embedding = x
...@@ -126,17 +97,21 @@ class Encoder(nn.Module): ...@@ -126,17 +97,21 @@ class Encoder(nn.Module):
x = self.fc2(x) x = self.fc2(x)
# scale gradients (this only affects backward, not forward) # scale gradients (this only affects backward, not forward)
x = grad_multiply(x, 1.0 / (2.0 * self.num_attention_layers)) x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention # add output to input embedding for attention
y = (x + input_embedding) * math.sqrt(0.5) y = (x + input_embedding) * math.sqrt(0.5)
return x, y return x, y
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
class AttentionLayer(nn.Module): class AttentionLayer(nn.Module):
def __init__(self, conv_channels, embed_dim, bmm=None): def __init__(self, conv_channels, embed_dim, bmm=None):
super(AttentionLayer, self).__init__() super().__init__()
# projects from output of convolution to embedding dimension # projects from output of convolution to embedding dimension
self.in_projection = Linear(conv_channels, embed_dim) self.in_projection = Linear(conv_channels, embed_dim)
# projects from embedding dimension to convolution size # projects from embedding dimension to convolution size
...@@ -167,13 +142,18 @@ class AttentionLayer(nn.Module): ...@@ -167,13 +142,18 @@ class AttentionLayer(nn.Module):
x = (self.out_projection(x) + residual) * math.sqrt(0.5) x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores return x, attn_scores
def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
"""Replace torch.bmm with BeamableMM."""
if beamable_mm_beam_size is not None:
self.bmm = BeamableMM(beamable_mm_beam_size)
class Decoder(nn.Module): class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256, def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1): attention=True, dropout=0.1):
super(Decoder, self).__init__() super().__init__()
self.dictionary = dictionary self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
...@@ -204,25 +184,45 @@ class Decoder(nn.Module): ...@@ -204,25 +184,45 @@ class Decoder(nn.Module):
self.fc2 = Linear(in_channels, out_embed_dim) self.fc2 = Linear(in_channels, out_embed_dim)
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
self._is_inference_incremental = False def forward(self, input_tokens, encoder_out):
if self._is_incremental_eval:
return self.incremental_forward(input_tokens, encoder_out)
else:
return self.batch_forward(input_tokens, encoder_out)
def batch_forward(self, input_tokens, encoder_out):
"""Forward pass for decoding multiple time steps in batch mode."""
positions = Variable(make_positions(input_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_TARGET))
return self._forward(input_tokens, positions, encoder_out)
def incremental_forward(self, input_tokens, encoder_out):
"""Forward pass for one time step."""
# positions is the same for every token when decoding a single step
positions = Variable(input_tokens.data.new(1, 1).fill_(
self.dictionary.pad() + input_tokens.size(1)))
# keep only the last token for incremental forward pass
return self._forward(input_tokens[:, -1:], positions, encoder_out)
def _forward(self, input_tokens, positions, encoder_out):
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out)
def forward(self, tokens, positions, encoder_out):
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions) x = self.embed_tokens(input_tokens) + self.embed_positions(positions)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x target_embedding = x
# project to size of convolution # project to size of convolution
x = self.fc1(x) x = self.fc1(x)
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = self._transpose_unless_incremental_eval(x)
# temporal convolutions # temporal convolutions
avg_attn_scores = None
num_attn_layers = len(self.attention)
for proj, conv, attention in zip(self.projections, self.convolutions, self.attention): for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
residual = x if proj is None else proj(x) residual = x if proj is None else proj(x)
...@@ -233,172 +233,58 @@ class Decoder(nn.Module): ...@@ -233,172 +233,58 @@ class Decoder(nn.Module):
# attention # attention
if attention is not None: if attention is not None:
x = x.transpose(1, 0) x = self._transpose_unless_incremental_eval(x)
x, _ = attention(x, target_embedding, (encoder_a, encoder_b))
x = x.transpose(1, 0)
# residual
x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
# project back to size of vocabulary
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x)
return x
def context_size(self):
"""Maximum number of input elements each output element depends on"""
context = 1
for conv in self.convolutions:
context += conv.kernel_size[0] - 1
return context
def max_positions(self):
"""Returns maximum size of positions embeddings supported by this decoder"""
return self.embed_positions.num_embeddings
def incremental_inference(self, beam_size=None):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call model.decoder.reorder_incremental_state to
update the relevant buffers. To generate a fresh sequence, first call
model.decoder.start_fresh_sequence.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out = model.decoder(tokens[:, :step], positions[:, :step],
encoder_out)
probs = F.log_softmax(out[:, -1, :])
```
"""
class IncrementalInference(object):
def __init__(self, decoder, beam_size):
self.decoder = decoder
self.beam_size = beam_size
def __enter__(self):
self.decoder._start_incremental_inference(self.beam_size)
def __exit__(self, *args):
self.decoder._stop_incremental_inference()
return IncrementalInference(self, beam_size)
def _start_incremental_inference(self, beam_size):
assert not self._is_inference_incremental, \
'already performing incremental inference'
self._is_inference_incremental = True
# save original forward
self._orig_forward = self.forward
# switch to incremental forward
self.forward = self._incremental_forward
# start a fresh sequence
self.start_fresh_sequence(beam_size)
def _stop_incremental_inference(self):
# restore original forward
self.forward = self._orig_forward
self._is_inference_incremental = False
def _incremental_forward(self, tokens, positions, encoder_out):
assert self._is_inference_incremental
# setup initial state
if self.prev_state is None:
# transpose encoder output once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
self.prev_state = {
'encoder_out': (encoder_a, encoder_b),
}
# load previous state
encoder_a, encoder_b = self.prev_state['encoder_out']
# keep only the last token for incremental forward pass
tokens = tokens[:, -1:]
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions)
target_embedding = x
# project to size of convolution
x = self.fc1(x)
# temporal convolutions
avg_attn_scores = None
num_attn_layers = len(self.attention)
for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
residual = x if proj is None else proj(x)
x = conv.incremental_forward(x)
x = F.glu(x)
# attention
if attention is not None:
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b)) x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b))
attn_scores = attn_scores / num_attn_layers attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None: if avg_attn_scores is None:
avg_attn_scores = attn_scores avg_attn_scores = attn_scores
else: else:
avg_attn_scores += attn_scores avg_attn_scores.add_(attn_scores)
x = self._transpose_unless_incremental_eval(x)
# residual # residual
x = (x + residual) * math.sqrt(0.5) x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = self._transpose_unless_incremental_eval(x)
# project back to size of vocabulary # project back to size of vocabulary
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x) x = self.fc3(x)
return x, avg_attn_scores return x, avg_attn_scores
def start_fresh_sequence(self, beam_size=None): def reorder_incremental_state(self, new_order):
"""Clear all state used for incremental generation. """Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(new_order)
**For incremental inference only** def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
This should be called before generating a fresh sequence. def _split_encoder_out(self, encoder_out):
beam_size is required if using BeamableMM. """Split and transpose encoder outputs.
"""
if self._is_inference_incremental:
self.prev_state = None
for conv in self.convolutions:
conv.clear_buffer()
for attn in self.attention:
if isinstance(attn.bmm, BeamableMM):
attn.bmm.set_beam_size(beam_size)
def reorder_incremental_state(self, new_order): This is cached when doing incremental inference.
"""Reorder buffered internal state (for incremental generation). """
cached_result = self.get_incremental_state('encoder_out')
if cached_result:
return cached_result
**For incremental inference only** # transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
result = (encoder_a, encoder_b)
This should be called when the order of the input has changed from the return self.set_incremental_state('encoder_out', result)
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the choice of beams.
"""
if self._is_inference_incremental:
for conv in self.convolutions:
conv.reorder_buffer(new_order)
def _use_beamable_mm(self): def _transpose_unless_incremental_eval(self, x):
"""Replace torch.bmm with BeamableMM in attention layers.""" if self._is_incremental_eval:
beamable_mm = BeamableMM() return x
for attn in self.attention: return x.transpose(0, 1)
attn.bmm = beamable_mm
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
...@@ -434,23 +320,6 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs): ...@@ -434,23 +320,6 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
return nn.utils.weight_norm(m, dim=2) return nn.utils.weight_norm(m, dim=2)
def grad_multiply(x, scale):
return GradMultiply.apply(x, scale)
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
ctx.mark_shared_storage((x, res))
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
def get_archs(): def get_archs():
return [ return [
'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr', 'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr',
...@@ -518,20 +387,20 @@ def parse_arch(args): ...@@ -518,20 +387,20 @@ def parse_arch(args):
def build_model(args, src_dict, dst_dict): def build_model(args, src_dict, dst_dict):
encoder = Encoder( encoder = FConvEncoder(
src_dict, src_dict,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers), convolutions=eval(args.encoder_layers),
dropout=args.dropout, dropout=args.dropout,
max_positions=args.max_positions, max_positions=args.max_source_positions,
) )
decoder = Decoder( decoder = FConvDecoder(
dst_dict, dst_dict,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers), convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention), attention=eval(args.decoder_attention),
dropout=args.dropout, dropout=args.dropout,
max_positions=args.max_positions, max_positions=args.max_target_positions,
) )
return FConvModel(encoder, decoder) return FConvModel(encoder, decoder)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
class LSTMModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def __init__(self, dictionary, embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1):
super().__init__()
self.dictionary = dictionary
self.dropout_in = dropout_in
self.dropout_out = dropout_out
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.layers = nn.ModuleList([
LSTMCell(embed_dim, embed_dim)
for layer in range(num_layers)
])
def forward(self, src_tokens):
bsz, seqlen = src_tokens.size()
num_layers = len(self.layers)
# embed tokens
x = self.embed_tokens(src_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
embed_dim = x.size(2)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
final_hiddens, final_cells = [], []
outs = [x[j] for j in range(seqlen)]
for i, rnn in enumerate(self.layers):
hidden = Variable(x.data.new(bsz, embed_dim).zero_())
cell = Variable(x.data.new(bsz, embed_dim).zero_())
for j in range(seqlen):
# recurrent cell
hidden, cell = rnn(outs[j], (hidden, cell))
# store the most recent hidden state in outs, either to be used
# as the input for the next layer, or as the final output
outs[j] = F.dropout(hidden, p=self.dropout_out, training=self.training)
# save the final hidden and cell states for every layer
final_hiddens.append(hidden)
final_cells.append(cell)
# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
final_hiddens = torch.cat(final_hiddens, dim=0).view(num_layers, bsz, embed_dim)
final_cells = torch.cat(final_cells, dim=0).view(num_layers, bsz, embed_dim)
return x, final_hiddens, final_cells
def max_positions(self):
"""Maximum input length supported by the encoder."""
return int(1e5) # an arbitrary large number
class AttentionLayer(nn.Module):
def __init__(self, input_embed_dim, output_embed_dim):
super().__init__()
self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False)
self.output_proj = Linear(2*output_embed_dim, output_embed_dim, bias=False)
def forward(self, input, source_hids):
# input: bsz x input_embed_dim
# source_hids: srclen x bsz x output_embed_dim
# x: bsz x output_embed_dim
x = self.input_proj(input)
# compute attention
attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
attn_scores = F.softmax(attn_scores.t(), dim=1).t() # srclen x bsz
# sum weighted sources
x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0)
x = F.tanh(self.output_proj(torch.cat((x, input), dim=1)))
return x, attn_scores
class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder."""
def __init__(self, dictionary, encoder_embed_dim=512, embed_dim=512,
out_embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1, attention=True):
super().__init__()
self.dictionary = dictionary
self.dropout_in = dropout_in
self.dropout_out = dropout_out
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.layers = nn.ModuleList([
LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim)
for layer in range(num_layers)
])
self.attention = AttentionLayer(encoder_embed_dim, embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, input_tokens, encoder_out):
bsz, seqlen = input_tokens.size()
num_layers = len(self.layers)
# get outputs from encoder
encoder_outs, _, _ = encoder_out
srclen = encoder_outs.size(0)
# embed tokens
x = self.embed_tokens(input_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
embed_dim = x.size(2)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# initialize previous states (or get from cache during incremental generation)
prev_hiddens = self.get_incremental_state('prev_hiddens')
if not prev_hiddens:
# first time step, initialize previous states
prev_hiddens, prev_cells = self._init_prev_states(input_tokens, encoder_out)
input_feed = Variable(x.data.new(bsz, embed_dim).zero_())
else:
# previous states are cached
prev_cells = self.get_incremental_state('prev_cells')
input_feed = self.get_incremental_state('input_feed')
attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
outs = []
for j in range(seqlen):
# input feeding: concatenate context vector from previous time step
input = torch.cat((x[j, :, :], input_feed), dim=1)
for i, rnn in enumerate(self.layers):
# recurrent cell
hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
# hidden state becomes the input to the next layer
input = F.dropout(hidden, p=self.dropout_out, training=self.training)
# save state for next time step
prev_hiddens[i] = hidden
prev_cells[i] = cell
# apply attention using the last layer's hidden state
out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs)
out = F.dropout(out, p=self.dropout_out, training=self.training)
# input feeding
input_feed = out
# save final output
outs.append(out)
# cache previous states (no-op except during incremental generation)
self.set_incremental_state('prev_hiddens', prev_hiddens)
self.set_incremental_state('prev_cells', prev_cells)
self.set_incremental_state('input_feed', input_feed)
# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores = attn_scores.transpose(0, 2)
# project back to size of vocabulary
x = self.fc_out(x)
return x, attn_scores
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(new_order)
new_order = Variable(new_order)
def reorder_state(key):
old = self.get_incremental_state(key)
if isinstance(old, list):
new = [old_i.index_select(0, new_order) for old_i in old]
else:
new = old.index_select(0, new_order)
self.set_incremental_state(key, new)
reorder_state('prev_hiddens')
reorder_state('prev_cells')
reorder_state('input_feed')
def max_positions(self):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number
def _init_prev_states(self, input_tokens, encoder_out):
_, encoder_hiddens, encoder_cells = encoder_out
bsz = input_tokens.size(0)
num_layers = len(self.layers)
embed_dim = encoder_hiddens.size(2)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)]
return prev_hiddens, prev_cells
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.uniform_(-0.1, 0.1)
return m
def LSTMCell(input_dim, hidden_dim, **kwargs):
m = nn.LSTMCell(input_dim, hidden_dim, **kwargs)
for name, param in m.named_parameters():
if 'weight' in name or 'bias' in name:
param.data.uniform_(-0.1, 0.1)
return m
def Linear(in_features, out_features, bias=True, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.uniform_(-0.1, 0.1)
if bias:
m.bias.data.uniform_(-0.1, 0.1)
return m
def get_archs():
return [
'lstm', 'lstm_wiseman_iwslt_de_en', 'lstm_luong_wmt_en_de',
]
def _check_arch(args):
"""Check that the specified architecture is valid and not ambiguous."""
if args.arch not in get_archs():
raise ValueError('Unknown LSTM model architecture: {}'.format(args.arch))
if args.arch != 'lstm':
# check that architecture is not ambiguous
for a in ['encoder_embed_dim', 'encoder_layers', 'decoder_embed_dim', 'decoder_layers',
'decoder_out_embed_dim']:
if hasattr(args, a):
raise ValueError('--{} cannot be combined with --arch={}'.format(a, args.arch))
def parse_arch(args):
_check_arch(args)
if args.arch == 'lstm_wiseman_iwslt_de_en':
args.encoder_embed_dim = 256
args.encoder_layers = 1
args.encoder_dropout_in = 0
args.encoder_dropout_out = 0
args.decoder_embed_dim = 256
args.decoder_layers = 1
args.decoder_out_embed_dim = 256
args.decoder_attention = True
args.decoder_dropout_in = 0
elif args.arch == 'lstm_luong_wmt_en_de':
args.encoder_embed_dim = 1000
args.encoder_layers = 4
args.encoder_dropout_out = 0
args.decoder_embed_dim = 1000
args.decoder_layers = 4
args.decoder_out_embed_dim = 1000
args.decoder_attention = True
args.decoder_dropout_out = 0
else:
assert args.arch == 'lstm'
# default architecture
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_layers = getattr(args, 'encoder_layers', 1)
args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_layers = getattr(args, 'decoder_layers', 1)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
args.decoder_attention = getattr(args, 'decoder_attention', True)
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
return args
def build_model(args, src_dict, dst_dict):
encoder = LSTMEncoder(
src_dict,
embed_dim=args.encoder_embed_dim,
num_layers=int(args.encoder_layers),
dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out,
)
decoder = LSTMDecoder(
dst_dict,
encoder_embed_dim=args.encoder_embed_dim,
embed_dim=args.decoder_embed_dim,
out_embed_dim=args.decoder_out_embed_dim,
num_layers=int(args.decoder_layers),
attention=bool(args.decoder_attention),
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
)
return LSTMModel(encoder, decoder)
...@@ -8,8 +8,12 @@ ...@@ -8,8 +8,12 @@
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .grad_multiply import GradMultiply
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
__all__ = [ __all__ = [
'BeamableMM', 'LinearizedConvolution', 'ConvTBC', 'BeamableMM',
'ConvTBC',
'GradMultiply',
'LinearizedConvolution',
] ]
...@@ -18,9 +18,9 @@ class BeamableMM(nn.Module): ...@@ -18,9 +18,9 @@ class BeamableMM(nn.Module):
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
""" """
def __init__(self): def __init__(self, beam_size=None):
super(BeamableMM, self).__init__() super(BeamableMM, self).__init__()
self.beam_size = None self.beam_size = beam_size
def forward(self, input1, input2): def forward(self, input1, input2):
if ( if (
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
ctx.mark_shared_storage((x, res))
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
...@@ -14,38 +14,42 @@ from .conv_tbc import ConvTBC ...@@ -14,38 +14,42 @@ from .conv_tbc import ConvTBC
class LinearizedConvolution(ConvTBC): class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d. """An optimized version of nn.Conv1d.
This module replaces convolutions with linear layers as appropriate At training time, this module uses ConvTBC, which is an optimized version
and supports optimizations for incremental inference. of Conv1d. At inference time, it optimizes incremental generation (i.e.,
one time step at a time) by replacing the convolutions with linear layers.
""" """
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs) super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.clear_buffer() self._is_incremental_eval = False
self._linearized_weight = None self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight) self.register_backward_hook(self._clear_linearized_weight)
def remove_future_timesteps(self, x): def remove_future_timesteps(self, x):
"""Remove future time steps created by padding.""" """Remove future time steps created by padding."""
if self.kernel_size[0] > 1 and self.padding[0] > 0: if not self._is_incremental_eval and self.kernel_size[0] > 1 and self.padding[0] > 0:
x = x[:-self.padding[0], :, :] x = x[:-self.padding[0], :, :]
return x return x
def incremental_eval(self, mode=True):
self._is_incremental_eval = mode
if mode:
self.clear_incremental_state()
def forward(self, input):
if self._is_incremental_eval:
return self.incremental_forward(input)
else:
return super().forward(input)
def incremental_forward(self, input): def incremental_forward(self, input):
"""Forward convolution one time step at a time. """Forward convolution one time step at a time.
This function maintains an internal state to buffer signal and This function maintains an internal state to buffer signal and accepts
accepts a single frame as input. If the input order changes a single frame as input. If the input order changes between time steps,
between time steps, call reorder_buffer. To apply to fresh call reorder_incremental_state. To apply to fresh inputs, call
inputs, call clear_buffer. clear_incremental_state.
""" """
if self.training:
raise RuntimeError('LinearizedConvolution only supports inference')
# run forward pre hooks (e.g., weight norm)
for hook in self._forward_pre_hooks.values():
hook(self, input)
# reshape weight # reshape weight
weight = self._get_linearized_weight() weight = self._get_linearized_weight()
kw = self.kernel_size[0] kw = self.kernel_size[0]
...@@ -65,10 +69,10 @@ class LinearizedConvolution(ConvTBC): ...@@ -65,10 +69,10 @@ class LinearizedConvolution(ConvTBC):
output = F.linear(input.view(bsz, -1), weight, self.bias) output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
def clear_buffer(self): def clear_incremental_state(self):
self.input_buffer = None self.input_buffer = None
def reorder_buffer(self, new_order): def reorder_incremental_state(self, new_order):
if self.input_buffer is not None: if self.input_buffer is not None:
self.input_buffer = self.input_buffer.index_select(0, new_order) self.input_buffer = self.input_buffer.index_select(0, new_order)
......
...@@ -69,7 +69,6 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -69,7 +69,6 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize optimizer # initialize optimizer
self.optimizer = self._build_optimizer() self.optimizer = self._build_optimizer()
self.flat_grads = None
self.loss = None self.loss = None
# initialize LR scheduler # initialize LR scheduler
...@@ -104,8 +103,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -104,8 +103,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
lr_scheduler = LambdaLR(self.optimizer, anneal) lr_scheduler = LambdaLR(self.optimizer, anneal)
lr_scheduler.best = None lr_scheduler.best = None
else: else:
# decay the LR by 0.1 every time the validation loss plateaus # decay the LR by a factor every time the validation loss plateaus
lr_scheduler = ReduceLROnPlateau(self.optimizer, patience=0) lr_scheduler = ReduceLROnPlateau(self.optimizer, patience=0,
factor=self.args.lrshrink)
return lr_scheduler return lr_scheduler
def get_model(self): def get_model(self):
...@@ -199,19 +199,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -199,19 +199,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# backward pass # backward pass
self.loss.backward() self.loss.backward()
# flatten grads into a contiguous block of memory # get model parameters as a flattened (contiguous) tensor
if self.flat_grads is None: flat_grads = self._flat_model_grads()
self.flat_grads = self._flatten_grads_(self.model)
# all-reduce grads # all-reduce grads
nccl.all_reduce(self.flat_grads) nccl.all_reduce(flat_grads)
# normalize grads # normalize grads
if grad_denom != 0: if grad_denom != 0:
self.flat_grads.div_(grad_denom) flat_grads.div_(grad_denom)
# clip grads # clip grads
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm) grad_norm = self._clip_grads_(flat_grads, self.args.clip_norm)
# copy reduced grads back
self._set_model_grads_(flat_grads)
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
...@@ -221,20 +223,34 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -221,20 +223,34 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return grad_norm return grad_norm
def _flatten_grads_(self, model): def _model_grads(self):
num_params = sum(p.data.numel() for p in model.parameters()) return [p.grad for p in self.model.parameters() if p.requires_grad]
flat_grads = next(model.parameters()).data.new(num_params)
def _flat_model_grads(self):
grads = self._model_grads()
if not hasattr(self, '_flat_grads'):
num_params = sum(g.data.numel() for g in grads)
self._flat_grads = grads[0].data.new(num_params)
offset = 0
for grad in grads:
grad = grad.data.view(-1)
numel = grad.numel()
self._flat_grads[offset:offset+numel].copy_(grad)
offset += numel
return self._flat_grads
def _set_model_grads_(self, flat_grads):
grads = self._model_grads()
offset = 0 offset = 0
for p in model.parameters(): for grad in grads:
grad = p.grad.data grad = grad.data.view(-1)
numel, sz = grad.numel(), grad.size() numel = grad.numel()
flat_grads[offset:offset+numel] = grad.view(-1) grad.copy_(flat_grads[offset:offset+numel])
grad.set_(flat_grads[offset:offset+numel])
grad.resize_(sz) # preserve original shape
offset += numel offset += numel
return flat_grads assert offset == flat_grads.numel()
def _clip_grads_(self, flat_grads, clipv): def _clip_grads_(self, flat_grads, clipv):
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors."""
norm = flat_grads.norm() norm = flat_grads.norm()
if clipv > 0 and norm > clipv: if clipv > 0 and norm > clipv:
coef = max(norm, 1e-6) / clipv coef = max(norm, 1e-6) / clipv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment