Unverified Commit 388c520b authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

0.4.0 -> 0.5.0

Changelog:
- 97b58b46: add Transformer model from Vaswani et al. (2017)
- b2374e52: faster Transformer inference with improved caching
- 2d27ae08: simulate large mini-batch training with delayed updates (`--update-freq`)
- 7ee1d284: add FP16 training support (`--fp16`)
- 2a84f46b: faster inference by removing completed sentences from the batch
- 663fd806: batched interactive generation
- 4c2ef2de: add language modeling / gated convolutional model from Dauphin et al. (2017)
- b59815bc: add Hierarchical Neural Story Generation model from Fan et al. (2018)
- ff68a9ef: add FairseqTask to modularize task definitions (e.g., translation, language modeling)
parents ec0031df 5383b5db
# Introduction # Introduction
Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization and other text generation tasks. It provides reference implementations of various sequence-to-sequence models, including: Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks. It provides reference implementations of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)** - **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](https://arxiv.org/abs/1612.08083)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122) - [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://arxiv.org/abs/1711.04956) - **_New_** [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://arxiv.org/abs/1711.04956)
- **_New_** [Fan et al. (2018): Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833)
- **Long Short-Term Memory (LSTM) networks** - **Long Short-Term Memory (LSTM) networks**
- [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025) - [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)
- [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960) - [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960)
- **Transformer (self-attention) networks**
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- **_New_** [Ott et al. (2018): Scaling Neural Machine Translation](https://arxiv.org/abs/1806.00187)
Fairseq features multi-GPU (distributed) training on one machine or across multiple machines, fast beam search generation on both CPU and GPU, and includes pre-trained models for several benchmark translation datasets. Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
- fast beam search generation on both CPU and GPU
- large mini-batch training (even on a single GPU) via delayed updates
- fast half-precision floating point (FP16) training
We also provide [pre-trained models](#pre-trained-models) for several benchmark translation datasets.
![Model](fairseq.gif) ![Model](fairseq.gif)
...@@ -38,6 +49,7 @@ The following command-line tools are provided: ...@@ -38,6 +49,7 @@ The following command-line tools are provided:
* `python generate.py`: Translate pre-processed data with a trained model * `python generate.py`: Translate pre-processed data with a trained model
* `python interactive.py`: Translate raw text with a trained model * `python interactive.py`: Translate raw text with a trained model
* `python score.py`: BLEU scoring of generated translations against reference translations * `python score.py`: BLEU scoring of generated translations against reference translations
* `python eval_lm.py`: Language model evaluation
## Evaluating Pre-trained Models ## Evaluating Pre-trained Models
First, download a pre-trained model along with its vocabularies: First, download a pre-trained model along with its vocabularies:
...@@ -71,16 +83,19 @@ This generation script produces four types of outputs: a line prefixed with *S* ...@@ -71,16 +83,19 @@ This generation script produces four types of outputs: a line prefixed with *S*
Check [below](#pre-trained-models) for a full list of pre-trained models available. Check [below](#pre-trained-models) for a full list of pre-trained models available.
## Training a New Model ## Training a New Model
The following tutorial is for machine translation.
For an example of how to use Fairseq for other tasks, such as [language modeling](examples/language_model/README.md), please see the `examples/` directory.
### Data Pre-processing ### Data Pre-processing
Fairseq contains example pre-processing scripts for several translation datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT 2014 (English-German). Fairseq contains example pre-processing scripts for several translation datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT 2014 (English-German).
To pre-process and binarize the IWSLT dataset: To pre-process and binarize the IWSLT dataset:
``` ```
$ cd data/ $ cd examples/translation/
$ bash prepare-iwslt14.sh $ bash prepare-iwslt14.sh
$ cd .. $ cd ../..
$ TEXT=data/iwslt14.tokenized.de-en $ TEXT=data/iwslt14.tokenized.de-en
$ python preprocess.py --source-lang de --target-lang en \ $ python preprocess.py --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
...@@ -125,15 +140,25 @@ BPE continuation markers can be removed with the `--remove-bpe` flag. ...@@ -125,15 +140,25 @@ BPE continuation markers can be removed with the `--remove-bpe` flag.
# Pre-trained Models # Pre-trained Models
We provide the following pre-trained fully convolutional sequence-to-sequence models: We provide the following pre-trained models and pre-processed, binarized test sets:
### Translation
* [wmt14.en-fr.fconv-py.tar.bz2](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2): Pre-trained model for [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) including vocabularies Description | Dataset | Model | Test set(s)
* [wmt14.en-de.fconv-py.tar.bz2](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-de.fconv-py.tar.bz2): Pre-trained model for [WMT14 English-German](https://nlp.stanford.edu/projects/nmt) including vocabularies ---|---|---|---
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](https://nlp.stanford.edu/projects/nmt) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-de.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt14.en-fr.joined-dict.transformer.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
In addition, we provide pre-processed and binarized test sets for the models above: ### Language models
* [wmt14.en-fr.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-French
* [wmt14.en-fr.ntst1213.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-fr.ntst1213.tar.bz2): newstest2012 and newstest2013 test sets for WMT14 English-French Description | Dataset | Model | Test set(s)
* [wmt14.en-de.newstest2014.tar.bz2](https://s3.amazonaws.com/fairseq-py/data/wmt14.v2.en-de.newstest2014.tar.bz2): newstest2014 test set for WMT14 English-German ---|---|---|---
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/gbw_test_lm.tar.bz2)
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://s3.amazonaws.com/fairseq-py/data/wiki103_test_lm.tar.bz2)
### Usage
Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti: Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
``` ```
...@@ -153,39 +178,66 @@ $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref ...@@ -153,39 +178,66 @@ $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
``` ```
# Distributed version # Large mini-batch training with delayed updates
The `--update-freq` option can be used to accumulate gradients from multiple mini-batches and delay updating,
creating a larger effective batch size.
Delayed updates can also improve training speed by reducing inter-GPU communication costs and by saving idle time caused by variance in workload across GPUs.
See [Ott et al. (2018)](https://arxiv.org/abs/1806.00187) for more details.
To train on a single GPU with an effective batch size that is equivalent to training on 8 GPUs:
```
CUDA_VISIBLE_DEVICES=0 python train.py --update-freq 8 (...)
```
# Training with half precision floating point (FP16)
> Note: FP16 training requires a Volta GPU and CUDA 9.1 or greater
Recent GPUs enable efficient half precision floating point computation, e.g., using [Nvidia Tensor Cores](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html).
Fairseq supports FP16 training with the `--fp16` flag:
```
python train.py --fp16 (...)
```
# Distributed training
Distributed training in fairseq is implemented on top of [torch.distributed](http://pytorch.org/docs/master/distributed.html). Distributed training in fairseq is implemented on top of [torch.distributed](http://pytorch.org/docs/master/distributed.html).
Training begins by launching one worker process per GPU. Training begins by launching one worker process per GPU.
These workers discover each other via a unique host and port (required) that can be used to establish an initial connection. These workers discover each other via a unique host and port (required) that can be used to establish an initial connection.
Additionally, each worker is given a rank, that is a unique number from 0 to n-1 where n is the total number of GPUs. Additionally, each worker has a rank, that is a unique number from 0 to n-1 where n is the total number of GPUs.
If you run on a cluster managed by [SLURM](https://slurm.schedmd.com/) you can train a large English-French model on the WMT 2014 dataset on 16 nodes with 8 GPUs each (in total 128 GPUs) using this command: If you run on a cluster managed by [SLURM](https://slurm.schedmd.com/) you can train a large English-French model on the WMT 2014 dataset on 16 nodes with 8 GPUs each (in total 128 GPUs) using this command:
``` ```
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes $ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ PORT=9218 # any available tcp port that can be used by the trained to establish initial connection $ PORT=9218 # any available TCP port that can be used by the trainer to establish initial connection
$ sbatch --job-name fairseq-py --gres gpu:8 --nodes 16 --ntasks-per-node 8 \ $ sbatch --job-name fairseq-py --gres gpu:8 --cpus-per-task 10 \
--cpus-per-task 10 --no-requeue --wrap 'srun --output train.log.node%t \ --nodes 16 --ntasks-per-node 8 \
--error train.stderr.node%t.%j python train.py $DATA --distributed-world-size 128 \ --wrap 'srun --output train.log.node%t --error train.stderr.node%t.%j \
--distributed-port $PORT --force-anneal 50 --lr-scheduler fixed --max-epoch 55 \ python train.py $DATA \
--distributed-world-size 128 \
--distributed-port $PORT \
--force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
--arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \ --arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \
--clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \ --clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 --wd 0.0001' --label-smoothing 0.1 --wd 0.0001'
``` ```
Alternatively you'll need to manually start one process per each GPU: Alternatively you can manually start one process per GPU:
``` ```
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes $ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ HOST_PORT=your.devserver.com:9218 # has to be one of the hosts that will be used by the job \ $ HOST_PORT=master.devserver.com:9218 # one of the hosts used by the job
and the port on that host has to be available $ RANK=... # the rank of this process, from 0 to 127 in case of 128 GPUs
$ RANK=... # the rank of this process, has to go from 0 to 127 in case of 128 GPUs $ python train.py $DATA \
$ python train.py $DATA --distributed-world-size 128 \ --distributed-world-size 128 \
--distributed-init-method 'tcp://$HOST_PORT' \
--distributed-rank $RANK \
--force-anneal 50 --lr-scheduler fixed --max-epoch 55 \ --force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
--arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \ --arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \
--clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \ --clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 --wd 0.0001 \ --label-smoothing 0.1 --wd 0.0001
--distributed-init-method='tcp://$HOST_PORT' --distributed-rank=$RANK
``` ```
# Join the fairseq community # Join the fairseq community
......
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
import socket import socket
import subprocess import subprocess
from singleprocess_train import main as single_process_main from train import main as single_process_main
from fairseq import distributed_utils, options from fairseq import distributed_utils, options
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from fairseq import data, options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
def main(args):
assert args.path is not None, '--path required for evaluation!'
if args.tokens_per_sample is None:
args.tokens_per_sample = 1024
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()
itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_sentences=args.max_sentences or 4,
max_positions=model.max_positions(),
num_shards=args.num_shards,
shard_id=args.shard_id,
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(models, task.target_dictionary)
if use_cuda:
scorer.cuda()
score_sum = 0.
count = 0
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
pos_scores = hypo['positional_scores']
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum()
count += pos_scores.numel()
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
avg_nll_loss = -score_sum / count
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
if __name__ == '__main__':
parser = options.get_eval_lm_parser()
args = options.parse_args_and_arch(parser)
main(args)
Sample data processing scripts for the FAIR Sequence-to-Sequence Toolkit
These scripts provide an example of pre-processing data for the Language Modeling task.
# prepare-wikitext-103.sh
Provides an example of pre-processing for [WikiText-103 language modeling task](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset):
Example usage:
```
$ cd examples/language_model/
$ bash prepare-wikitext-103.sh
$ cd ../..
# Binarize the dataset:
$ TEXT=examples/language_model/wikitext-103
$ python preprocess.py --only-source \
--trainpref $TEXT/wiki.train.tokens --validpref $TEXT/wiki.valid.tokens --testpref $TEXT/wiki.test.tokens \
--destdir data-bin/wikitext-103
# Train the model:
# If it runs out of memory, try to reduce max-tokens and max-target-positions
$ mkdir -p checkpoints/wikitext-103
$ python train.py --task language_modeling data-bin/wikitext-103 \
--max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
--clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024
# Evaluate:
$ python eval_lm.py data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt'
```
#!/bin/bash
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
URLS=(
"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip"
)
FILES=(
"wikitext-103-v1.zip"
)
for ((i=0;i<${#URLS[@]};++i)); do
file=${FILES[i]}
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
url=${URLS[i]}
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit -1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
elif [ ${file: -4} == ".zip" ]; then
unzip $file
fi
fi
done
cd ..
FAIR Sequence-to-Sequence Toolkit for Story Generation
The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset.
The dataset can be downloaded like this:
```
curl https://s3.amazonaws.com/fairseq-py/data/writingPrompts.tar.gz | tar xvjf -
```
and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833, where only the first 1000 words of each story are modeled.
Example usage:
```
# Binarize the dataset:
$ TEXT=examples/stories/writingPrompts
$ python preprocess.py --source-lang wp_source --target-lang wp_target \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/writingPrompts --thresholdtgt 10 --thresholdsrc 10
# Train the model:
$ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False
# Train a fusion model:
# add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint
# Generate:
$ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1
```
...@@ -8,9 +8,9 @@ Provides an example of pre-processing for IWSLT'14 German to English translation ...@@ -8,9 +8,9 @@ Provides an example of pre-processing for IWSLT'14 German to English translation
Example usage: Example usage:
``` ```
$ cd data/ $ cd examples/translation/
$ bash prepare-iwslt14.sh $ bash prepare-iwslt14.sh
$ cd .. $ cd ../..
# Binarize the dataset: # Binarize the dataset:
$ TEXT=data/iwslt14.tokenized.de-en $ TEXT=data/iwslt14.tokenized.de-en
...@@ -47,9 +47,9 @@ $ bash prepare-wmt14en2de.sh --icml17 ...@@ -47,9 +47,9 @@ $ bash prepare-wmt14en2de.sh --icml17
Example usage: Example usage:
``` ```
$ cd data/ $ cd examples/translation/
$ bash prepare-wmt14en2de.sh $ bash prepare-wmt14en2de.sh
$ cd .. $ cd ../..
# Binarize the dataset: # Binarize the dataset:
$ TEXT=data/wmt14_en_de $ TEXT=data/wmt14_en_de
...@@ -79,9 +79,9 @@ Provides an example of pre-processing for the WMT'14 English to French translati ...@@ -79,9 +79,9 @@ Provides an example of pre-processing for the WMT'14 English to French translati
Example usage: Example usage:
``` ```
$ cd data/ $ cd examples/translation/
$ bash prepare-wmt14en2fr.sh $ bash prepare-wmt14en2fr.sh
$ cd .. $ cd ../..
# Binarize the dataset: # Binarize the dataset:
$ TEXT=data/wmt14_en_fr $ TEXT=data/wmt14_en_fr
...@@ -103,4 +103,3 @@ $ python generate.py data-bin/fconv_wmt_en_fr \ ...@@ -103,4 +103,3 @@ $ python generate.py data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe --path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe
``` ```
...@@ -15,8 +15,8 @@ CRITERION_REGISTRY = {} ...@@ -15,8 +15,8 @@ CRITERION_REGISTRY = {}
CRITERION_CLASS_NAMES = set() CRITERION_CLASS_NAMES = set()
def build_criterion(args, src_dict, dst_dict): def build_criterion(args, task):
return CRITERION_REGISTRY[args.criterion](args, src_dict, dst_dict) return CRITERION_REGISTRY[args.criterion](args, task)
def register_criterion(name): def register_criterion(name):
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch.nn.functional as F
from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('adaptive_loss')
class AdaptiveLoss(FairseqCriterion):
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def __init__(self, args, task):
super().__init__(args, task)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert hasattr(model.decoder, 'adaptive_softmax') and model.decoder.adaptive_softmax is not None
adaptive_softmax = model.decoder.adaptive_softmax
net_output = model(**sample['net_input'])
target = model.get_targets(sample, net_output).view(-1)
bsz = target.size(0)
logits, target = adaptive_softmax(net_output[0], target)
assert len(target) == len(logits)
loss = net_output[0].new(1 if reduce else bsz).zero_()
for i in range(len(target)):
if target[i] is not None:
assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1))
loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'sample_size': sample_size,
}
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output
...@@ -16,14 +16,14 @@ from . import FairseqCriterion, register_criterion ...@@ -16,14 +16,14 @@ from . import FairseqCriterion, register_criterion
@register_criterion('cross_entropy') @register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion): class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, src_dict, dst_dict): def __init__(self, args, task):
super().__init__(args, src_dict, dst_dict) super().__init__(args, task)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
1) the loss, as a Variable 1) the loss
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
......
...@@ -10,10 +10,10 @@ from torch.nn.modules.loss import _Loss ...@@ -10,10 +10,10 @@ from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss): class FairseqCriterion(_Loss):
def __init__(self, args, src_dict, dst_dict): def __init__(self, args, task):
super().__init__() super().__init__()
self.args = args self.args = args
self.padding_idx = dst_dict.pad() self.padding_idx = task.target_dictionary.pad()
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -24,7 +24,7 @@ class FairseqCriterion(_Loss): ...@@ -24,7 +24,7 @@ class FairseqCriterion(_Loss):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
1) the loss, as a Variable 1) the loss
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
......
...@@ -15,8 +15,8 @@ from . import FairseqCriterion, register_criterion ...@@ -15,8 +15,8 @@ from . import FairseqCriterion, register_criterion
@register_criterion('label_smoothed_cross_entropy') @register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, src_dict, dst_dict): def __init__(self, args, task):
super().__init__(args, src_dict, dst_dict) super().__init__(args, task)
self.eps = args.label_smoothing self.eps = args.label_smoothing
@staticmethod @staticmethod
...@@ -29,7 +29,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -29,7 +29,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
1) the loss, as a Variable 1) the loss
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
import itertools
import glob
import math
import numbers
import numpy as np
import os
import torch
import torch.utils.data
from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
def has_binary_files(data_dir, splits):
for split in splits:
if len(glob.glob(os.path.join(data_dir, '{}.*-*.*.bin'.format(split)))) < 2:
return False
return True
def infer_language_pair(path, splits):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in os.listdir(path):
parts = filename.split('.')
for split in splits:
if parts[0] == split and parts[-1] == 'idx':
src, dst = parts[1].split('-')
break
return src, dst
def load_dictionaries(path, src_lang, dst_lang):
"""Load dictionaries for a given language pair."""
src_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(src_lang)))
dst_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(dst_lang)))
return src_dict, dst_dict
def load_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from the
specified folder and check that files exist."""
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from binary files
def all_splits_exist(src, dst, lang):
for split in load_splits:
filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang)
if not os.path.exists(os.path.join(path, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst, src):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + path)
def fmt_path(fmt, *args):
return os.path.join(path, fmt.format(*args))
for split in load_splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path('{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path):
break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
target_dataset,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
def load_raw_text_dataset(path, load_splits, src=None, dst=None):
"""Loads specified data splits (e.g., test, train or valid) from raw text
files in the specified folder."""
if src is None and dst is None:
# find language pair automatically
src, dst = infer_language_pair(path, load_splits)
assert src is not None and dst is not None, 'Source and target languages should be provided'
src_dict, dst_dict = load_dictionaries(path, src, dst)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from raw text files
for split in load_splits:
src_path = os.path.join(path, '{}.{}'.format(split, src))
dst_path = os.path.join(path, '{}.{}'.format(split, dst))
dataset.splits[split] = LanguagePairDataset(
IndexedRawTextDataset(src_path, src_dict),
IndexedRawTextDataset(dst_path, dst_dict),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
self.dst = dst
self.src_dict = src_dict
self.dst_dict = dst_dict
self.splits = {}
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader(self, split, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
seed=None, epoch=1, sample_without_replacement=0,
sort_by_source_size=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, epoch=epoch,
sample=sample_without_replacement, max_positions=max_positions,
sort_by_source_size=sort_by_source_size)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
class sharded_iterator(object):
def __init__(self, itr, num_shards, shard_id):
assert shard_id >= 0 and shard_id < num_shards
self.itr = itr
self.num_shards = num_shards
self.shard_id = shard_id
def __len__(self):
return len(self.itr)
def __iter__(self):
for i, v in enumerate(self.itr):
if i % self.num_shards == self.shard_id:
yield v
class LanguagePairDataset(torch.utils.data.Dataset):
# padding constants
LEFT_PAD_SOURCE = True
LEFT_PAD_TARGET = False
def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src
self.dst = dst
self.pad_idx = pad_idx
self.eos_idx = eos_idx
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __len__(self):
return len(self.src)
def collater(self, samples):
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod
def collate(samples, pad_idx, eos_idx, has_target=True):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return LanguagePairDataset.collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
ntokens = None
if has_target:
target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
@staticmethod
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
if left_pad:
copy_tensor(v, res[i][size-len(v):])
else:
copy_tensor(v, res[i][:len(v)])
return res
def _valid_size(src_size, dst_size, max_positions):
if isinstance(max_positions, numbers.Number):
max_src_positions, max_dst_positions = max_positions, max_positions
else:
max_src_positions, max_dst_positions = max_positions
if src_size < 1 or src_size > max_src_positions:
return False
if dst_size is not None and (dst_size < 1 or dst_size > max_dst_positions):
return False
return True
def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=False, allow_different_src_lens=False):
batch = []
def yield_batch(next_idx, num_tokens):
if len(batch) == 0:
return False
if len(batch) == max_sentences:
return True
if num_tokens > max_tokens:
return True
if not allow_different_src_lens and \
(src.sizes[batch[0]] != src.sizes[next_idx]):
return True
return False
sample_len = 0
ignored = []
for idx in map(int, indices):
src_size = src.sizes[idx]
dst_size = dst.sizes[idx] if dst else src_size
if not _valid_size(src_size, dst_size, max_positions):
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
"Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src_size, dst_size, max_positions))
sample_len = max(sample_len, src_size, dst_size)
num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens):
yield batch
batch = []
sample_len = max(src_size, dst_size)
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset))
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False))
def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
epoch=1, sample=0, max_positions=(1024, 1024),
sort_by_source_size=False):
"""Returns batches of indices, bucketed by size and then shuffled. Batches
may contain sequences of different lengths."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.random.permutation(len(src))
# sort by sizes
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True))
if not sort_by_source_size:
np.random.shuffle(batches)
if sample:
offset = (epoch - 1) * sample
while offset > len(batches):
np.random.shuffle(batches)
offset -= len(batches)
result = batches[offset:(offset + sample)]
while len(result) < sample:
np.random.shuffle(batches)
result += batches[:(sample - len(result))]
assert len(result) == sample, \
"batch length is not correct {}".format(len(result))
batches = result
return batches
def mask_batches(batch_sampler, shard_id, num_shards):
if num_shards == 1:
return batch_sampler
res = [
batch
for i, batch in enumerate(batch_sampler)
if i % num_shards == shard_id
]
expected_length = int(math.ceil(len(batch_sampler) / num_shards))
return res + [[]] * (expected_length - len(res))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .dictionary import Dictionary
from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset
from .data_utils import EpochBatchIterator
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
import itertools
import os
import numpy as np
import torch
from . import FairseqDataset
def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in os.listdir(path):
parts = filename.split('.')
if len(parts) >= 3 and len(parts[1].split('-')) == 2:
return parts[1].split('-')
return src, dst
class ShardedIterator(object):
"""A sharded wrapper around an iterable (padded to length)."""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self):
return self._sharded_len
def __iter__(self):
return self
def __next__(self):
return next(self.itr)[1]
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count."""
def __init__(self, iterable):
self.iterable = iterable
self.count = 0
self.itr = iter(self)
def __len__(self):
return len(self.iterable)
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
return self.count < len(self)
def skip(self, num_to_skip):
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
return res
class EpochBatchIterator(object):
"""Iterate over a FairseqDataset and yield batches bucketed by size.
Batches may contain sequences of different lengths. This iterator can be
reused across multiple epochs with the next_epoch_itr() method.
Args:
dataset: a FairseqDataset
max_tokens: max number of tokens in each batch
max_sentences: max number of sentences in each batch
max_positions: max sentence length supported by the model
ignore_invalid_inputs: don't raise Exception for sentences that are too long
required_batch_size_multiple: require batch size to be a multiple of N
seed: seed for random number generator for reproducibility
num_shards: shard the data iterator into N shards
shard_id: which shard of the data iterator to return
"""
def __init__(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1,
num_shards=1, shard_id=0,
):
assert isinstance(dataset, FairseqDataset)
self.dataset = dataset
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
self.max_positions = max_positions
self.ignore_invalid_inputs = ignore_invalid_inputs
self.bsz_mult = required_batch_size_multiple
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
with numpy_seed(self.seed):
self.frozen_batches = tuple(self._batch_generator())
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
def __len__(self):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
"""Shuffle batches and return a new iterator over the dataset."""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._cur_epoch_itr
def end_of_epoch(self):
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
def state_dict(self):
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(self.seed + epoch):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches)
else:
batches = self.frozen_batches
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.dataset.collater,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
def _batch_generator(self):
batch = []
def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == self.max_sentences:
return True
if num_tokens > self.max_tokens:
return True
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in self.dataset.ordered_indices():
if not self.dataset.valid_size(idx, self.max_positions):
if self.ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
'Size of sample #{} is invalid, max_positions={}, skip this '
'example with --skip-invalid-size-inputs-valid-test'
).format(idx, self.max_positions))
sample_lens.append(self.dataset.num_tokens(idx))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(
self.bsz_mult * (len(batch) // self.bsz_mult),
len(batch) % self.bsz_mult,
)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print((
'| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).format(len(ignored), self.max_positions, ignored[:10]))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math from collections import Counter
import os import os
import torch import torch
...@@ -58,7 +59,7 @@ class Dictionary(object): ...@@ -58,7 +59,7 @@ class Dictionary(object):
sent = ' '.join(token_string(i) for i in tensor if i != self.eos()) sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None: if bpe_symbol is not None:
sent = sent.replace(bpe_symbol, '') sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent return sent
def unk_string(self, escape=False): def unk_string(self, escape=False):
...@@ -94,13 +95,44 @@ class Dictionary(object): ...@@ -94,13 +95,44 @@ class Dictionary(object):
self.symbols.append(word) self.symbols.append(word)
self.count.append(new_dict.count[idx2]) self.count.append(new_dict.count[idx2])
def finalize(self): def finalize(self, threshold=1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.""" """Sort symbols by frequency in descending order, ignoring special ones.
self.count, self.symbols = zip(
*sorted(zip(self.count, self.symbols), Args:
key=(lambda x: math.inf if self.indices[x[1]] < self.nspecial else x[0]), - threshold defines the minimum word count
reverse=True) - nwords defines the total number of words in the final dictionary,
) including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if nwords == -1:
nwords = len(self)
new_symbols = self.symbols[:self.nspecial]
new_count = self.count[:self.nspecial]
c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])))
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_symbols.append(symbol)
new_count.append(count)
else:
break
threshold_nwords = len(new_symbols)
if padding_factor > 1:
i = 0
while threshold_nwords % padding_factor != 0:
new_symbols.append('madeupword{:04d}'.format(i))
i += 1
threshold_nwords += 1
assert min(new_count[self.nspecial:]) >= threshold
assert len(new_symbols) % padding_factor == 0
self.count = tuple(new_count)
self.symbols = tuple(new_symbols)
def pad(self): def pad(self):
"""Helper to get index of pad symbol""" """Helper to get index of pad symbol"""
...@@ -124,7 +156,6 @@ class Dictionary(object): ...@@ -124,7 +156,6 @@ class Dictionary(object):
... ...
``` ```
""" """
if isinstance(f, str): if isinstance(f, str):
try: try:
if not ignore_utf_errors: if not ignore_utf_errors:
...@@ -155,9 +186,10 @@ class Dictionary(object): ...@@ -155,9 +186,10 @@ class Dictionary(object):
os.makedirs(os.path.dirname(f), exist_ok=True) os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd: with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd, threshold, nwords) return self.save(fd, threshold, nwords)
cnt = 0 for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
for i, t in enumerate(zip(self.symbols, self.count)): print('{} {}'.format(symbol, count), file=f)
if i >= self.nspecial and t[1] >= threshold \
and (nwords <= 0 or cnt < nwords): def dummy_sentence(self, length):
print('{} {}'.format(t[0], t[1]), file=f) t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
cnt += 1 t[-1] = self.eos()
return t
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.utils.data
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
raise NotImplementedError
def get_dummy_batch(self, num_tokens, max_positions):
"""Return a dummy batch with a given number of tokens."""
raise NotImplementedError
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
raise NotImplementedError
def ordered_indices(self):
"""Ordered indices for batching."""
raise NotImplementedError
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment