"vscode:/vscode.git/clone" did not exist on "efa3cbcf07d90e7bfe98f04d5b6076c7a31ba6e7"
README.md 11.8 KB
Newer Older
1
# Neural Machine Translation
2

3
4
5
6
7
8
9
10
11
## Pre-trained models

Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
12
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) | See NOTE in the archive
13

Myle Ott's avatar
Myle Ott committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
## Example usage (torch.hub)

Interactive generation via PyTorch Hub:
```
>>> import torch
>>> en2de = torch.hub.load(
...   'pytorch/fairseq',
...   'transformer',
...   model_name_or_path='transformer.wmt16.en-de',
...   data_name_or_path='.',
...   tokenizer='moses',
...   aggressive_dash_splits=True,
...   bpe='subword_nmt',
... )
>>> print(en2de.models[0].__class__)
<class 'fairseq.models.transformer.TransformerModel'>
>>> print(en2de.generate('Hello world!'))
Hallo Welt!
```

Available models are listed in the ``hub_models()`` method in each model file, for example:
[transformer.py](https://github.com/pytorch/fairseq/blob/master/fairseq/models/transformer.py).

## Example usage (CLI tools)
38
39
40
41
42
43

Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
```
$ mkdir -p data-bin
$ curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
$ curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
Myle Ott's avatar
Myle Ott committed
44
$ fairseq-generate data-bin/wmt14.en-fr.newstest2014  \
45
46
47
48
49
50
  --path data-bin/wmt14.en-fr.fconv-py/model.pt \
  --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
...
| Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s)
| Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)

Myle Ott's avatar
Myle Ott committed
51
# Compute BLEU score
52
53
$ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
$ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
Myle Ott's avatar
Myle Ott committed
54
$ fairseq-score --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
55
56
BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
```
57

58
59
## Preprocessing

60
61
These scripts provide an example of pre-processing data for the NMT task.

62
### prepare-iwslt14.sh
63
64
65
66
67

Provides an example of pre-processing for IWSLT'14 German to English translation task: ["Report on the 11th IWSLT evaluation campaign" by Cettolo et al.](http://workshop2014.iwslt.org/downloads/proceeding.pdf)

Example usage:
```
68
$ cd examples/translation/
69
$ bash prepare-iwslt14.sh
70
$ cd ../..
71
72

# Binarize the dataset:
73
$ TEXT=examples/translation/iwslt14.tokenized.de-en
Myle Ott's avatar
Myle Ott committed
74
$ fairseq-preprocess --source-lang de --target-lang en \
75
76
77
  --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
  --destdir data-bin/iwslt14.tokenized.de-en

78
# Train the model (better for a single GPU setup):
79
$ mkdir -p checkpoints/fconv
Myle Ott's avatar
Myle Ott committed
80
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
81
  --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
Runqi Yang's avatar
Runqi Yang committed
82
  --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
Runqi Yang's avatar
Runqi Yang committed
83
  --lr-scheduler fixed --force-anneal 200 \
84
85
86
  --arch fconv_iwslt_de_en --save-dir checkpoints/fconv

# Generate:
Myle Ott's avatar
Myle Ott committed
87
$ fairseq-generate data-bin/iwslt14.tokenized.de-en \
88
89
90
91
92
  --path checkpoints/fconv/checkpoint_best.pt \
  --batch-size 128 --beam 5 --remove-bpe

```

93
94
95
96
97
98
To train transformer model on IWSLT'14 German to English:
```
# Preparation steps are the same as for fconv model.

# Train the model (better for a single GPU setup):
$ mkdir -p checkpoints/transformer
Myle Ott's avatar
Myle Ott committed
99
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
100
101
102
103
104
105
106
107
108
109
110
111
  -a transformer_iwslt_de_en --optimizer adam --lr 0.0005 -s de -t en \
  --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 \
  --min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
  --criterion label_smoothed_cross_entropy --max-update 50000 \
  --warmup-updates 4000 --warmup-init-lr '1e-07' \
  --adam-betas '(0.9, 0.98)' --save-dir checkpoints/transformer

# Average 10 latest checkpoints:
$ python scripts/average_checkpoints.py --inputs checkpoints/transformer \
   --num-epoch-checkpoints 10 --output checkpoints/transformer/model.pt

# Generate:
Myle Ott's avatar
Myle Ott committed
112
$ fairseq-generate data-bin/iwslt14.tokenized.de-en \
113
114
115
116
117
  --path checkpoints/transformer/model.pt \
  --batch-size 128 --beam 5 --remove-bpe

```

118
### prepare-wmt14en2de.sh
119

120
121
The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script.
By default it will produce a dataset that was modeled after ["Attention Is All You Need" (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762), but with news-commentary-v12 data from WMT'17.
122

123
To use only data available in WMT'14 or to replicate results obtained in the original ["Convolutional Sequence to Sequence Learning" (Gehring et al., 2017)](https://arxiv.org/abs/1705.03122) paper, please use the `--icml17` option.
124
125
126
127
128
129
130
131

```
$ bash prepare-wmt14en2de.sh --icml17
```

Example usage:

```
132
$ cd examples/translation/
133
$ bash prepare-wmt14en2de.sh
134
$ cd ../..
135
136

# Binarize the dataset:
137
$ TEXT=examples/translation/wmt17_en_de
Myle Ott's avatar
Myle Ott committed
138
$ fairseq-preprocess --source-lang en --target-lang de \
139
  --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
140
  --destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0
141
142
143
144

# Train the model:
# If it runs out of memory, try to set --max-tokens 1500 instead
$ mkdir -p checkpoints/fconv_wmt_en_de
145
$ fairseq-train data-bin/wmt17_en_de \
146
  --lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
Runqi Yang's avatar
Runqi Yang committed
147
148
  --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
  --lr-scheduler fixed --force-anneal 50 \
149
150
151
  --arch fconv_wmt_en_de --save-dir checkpoints/fconv_wmt_en_de

# Generate:
152
$ fairseq-generate data-bin/wmt17_en_de \
153
154
155
156
  --path checkpoints/fconv_wmt_en_de/checkpoint_best.pt --beam 5 --remove-bpe

```

157
### prepare-wmt14en2fr.sh
158

Sergey Edunov's avatar
Sergey Edunov committed
159
Provides an example of pre-processing for the WMT'14 English to French translation task.
160
161
162
163

Example usage:

```
164
$ cd examples/translation/
165
$ bash prepare-wmt14en2fr.sh
166
$ cd ../..
167
168

# Binarize the dataset:
169
$ TEXT=examples/translation/wmt14_en_fr
Myle Ott's avatar
Myle Ott committed
170
$ fairseq-preprocess --source-lang en --target-lang fr \
171
172
173
174
175
176
  --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
  --destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0

# Train the model:
# If it runs out of memory, try to set --max-tokens 1000 instead
$ mkdir -p checkpoints/fconv_wmt_en_fr
Myle Ott's avatar
Myle Ott committed
177
$ fairseq-train data-bin/wmt14_en_fr \
178
  --lr 0.5 --clip-norm 0.1 --dropout 0.1 --max-tokens 3000 \
Runqi Yang's avatar
Runqi Yang committed
179
180
  --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
  --lr-scheduler fixed --force-anneal 50 \
181
182
183
  --arch fconv_wmt_en_fr --save-dir checkpoints/fconv_wmt_en_fr

# Generate:
Myle Ott's avatar
Myle Ott committed
184
$ fairseq-generate data-bin/fconv_wmt_en_fr \
185
186
187
  --path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe

```
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

## Multilingual Translation

We also support training multilingual translation models. In this example we'll
train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets.

Note that we use slightly different preprocessing here than for the IWSLT'14
En-De data above. In particular we learn a joint BPE code for all three
languages and use interactive.py and sacrebleu for scoring the test set.

```
# First install sacrebleu and sentencepiece
$ pip install sacrebleu sentencepiece

# Then download and preprocess the data
$ cd examples/translation/
$ bash prepare-iwslt17-multilingual.sh
$ cd ../..

# Binarize the de-en dataset
$ TEXT=examples/translation/iwslt17.de_fr.en.bpe16k
$ fairseq-preprocess --source-lang de --target-lang en \
  --trainpref $TEXT/train.bpe.de-en --validpref $TEXT/valid.bpe.de-en \
  --joined-dictionary \
  --destdir data-bin/iwslt17.de_fr.en.bpe16k \
  --workers 10

# Binarize the fr-en dataset
# NOTE: it's important to reuse the en dictionary from the previous step
$ fairseq-preprocess --source-lang fr --target-lang en \
  --trainpref $TEXT/train.bpe.fr-en --validpref $TEXT/valid.bpe.fr-en \
  --joined-dictionary --tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \
  --destdir data-bin/iwslt17.de_fr.en.bpe16k \
  --workers 10

# Train a multilingual transformer model
# NOTE: the command below assumes 1 GPU, but accumulates gradients from
#       8 fwd/bwd passes to simulate training on 8 GPUs
$ mkdir -p checkpoints/multilingual_transformer
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \
  --max-epoch 50 \
  --ddp-backend=no_c10d \
  --task multilingual_translation --lang-pairs de-en,fr-en \
  --arch multilingual_transformer_iwslt_de_en \
  --share-decoders --share-decoder-input-output-embed \
233
  --optimizer adam --adam-betas '(0.9, 0.98)' \
234
235
  --lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \
  --warmup-updates 4000 --warmup-init-lr '1e-07' \
236
  --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
  --dropout 0.3 --weight-decay 0.0001 \
  --save-dir checkpoints/multilingual_transformer \
  --max-tokens 4000 \
  --update-freq 8

# Generate and score the test set with sacrebleu
$ SRC=de
$ sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
  | python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
  > iwslt17.test.${SRC}-en.${SRC}.bpe
$ cat iwslt17.test.${SRC}-en.${SRC}.bpe | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
  --task multilingual_translation --source-lang ${SRC} --target-lang en \
  --path checkpoints/multilingual_transformer/checkpoint_best.pt \
  --buffer 2000 --batch-size 128 \
  --beam 5 --remove-bpe=sentencepiece \
  > iwslt17.test.${SRC}-en.en.sys
$ grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
  | sacrebleu --test-set iwslt17 --language-pair ${SRC}-en
```
256
257
258
259
260
261

### Argument format during inference
During inference it is required to specify a single `--source-lang` and
`--target-lang`, which indicates the inference langauge direction.
`--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to
the same value as training.