Commit 12c90639 authored by “change”'s avatar “change”
Browse files

init

parent 417b607b
# CamemBERT: a Tasty French Language Model
## Introduction
[CamemBERT](https://arxiv.org/abs/1911.03894) is a pretrained language model trained on 138GB of French text based on RoBERTa.
Also available in [github.com/huggingface/transformers](https://github.com/huggingface/transformers/).
## Pre-trained models
| Model | #params | Download | Arch. | Training data |
|--------------------------------|---------|--------------------------------------------------------------------------------------------------------------------------|-------|-----------------------------------|
| `camembert` / `camembert-base` | 110M | [camembert-base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz) | Base | OSCAR (138 GB of text) |
| `camembert-large` | 335M | [camembert-large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz) | Large | CCNet (135 GB of text) |
| `camembert-base-ccnet` | 110M | [camembert-base-ccnet.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz) | Base | CCNet (135 GB of text) |
| `camembert-base-wikipedia-4gb` | 110M | [camembert-base-wikipedia-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz) | Base | Wikipedia (4 GB of text) |
| `camembert-base-oscar-4gb` | 110M | [camembert-base-oscar-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz) | Base | Subsample of OSCAR (4 GB of text) |
| `camembert-base-ccnet-4gb` | 110M | [camembert-base-ccnet-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz) | Base | Subsample of CCNet (4 GB of text) |
## Example usage
### fairseq
##### Load CamemBERT from torch.hub (PyTorch >= 1.1):
```python
import torch
camembert = torch.hub.load('pytorch/fairseq', 'camembert')
camembert.eval() # disable dropout (or leave in train mode to finetune)
```
##### Load CamemBERT (for PyTorch 1.0 or custom models):
```python
# Download camembert model
wget https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz
tar -xzvf camembert.tar.gz
# Load the model in fairseq
from fairseq.models.roberta import CamembertModel
camembert = CamembertModel.from_pretrained('/path/to/camembert')
camembert.eval() # disable dropout (or leave in train mode to finetune)
```
##### Filling masks:
```python
masked_line = 'Le camembert est <mask> :)'
camembert.fill_mask(masked_line, topk=3)
# [('Le camembert est délicieux :)', 0.4909118115901947, ' délicieux'),
# ('Le camembert est excellent :)', 0.10556942224502563, ' excellent'),
# ('Le camembert est succulent :)', 0.03453322499990463, ' succulent')]
```
##### Extract features from Camembert:
```python
# Extract the last layer's features
line = "J'aime le camembert !"
tokens = camembert.encode(line)
last_layer_features = camembert.extract_features(tokens)
assert last_layer_features.size() == torch.Size([1, 10, 768])
# Extract all layer's features (layer 0 is the embedding layer)
all_layers = camembert.extract_features(tokens, return_all_hiddens=True)
assert len(all_layers) == 13
assert torch.all(all_layers[-1] == last_layer_features)
```
## Citation
If you use our work, please cite:
```bibtex
@inproceedings{martin2020camembert,
title={CamemBERT: a Tasty French Language Model},
author={Martin, Louis and Muller, Benjamin and Su{\'a}rez, Pedro Javier Ortiz and Dupont, Yoann and Romary, Laurent and de la Clergerie, {\'E}ric Villemonte and Seddah, Djam{\'e} and Sagot, Beno{\^\i}t},
booktitle={Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
year={2020}
}
```
# (Vectorized) Lexically constrained decoding with dynamic beam allocation
This page provides instructions for how to use lexically constrained decoding in Fairseq.
Fairseq implements the code described in the following papers:
* [Fast Lexically Constrained Decoding With Dynamic Beam Allocation](https://www.aclweb.org/anthology/N18-1119/) (Post & Vilar, 2018)
* [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://www.aclweb.org/anthology/N19-1090/) (Hu et al., 2019)
## Quick start
Constrained search is enabled by adding the command-line argument `--constraints` to `fairseq-interactive`.
Constraints are appended to each line of input, separated by tabs. Each constraint (one or more tokens)
is a separate field.
The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md),
translates the sentence *Die maschinelle Übersetzung ist schwer zu kontrollieren.* with the constraints
"hard" and "to influence".
echo -e "Die maschinelle Übersetzung ist schwer zu kontrollieren.\thard\ttoinfluence" \
| normalize.py | tok.py \
| fairseq-interactive /path/to/model \
--path /path/to/model/model1.pt \
--bpe fastbpe \
--bpe-codes /path/to/model/bpecodes \
--constraints \
-s de -t en \
--beam 10
(tok.py and normalize.py can be found in the same directory as this README; they are just shortcuts around Fairseq's WMT19 preprocessing).
This will generate the following output:
[snip]
S-0 Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren .
W-0 1.844 seconds
C-0 hard
C-0 influence
H-0 -1.5333266258239746 Mach@@ ine trans@@ lation is hard to influence .
D-0 -1.5333266258239746 Machine translation is hard to influence .
P-0 -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.8031 -0.1701 -11.7727 -0.1815 -0.1511
By default, constraints are generated in the order supplied, with any number (zero or more) of tokens generated
between constraints. If you wish for the decoder to order the constraints, then use `--constraints unordered`.
Note that you may want to use a larger beam.
## Implementation details
The heart of the implementation is in `fairseq/search.py`, which adds a `LexicallyConstrainedBeamSearch` instance.
This instance of beam search tracks the progress of each hypothesis in the beam through the set of constraints
provided for each input sentence. It does this using one of two classes, both found in `fairseq/token_generation_contstraints.py`:
* OrderedConstraintState: assumes the `C` input constraints will be generated in the provided order
* UnorderedConstraintState: tries to apply `C` (phrasal) constraints in all `C!` orders
## Differences from Sockeye
There are a number of [differences from Sockeye's implementation](https://awslabs.github.io/sockeye/inference.html#lexical-constraints).
* Generating constraints in the order supplied (the default option here) is not available in Sockeye.
* Due to an improved beam allocation method, there is no need to prune the beam.
* Again due to better allocation, beam sizes as low as 10 or even 5 are often sufficient.
* [The vector extensions described in Hu et al.](https://github.com/edwardjhu/sockeye/tree/trie_constraints) (NAACL 2019) were never merged
into the main Sockeye branch.
## Citation
The paper first describing lexical constraints for seq2seq decoding is:
```bibtex
@inproceedings{hokamp-liu-2017-lexically,
title = "Lexically Constrained Decoding for Sequence Generation Using Grid Beam Search",
author = "Hokamp, Chris and
Liu, Qun",
booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
month = jul,
year = "2017",
address = "Vancouver, Canada",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/P17-1141",
doi = "10.18653/v1/P17-1141",
pages = "1535--1546",
}
```
The fairseq implementation uses the extensions described in
```bibtex
@inproceedings{post-vilar-2018-fast,
title = "Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation",
author = "Post, Matt and
Vilar, David",
booktitle = "Proceedings of the 2018 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers)",
month = jun,
year = "2018",
address = "New Orleans, Louisiana",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/N18-1119",
doi = "10.18653/v1/N18-1119",
pages = "1314--1324",
}
```
and
```bibtex
@inproceedings{hu-etal-2019-improved,
title = "Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting",
author = "Hu, J. Edward and
Khayrallah, Huda and
Culkin, Ryan and
Xia, Patrick and
Chen, Tongfei and
Post, Matt and
Van Durme, Benjamin",
booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)",
month = jun,
year = "2019",
address = "Minneapolis, Minnesota",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/N19-1090",
doi = "10.18653/v1/N19-1090",
pages = "839--850",
}
```
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
from sacremoses.normalize import MosesPunctNormalizer
def main(args):
normalizer = MosesPunctNormalizer(lang=args.lang, penn=args.penn)
for line in sys.stdin:
print(normalizer.normalize(line.rstrip()), flush=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--lang", "-l", default="en")
parser.add_argument("--penn", "-p", action="store_true")
args = parser.parse_args()
main(args)
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import sacremoses
def main(args):
"""Tokenizes, preserving tabs"""
mt = sacremoses.MosesTokenizer(lang=args.lang)
def tok(s):
return mt.tokenize(s, return_str=True)
for line in sys.stdin:
parts = list(map(tok, line.split("\t")))
print(*parts, sep="\t", flush=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--lang", "-l", default="en")
parser.add_argument("--penn", "-p", action="store_true")
parser.add_argument("--fields", "-f", help="fields to tokenize")
args = parser.parse_args()
main(args)
# Convolutional Sequence to Sequence Learning (Gehring et al., 2017)
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
## Example usage
See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and
WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures.
## Citation
```bibtex
@inproceedings{gehring2017convs2s,
title = {Convolutional Sequence to Sequence Learning},
author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N},
booktitle = {Proc. of ICML},
year = 2017,
}
```
# Cross-lingual Retrieval for Iterative Self-Supervised Training
https://arxiv.org/pdf/2006.09526.pdf
## Introduction
CRISS is a multilingual sequence-to-sequnce pretraining method where mining and training processes are applied iteratively, improving cross-lingual alignment and translation ability at the same time.
## Requirements:
* faiss: https://github.com/facebookresearch/faiss
* mosesdecoder: https://github.com/moses-smt/mosesdecoder
* flores: https://github.com/facebookresearch/flores
* LASER: https://github.com/facebookresearch/LASER
## Unsupervised Machine Translation
##### 1. Download and decompress CRISS checkpoints
```
cd examples/criss
wget https://dl.fbaipublicfiles.com/criss/criss_3rd_checkpoints.tar.gz
tar -xf criss_checkpoints.tar.gz
```
##### 2. Download and preprocess Flores test dataset
Make sure to run all scripts from examples/criss directory
```
bash download_and_preprocess_flores_test.sh
```
##### 3. Run Evaluation on Sinhala-English
```
bash unsupervised_mt/eval.sh
```
## Sentence Retrieval
##### 1. Download and preprocess Tatoeba dataset
```
bash download_and_preprocess_tatoeba.sh
```
##### 2. Run Sentence Retrieval on Tatoeba Kazakh-English
```
bash sentence_retrieval/sentence_retrieval_tatoeba.sh
```
## Mining
##### 1. Install faiss
Follow instructions on https://github.com/facebookresearch/faiss/blob/master/INSTALL.md
##### 2. Mine pseudo-parallel data between Kazakh and English
```
bash mining/mine_example.sh
```
## Citation
```bibtex
@article{tran2020cross,
title={Cross-lingual retrieval for iterative self-supervised training},
author={Tran, Chau and Tang, Yuqing and Li, Xian and Gu, Jiatao},
journal={arXiv preprint arXiv:2006.09526},
year={2020}
}
```
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
SPM_ENCODE=flores/scripts/spm_encode.py
DATA=data_tmp
SPM_MODEL=criss_checkpoints/sentence.bpe.model
DICT=criss_checkpoints/dict.txt
download_data() {
CORPORA=$1
URL=$2
if [ -f $CORPORA ]; then
echo "$CORPORA already exists, skipping download"
else
echo "Downloading $URL"
wget $URL -O $CORPORA --no-check-certificate || rm -f $CORPORA
if [ -f $CORPORA ]; then
echo "$URL successfully downloaded."
else
echo "$URL not successfully downloaded."
rm -f $CORPORA
fi
fi
}
if [[ -f flores ]]; then
echo "flores already cloned"
else
git clone https://github.com/facebookresearch/flores
fi
mkdir -p $DATA
download_data $DATA/wikipedia_en_ne_si_test_sets.tgz "https://github.com/facebookresearch/flores/raw/master/data/wikipedia_en_ne_si_test_sets.tgz"
pushd $DATA
pwd
tar -vxf wikipedia_en_ne_si_test_sets.tgz
popd
for lang in ne_NP si_LK; do
datadir=$DATA/${lang}-en_XX-flores
rm -rf $datadir
mkdir -p $datadir
TEST_PREFIX=$DATA/wikipedia_en_ne_si_test_sets/wikipedia.test
python $SPM_ENCODE \
--model ${SPM_MODEL} \
--output_format=piece \
--inputs ${TEST_PREFIX}.${lang:0:2}-en.${lang:0:2} ${TEST_PREFIX}.${lang:0:2}-en.en \
--outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
# binarize data
fairseq-preprocess \
--source-lang ${lang} --target-lang en_XX \
--testpref $datadir/test.bpe.${lang}-en_XX \
--destdir $datadir \
--srcdict ${DICT} \
--joined-dictionary \
--workers 4
done
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
SPM_ENCODE=flores/scripts/spm_encode.py
DATA=data_tmp
SPM_MODEL=criss_checkpoints/sentence.bpe.model
DICT=criss_checkpoints/dict.txt
if [[ -f flores ]]; then
echo "flores already cloned"
else
git clone https://github.com/facebookresearch/flores
fi
if [[ -f LASER ]]; then
echo "LASER already cloned"
else
git clone https://github.com/facebookresearch/LASER
fi
mkdir -p data_tmp
declare -A lang_tatoeba_map=( ["ar_AR"]="ara" ["de_DE"]="deu" ["es_XX"]="spa" ["et_EE"]="est" ["fi_FI"]="fin" ["fr_XX"]="fra" ["hi_IN"]="hin" ["it_IT"]="ita" ["ja_XX"]="jpn" ["ko_KR"]="kor" ["kk_KZ"]="kaz" ["nl_XX"]="nld" ["ru_RU"]="rus" ["tr_TR"]="tur" ["vi_VN"]="vie" ["zh_CN"]="cmn")
for lang in ar_AR de_DE es_XX et_EE fi_FI fr_XX hi_IN it_IT ja_XX kk_KZ ko_KR nl_XX ru_RU tr_TR vi_VN zh_CN; do
lang_tatoeba=${lang_tatoeba_map[$lang]}
echo $lang_tatoeba
datadir=$DATA/${lang}-en_XX-tatoeba
rm -rf $datadir
mkdir -p $datadir
TEST_PREFIX=LASER/data/tatoeba/v1/tatoeba
python $SPM_ENCODE \
--model ${SPM_MODEL} \
--output_format=piece \
--inputs ${TEST_PREFIX}.${lang_tatoeba}-eng.${lang_tatoeba} ${TEST_PREFIX}.${lang_tatoeba}-eng.eng \
--outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
# binarize data
fairseq-preprocess \
--source-lang ${lang} --target-lang en_XX \
--testpref $datadir/test.bpe.${lang}-en_XX \
--destdir $datadir \
--srcdict ${DICT} \
--joined-dictionary \
--workers 4
done
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import glob
from subprocess import check_call
try:
import faiss
has_faiss = True
except ImportError:
has_faiss = False
import numpy as np
GB = 1024 * 1024 * 1024
def call(cmd):
print(cmd)
check_call(cmd, shell=True)
def get_batches(directory, lang, prefix="all_avg_pool"):
print(f"Finding in {directory}/{prefix}.{lang}*")
files = glob.glob(f"{directory}/{prefix}.{lang}*")
emb_files = []
txt_files = []
for emb_fi in files:
emb_files.append(emb_fi)
txt_fi = emb_fi.replace(prefix, "sentences")
txt_files.append(txt_fi)
return emb_files, txt_files
def load_batch(emb_file, dim):
embeddings = np.fromfile(emb_file, dtype=np.float32)
num_rows = int(embeddings.shape[0] / dim)
embeddings = embeddings.reshape((num_rows, dim))
faiss.normalize_L2(embeddings)
return embeddings
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
if not has_faiss:
raise ImportError("Please install Faiss")
sims = []
inds = []
xfrom = 0
xto = 0
for x_batch_f in x_batches_f:
yfrom = 0
yto = 0
x_batch = load_batch(x_batch_f, dim)
xto = xfrom + x_batch.shape[0]
bsims, binds = [], []
for y_batch_f in y_batches_f:
y_batch = load_batch(y_batch_f, dim)
neighbor_size = min(k, y_batch.shape[0])
yto = yfrom + y_batch.shape[0]
print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
idx = faiss.IndexFlatIP(dim)
idx = faiss.index_cpu_to_all_gpus(idx)
idx.add(y_batch)
bsim, bind = idx.search(x_batch, neighbor_size)
bsims.append(bsim)
binds.append(bind + yfrom)
yfrom += y_batch.shape[0]
del idx
del y_batch
bsims = np.concatenate(bsims, axis=1)
binds = np.concatenate(binds, axis=1)
aux = np.argsort(-bsims, axis=1)
sim_batch = np.zeros((x_batch.shape[0], k), dtype=np.float32)
ind_batch = np.zeros((x_batch.shape[0], k), dtype=np.int64)
for i in range(x_batch.shape[0]):
for j in range(k):
sim_batch[i, j] = bsims[i, aux[i, j]]
ind_batch[i, j] = binds[i, aux[i, j]]
sims.append(sim_batch)
inds.append(ind_batch)
xfrom += x_batch.shape[0]
del x_batch
sim = np.concatenate(sims, axis=0)
ind = np.concatenate(inds, axis=0)
return sim, ind
def score(sim, fwd_mean, bwd_mean, margin):
return margin(sim, (fwd_mean + bwd_mean) / 2)
def score_candidates(
sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
):
print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
scores = np.zeros(candidate_inds.shape)
for i in range(scores.shape[0]):
for j in range(scores.shape[1]):
k = int(candidate_inds[i, j])
scores[i, j] = score(sim_mat[i, j], fwd_mean[i], bwd_mean[k], margin)
return scores
def load_text(files):
all_sentences = []
for fi in files:
with open(fi) as sentence_fi:
for line in sentence_fi:
all_sentences.append(line.strip())
print(f"Read {len(all_sentences)} sentences")
return all_sentences
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mine bitext")
parser.add_argument("--src-lang", help="Source language")
parser.add_argument("--tgt-lang", help="Target language")
parser.add_argument(
"--dict-path", help="Path to dictionary file", default="dict.txt"
)
parser.add_argument(
"--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
)
parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
parser.add_argument("--src-dir", help="Source directory")
parser.add_argument("--tgt-dir", help="Target directory")
parser.add_argument("--output", help="Output path")
parser.add_argument(
"--neighborhood", type=int, default=4, help="Embedding dimension"
)
parser.add_argument(
"--threshold", type=float, default=1.06, help="Threshold on mined bitext"
)
parser.add_argument(
"--valid-size",
type=int,
default=2000,
help="Number of sentences used for validation set",
)
parser.add_argument(
"--min-count",
type=int,
default=50000,
help="Min num sentences used for each language",
)
args = parser.parse_args()
x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
margin = lambda a, b: a / b
y2x_sim, y2x_ind = knnGPU_sharded(
y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
)
x2y_sim, x2y_ind = knnGPU_sharded(
x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
)
x2y_mean = x2y_sim.mean(axis=1)
y2x_mean = y2x_sim.mean(axis=1)
fwd_scores = score_candidates(x2y_sim, x2y_ind, x2y_mean, y2x_mean, margin)
bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
indices = np.stack(
(
np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
),
axis=1,
)
scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
x_sentences = load_text(x_sents_f)
y_sentences = load_text(y_sents_f)
threshold = args.threshold
min_count = args.min_count
seen_src, seen_trg = set(), set()
directory = args.output
call(f"mkdir -p {directory}")
src_out = open(
f"{directory}/all.{args.src_lang}",
mode="w",
encoding="utf-8",
errors="surrogateescape",
)
tgt_out = open(
f"{directory}/all.{args.tgt_lang}",
mode="w",
encoding="utf-8",
errors="surrogateescape",
)
scores_out = open(
f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
)
count = 0
for i in np.argsort(-scores):
src_ind, trg_ind = indices[i]
if src_ind not in seen_src and trg_ind not in seen_trg:
seen_src.add(src_ind)
seen_trg.add(trg_ind)
if scores[i] > threshold or count < min_count:
if x_sentences[src_ind]:
print(scores[i], file=scores_out)
print(x_sentences[src_ind], file=src_out)
print(y_sentences[trg_ind], file=tgt_out)
count += 1
else:
print(f"Ignoring sentence: {x_sentences[src_ind]}")
src_out.close()
tgt_out.close()
scores_out.close()
print(f"Found {count} pairs for threshold={threshold}")
with open(f"{directory}/all.{args.src_lang}") as all_s, open(
f"{directory}/all.{args.tgt_lang}"
) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
f"{directory}/valid.{args.tgt_lang}", "w"
) as valid_t, open(
f"{directory}/train.{args.src_lang}", "w"
) as train_s, open(
f"{directory}/train.{args.tgt_lang}", "w"
) as train_t:
count = 0
for s_line, t_line in zip(all_s, all_t):
s_line = s_line.split("\t")[1]
t_line = t_line.split("\t")[1]
if count >= args.valid_size:
train_s.write(s_line)
train_t.write(t_line)
else:
valid_s.write(s_line)
valid_t.write(t_line)
count += 1
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
source_lang=kk_KZ
target_lang=en_XX
MODEL=criss_checkpoints/criss.3rd.pt
SPM=criss_checkpoints/sentence.bpe.model
SPLIT=test
LANG_DICT=criss_checkpoints/lang_dict.txt
SPM_ENCODE=flores/scripts/spm_encode.py
SAVE_ENCODER=save_encoder.py
ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL
DICT=criss_checkpoints/dict.txt
THRESHOLD=1.02
MIN_COUNT=500
DATA_DIR=data_tmp
SAVE_DIR=mining/${source_lang}_${target_lang}_mined
ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang}
INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba
mkdir -p $ENCODER_SAVE_DIR/${target_lang}
mkdir -p $ENCODER_SAVE_DIR/${source_lang}
mkdir -p $SAVE_DIR
## Save encoder outputs
# Save encoder outputs for source sentences
python $SAVE_ENCODER \
${INPUT_DIR} \
--path ${MODEL} \
--task translation_multi_simple_epoch \
--lang-pairs ${source_lang}-${target_lang} \
--lang-dict ${LANG_DICT} \
--gen-subset ${SPLIT} \
--bpe 'sentencepiece' \
-s ${source_lang} -t ${target_lang} \
--sentencepiece-model ${SPM} \
--remove-bpe 'sentencepiece' \
--beam 1 \
--lang-tok-style mbart \
--encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang}
## Save encoder outputs for target sentences
python $SAVE_ENCODER \
${INPUT_DIR} \
--path ${MODEL} \
--lang-pairs ${source_lang}-${target_lang} \
--lang-dict ${LANG_DICT} \
--task translation_multi_simple_epoch \
--gen-subset ${SPLIT} \
--bpe 'sentencepiece' \
-t ${source_lang} -s ${target_lang} \
--sentencepiece-model ${SPM} \
--remove-bpe 'sentencepiece' \
--beam 1 \
--lang-tok-style mbart \
--encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang}
## Mining
python mining/mine.py \
--src-lang ${source_lang} \
--tgt-lang ${target_lang} \
--dim 1024 \
--mem 10 \
--neighborhood 4 \
--src-dir ${ENCODER_SAVE_DIR}/${source_lang} \
--tgt-dir ${ENCODER_SAVE_DIR}/${target_lang} \
--output $SAVE_DIR \
--threshold ${THRESHOLD} \
--min-count ${MIN_COUNT} \
--valid-size 100 \
--dict-path ${DICT} \
--spm-path ${SPM} \
## Process and binarize mined data
python $SPM_ENCODE \
--model ${SPM} \
--output_format=piece \
--inputs mining/${source_lang}_${target_lang}_mined/train.${source_lang} mining/${source_lang}_${target_lang}_mined/train.${target_lang} \
--outputs mining/${source_lang}_${target_lang}_mined/train.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/train.bpe.${target_lang}
python $SPM_ENCODE \
--model ${SPM} \
--output_format=piece \
--inputs mining/${source_lang}_${target_lang}_mined/valid.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.${target_lang} \
--outputs mining/${source_lang}_${target_lang}_mined/valid.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.bpe.${target_lang}
fairseq-preprocess \
--source-lang ${source_lang} \
--target-lang ${target_lang} \
--trainpref mining/${source_lang}_${target_lang}_mined/train.bpe \
--validpref mining/${source_lang}_${target_lang}_mined/valid.bpe \
--destdir mining/${source_lang}_${target_lang}_mined \
--srcdict ${DICT} \
--joined-dictionary \
--workers 8
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate pre-processed data with a trained model.
"""
import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.sequence_generator import EnsembleModel
def get_avg_pool(
models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
):
model = EnsembleModel(models)
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)
np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
np.float32
)
encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
if has_langtok:
encoder_mask = encoder_mask[1:, :, :]
np_encoder_outs = np_encoder_outs[1, :, :]
masked_encoder_outs = encoder_mask * np_encoder_outs
avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0)
return avg_pool
def main(args):
assert args.path is not None, "--path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
args.beam = 1
utils.import_user_module(args)
if args.max_tokens is None:
args.max_tokens = 12000
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
# Set dictionaries
try:
src_dict = getattr(task, "source_dictionary", None)
except NotImplementedError:
src_dict = None
tgt_dict = task.target_dictionary
# Load ensemble
print("| loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(":"),
arg_overrides=eval(args.model_overrides),
task=task,
)
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_positions=utils.resolve_max_positions(
task.max_positions(),
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
num_sentences = 0
source_sentences = []
shard_id = 0
all_avg_pool = None
encoder_has_langtok = (
hasattr(task.args, "encoder_langtok")
and task.args.encoder_langtok is not None
and hasattr(task.args, "lang_tok_replacing_bos_eos")
and not task.args.lang_tok_replacing_bos_eos
)
with progress_bar.build_progress_bar(args, itr) as t:
for sample in t:
if sample is None:
print("Skipping None")
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]
with torch.no_grad():
avg_pool = get_avg_pool(
models,
sample,
prefix_tokens,
src_dict,
args.post_process,
has_langtok=encoder_has_langtok,
)
if all_avg_pool is not None:
all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
else:
all_avg_pool = avg_pool
if not isinstance(sample["id"], list):
sample_ids = sample["id"].tolist()
else:
sample_ids = sample["id"]
for i, sample_id in enumerate(sample_ids):
# Remove padding
src_tokens = utils.strip_pad(
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
)
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(
sample_id
)
else:
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.post_process)
else:
src_str = ""
if not args.quiet:
if src_dict is not None:
print("S-{}\t{}".format(sample_id, src_str))
source_sentences.append(f"{sample_id}\t{src_str}")
num_sentences += sample["nsentences"]
if all_avg_pool.shape[0] >= 1000000:
with open(
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
"w",
) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file)
with open(
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
"w",
) as sentence_file:
sentence_file.writelines(f"{line}\n" for line in source_sentences)
all_avg_pool = None
source_sentences = []
shard_id += 1
if all_avg_pool is not None:
with open(
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file)
with open(
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
) as sentence_file:
sentence_file.writelines(f"{line}\n" for line in source_sentences)
return None
def cli_main():
parser = options.get_generation_parser()
parser.add_argument(
"--encoder-save-dir",
default="",
type=str,
metavar="N",
help="directory to save encoder outputs",
)
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import glob
import numpy as np
DIM = 1024
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
target_ids = [tid for tid in target_embs]
source_mat = np.stack(source_embs.values(), axis=0)
normalized_source_mat = source_mat / np.linalg.norm(
source_mat, axis=1, keepdims=True
)
target_mat = np.stack(target_embs.values(), axis=0)
normalized_target_mat = target_mat / np.linalg.norm(
target_mat, axis=1, keepdims=True
)
sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
if return_sim_mat:
return sim_mat
neighbors_map = {}
for i, sentence_id in enumerate(source_embs):
idx = np.argsort(sim_mat[i, :])[::-1][:k]
neighbors_map[sentence_id] = [target_ids[tid] for tid in idx]
return neighbors_map
def load_embeddings(directory, LANGS):
sentence_embeddings = {}
sentence_texts = {}
for lang in LANGS:
sentence_embeddings[lang] = {}
sentence_texts[lang] = {}
lang_dir = f"{directory}/{lang}"
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
for embed_file in embedding_files:
shard_id = embed_file.split(".")[-1]
embeddings = np.fromfile(embed_file, dtype=np.float32)
num_rows = embeddings.shape[0] // DIM
embeddings = embeddings.reshape((num_rows, DIM))
with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
for idx, line in enumerate(sentence_file):
sentence_id, sentence = line.strip().split("\t")
sentence_texts[lang][sentence_id] = sentence
sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
return sentence_embeddings, sentence_texts
def compute_accuracy(directory, LANGS):
sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS)
top_1_accuracy = {}
top1_str = " ".join(LANGS) + "\n"
for source_lang in LANGS:
top_1_accuracy[source_lang] = {}
top1_str += f"{source_lang} "
for target_lang in LANGS:
top1 = 0
top5 = 0
neighbors_map = compute_dist(
sentence_embeddings[source_lang], sentence_embeddings[target_lang]
)
for sentence_id, neighbors in neighbors_map.items():
if sentence_id == neighbors[0]:
top1 += 1
if sentence_id in neighbors[:5]:
top5 += 1
n = len(sentence_embeddings[target_lang])
top1_str += f"{top1/n} "
top1_str += "\n"
print(top1_str)
print(top1_str, file=open(f"{directory}/accuracy", "w"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Analyze encoder outputs")
parser.add_argument("directory", help="Source language corpus")
parser.add_argument("--langs", help="List of langs")
args = parser.parse_args()
langs = args.langs.split(",")
compute_accuracy(args.directory, langs)
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
source_lang=kk_KZ
target_lang=en_XX
MODEL=criss_checkpoints/criss.3rd.pt
SPM=criss_checkpoints/sentence.bpe.model
SPLIT=test
LANG_DICT=criss_checkpoints/lang_dict.txt
ENCODER_ANALYSIS=sentence_retrieval/encoder_analysis.py
SAVE_ENCODER=save_encoder.py
ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL
DATA_DIR=data_tmp
INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba
ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang}
mkdir -p $ENCODER_SAVE_DIR/${target_lang}
mkdir -p $ENCODER_SAVE_DIR/${source_lang}
# Save encoder outputs for source sentences
python $SAVE_ENCODER \
${INPUT_DIR} \
--path ${MODEL} \
--task translation_multi_simple_epoch \
--lang-dict ${LANG_DICT} \
--gen-subset ${SPLIT} \
--bpe 'sentencepiece' \
--lang-pairs ${source_lang}-${target_lang} \
-s ${source_lang} -t ${target_lang} \
--sentencepiece-model ${SPM} \
--remove-bpe 'sentencepiece' \
--beam 1 \
--lang-tok-style mbart \
--encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang}
# Save encoder outputs for target sentences
python $SAVE_ENCODER \
${INPUT_DIR} \
--path ${MODEL} \
--lang-dict ${LANG_DICT} \
--task translation_multi_simple_epoch \
--gen-subset ${SPLIT} \
--bpe 'sentencepiece' \
--lang-pairs ${target_lang}-${source_lang} \
-t ${source_lang} -s ${target_lang} \
--sentencepiece-model ${SPM} \
--remove-bpe 'sentencepiece' \
--beam 1 \
--lang-tok-style mbart \
--encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang}
# Analyze sentence retrieval accuracy
python $ENCODER_ANALYSIS --langs "${source_lang},${target_lang}" ${ENCODER_SAVE_DIR}
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
SRC=si_LK
TGT=en_XX
MODEL=criss_checkpoints/criss.3rd.pt
MULTIBLEU=mosesdecoder/scripts/generic/multi-bleu.perl
MOSES=mosesdecoder
REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl
NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl
TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl
GEN_TMP_DIR=gen_tmp
LANG_DICT=criss_checkpoints/lang_dict.txt
if [ ! -d "mosesdecoder" ]; then
git clone https://github.com/moses-smt/mosesdecoder
fi
mkdir -p $GEN_TMP_DIR
fairseq-generate data_tmp/${SRC}-${TGT}-flores \
--task translation_multi_simple_epoch \
--max-tokens 2000 \
--path ${MODEL} \
--skip-invalid-size-inputs-valid-test \
--beam 5 --lenpen 1.0 --gen-subset test \
--remove-bpe=sentencepiece \
--source-lang ${SRC} --target-lang ${TGT} \
--decoder-langtok --lang-pairs 'en_XX-ar_AR,en_XX-de_DE,en_XX-es_XX,en_XX-fr_XX,en_XX-hi_IN,en_XX-it_IT,en_XX-ja_XX,en_XX-ko_KR,en_XX-nl_XX,en_XX-ru_RU,en_XX-zh_CN,en_XX-tr_TR,en_XX-vi_VN,en_XX-ro_RO,en_XX-my_MM,en_XX-ne_NP,en_XX-si_LK,en_XX-cs_CZ,en_XX-lt_LT,en_XX-kk_KZ,en_XX-gu_IN,en_XX-fi_FI,en_XX-et_EE,en_XX-lv_LV,ar_AR-en_XX,cs_CZ-en_XX,de_DE-en_XX,es_XX-en_XX,et_EE-en_XX,fi_FI-en_XX,fr_XX-en_XX,gu_IN-en_XX,hi_IN-en_XX,it_IT-en_XX,ja_XX-en_XX,kk_KZ-en_XX,ko_KR-en_XX,lt_LT-en_XX,lv_LV-en_XX,my_MM-en_XX,ne_NP-en_XX,nl_XX-en_XX,ro_RO-en_XX,ru_RU-en_XX,si_LK-en_XX,tr_TR-en_XX,vi_VN-en_XX,zh_CN-en_XX,ar_AR-es_XX,es_XX-ar_AR,ar_AR-hi_IN,hi_IN-ar_AR,ar_AR-zh_CN,zh_CN-ar_AR,cs_CZ-es_XX,es_XX-cs_CZ,cs_CZ-hi_IN,hi_IN-cs_CZ,cs_CZ-zh_CN,zh_CN-cs_CZ,de_DE-es_XX,es_XX-de_DE,de_DE-hi_IN,hi_IN-de_DE,de_DE-zh_CN,zh_CN-de_DE,es_XX-hi_IN,hi_IN-es_XX,es_XX-zh_CN,zh_CN-es_XX,et_EE-es_XX,es_XX-et_EE,et_EE-hi_IN,hi_IN-et_EE,et_EE-zh_CN,zh_CN-et_EE,fi_FI-es_XX,es_XX-fi_FI,fi_FI-hi_IN,hi_IN-fi_FI,fi_FI-zh_CN,zh_CN-fi_FI,fr_XX-es_XX,es_XX-fr_XX,fr_XX-hi_IN,hi_IN-fr_XX,fr_XX-zh_CN,zh_CN-fr_XX,gu_IN-es_XX,es_XX-gu_IN,gu_IN-hi_IN,hi_IN-gu_IN,gu_IN-zh_CN,zh_CN-gu_IN,hi_IN-zh_CN,zh_CN-hi_IN,it_IT-es_XX,es_XX-it_IT,it_IT-hi_IN,hi_IN-it_IT,it_IT-zh_CN,zh_CN-it_IT,ja_XX-es_XX,es_XX-ja_XX,ja_XX-hi_IN,hi_IN-ja_XX,ja_XX-zh_CN,zh_CN-ja_XX,kk_KZ-es_XX,es_XX-kk_KZ,kk_KZ-hi_IN,hi_IN-kk_KZ,kk_KZ-zh_CN,zh_CN-kk_KZ,ko_KR-es_XX,es_XX-ko_KR,ko_KR-hi_IN,hi_IN-ko_KR,ko_KR-zh_CN,zh_CN-ko_KR,lt_LT-es_XX,es_XX-lt_LT,lt_LT-hi_IN,hi_IN-lt_LT,lt_LT-zh_CN,zh_CN-lt_LT,lv_LV-es_XX,es_XX-lv_LV,lv_LV-hi_IN,hi_IN-lv_LV,lv_LV-zh_CN,zh_CN-lv_LV,my_MM-es_XX,es_XX-my_MM,my_MM-hi_IN,hi_IN-my_MM,my_MM-zh_CN,zh_CN-my_MM,ne_NP-es_XX,es_XX-ne_NP,ne_NP-hi_IN,hi_IN-ne_NP,ne_NP-zh_CN,zh_CN-ne_NP,nl_XX-es_XX,es_XX-nl_XX,nl_XX-hi_IN,hi_IN-nl_XX,nl_XX-zh_CN,zh_CN-nl_XX,ro_RO-es_XX,es_XX-ro_RO,ro_RO-hi_IN,hi_IN-ro_RO,ro_RO-zh_CN,zh_CN-ro_RO,ru_RU-es_XX,es_XX-ru_RU,ru_RU-hi_IN,hi_IN-ru_RU,ru_RU-zh_CN,zh_CN-ru_RU,si_LK-es_XX,es_XX-si_LK,si_LK-hi_IN,hi_IN-si_LK,si_LK-zh_CN,zh_CN-si_LK,tr_TR-es_XX,es_XX-tr_TR,tr_TR-hi_IN,hi_IN-tr_TR,tr_TR-zh_CN,zh_CN-tr_TR,vi_VN-es_XX,es_XX-vi_VN,vi_VN-hi_IN,hi_IN-vi_VN,vi_VN-zh_CN,zh_CN-vi_VN' \
--lang-dict ${LANG_DICT} --lang-tok-style 'mbart' --sampling-method 'temperature' --sampling-temperature '1.0' > $GEN_TMP_DIR/${SRC}_${TGT}.gen
cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^T-" | cut -f2 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.hyp
cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^H-" | cut -f3 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.ref
${MULTIBLEU} $GEN_TMP_DIR/${SRC}_${TGT}.ref < $GEN_TMP_DIR/${SRC}_${TGT}.hyp
# Cross-Lingual Language Model Pre-training
Below are some details for training Cross-Lingual Language Models (XLM) - similar to the ones presented in [Lample & Conneau, 2019](https://arxiv.org/pdf/1901.07291.pdf) - in Fairseq. The current implementation only supports the Masked Language Model (MLM) from the paper above.
## Downloading and Tokenizing Monolingual Data
Pointers to the monolingual data from wikipedia, used for training the XLM-style MLM model as well as details on processing (tokenization and BPE) it can be found in the [XLM Github Repository](https://github.com/facebookresearch/XLM#download--preprocess-monolingual-data).
Let's assume the following for the code snippets in later sections to work
- Processed data is in the folder: monolingual_data/processed
- Each language has 3 files for train, test and validation. For example we have the following files for English:
train.en, valid.en
- We are training a model for 5 languages: Arabic (ar), German (de), English (en), Hindi (hi) and French (fr)
- The vocabulary file is monolingual_data/processed/vocab_mlm
## Fairseq Pre-processing and Binarization
Pre-process and binarize the data with the MaskedLMDictionary and cross_lingual_lm task
```bash
# Ensure the output directory exists
DATA_DIR=monolingual_data/fairseq_processed
mkdir -p "$DATA_DIR"
for lg in ar de en hi fr
do
fairseq-preprocess \
--task cross_lingual_lm \
--srcdict monolingual_data/processed/vocab_mlm \
--only-source \
--trainpref monolingual_data/processed/train \
--validpref monolingual_data/processed/valid \
--testpref monolingual_data/processed/test \
--destdir monolingual_data/fairseq_processed \
--workers 20 \
--source-lang $lg
# Since we only have a source language, the output file has a None for the
# target language. Remove this
for stage in train test valid
sudo mv "$DATA_DIR/$stage.$lg-None.$lg.bin" "$stage.$lg.bin"
sudo mv "$DATA_DIR/$stage.$lg-None.$lg.idx" "$stage.$lg.idx"
done
done
```
## Train a Cross-lingual Language Model similar to the XLM MLM model
Use the following command to train the model on 5 languages.
```
fairseq-train \
--task cross_lingual_lm monolingual_data/fairseq_processed \
--save-dir checkpoints/mlm \
--max-update 2400000 --save-interval 1 --no-epoch-checkpoints \
--arch xlm_base \
--optimizer adam --lr-scheduler reduce_lr_on_plateau \
--lr-shrink 0.5 --lr 0.0001 --stop-min-lr 1e-09 \
--dropout 0.1 \
--criterion legacy_masked_lm_loss \
--max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \
--dataset-impl lazy --seed 0 \
--masked-lm-only \
--monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \
--ddp-backend=legacy_ddp
```
Some Notes:
- Using tokens_per_sample greater than 256 can cause OOM (out-of-memory) issues. Usually since MLM packs in streams of text, this parameter doesn't need much tuning.
- The Evaluation workflow for computing MLM Perplexity on test data is in progress.
- Finetuning this model on a downstream task is something which is not currently available.
# Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling
## Introduction
- [Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) introduce a simple and effective noisy channel modeling approach for neural machine translation. However, the noisy channel online decoding approach introduced in this paper is too slow to be practical.
- To address this, [Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 simple approximations to make this approach very fast and practical without much loss in accuracy.
- This README provides intructions on how to run online decoding or generation with the noisy channel modeling approach, including ways to make it very fast without much loss in accuracy.
## Noisy Channel Modeling
[Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) applies the Bayes Rule to predict `P(y|x)`, the probability of the target `y` given the source `x`.
```P(y|x) = P(x|y) * P(y) / P(x)```
- `P(x|y)` predicts the source `x` given the target `y` and is referred to as the **channel model**
- `P(y)` is a **language model** over the target `y`
- `P(x)` is generally not modeled since it is constant for all `y`.
We use Transformer models to parameterize the direct model `P(y|x)`, the channel model `P(x|y)` and the language model `P(y)`.
During online decoding with beam search, we generate the top `K2` candidates per beam and score them with the following linear combination of the channel model, the language model as well as the direct model scores.
```(1 / t) * log(P(y|x) + (1 / s) * ( λ1 * log(P(x|y)) + λ2 * log(P(y) ) )```
- `t` - Target Prefix Length
- `s` - Source Length
- `λ1` - Channel Model Weight
- `λ2` - Language Model Weight
The top `beam_size` candidates based on the above combined scores are chosen to continue the beams in beam search. In beam search with a direct model alone, the scores from the direct model `P(y|x)` are used to choose the top candidates in beam search.
This framework provides a great way to utlize strong target language models trained on large amounts of unlabeled data. Language models can prefer targets unrelated to the source, so we also need a channel model whose role is to ensure that the target preferred by the language model also translates back to the source.
### Training Translation Models and Language Models
For training Transformer models in fairseq for machine translation, refer to instructions [here](https://github.com/pytorch/fairseq/tree/master/examples/translation)
For training Transformer models in fairseq for language modeling, refer to instructions [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model)
### Generation with Language Model for German-English translation with fairseq
Here are instructions to generate using a direct model and a target-side language model.
Note:
- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
```sh
binarized_data=data_dir/binarized
direct_model=de_en_seed4.pt
lm_model=en_lm.pt
lm_data=lm_data
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
mkdir -p ${lm_data}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
k2=10
lenpen=0.16
lm_wt=0.14
fairseq-generate ${binarized_data} \
--user-dir examples/fast_noisy_channel \
--beam 5 \
--path ${direct_model} \
--lm-model ${lm_model} \
--lm-data ${lm_data} \
--k2 ${k2} \
--combine-method lm_only \
--task noisy_channel_translation \
--lenpen ${lenpen} \
--lm-wt ${lm_wt} \
--gen-subset valid \
--remove-bpe \
--fp16 \
--batch-size 10
```
### Noisy Channel Generation for German-English translation with fairseq
Here are instructions for noisy channel generation with a direct model, channel model and language model as explained in section [Noisy Channel Modeling](#noisy-channel-modeling).
Note:
- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
```sh
binarized_data=data_dir/binarized
direct_model=de_en_seed4.pt
lm_model=en_lm.pt
lm_data=lm_data
ch_model=en_de.big.seed4.pt
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
mkdir -p ${lm_data}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt -O ${ch_model}
k2=10
lenpen=0.21
lm_wt=0.50
bw_wt=0.30
fairseq-generate ${binarized_data} \
--user-dir examples/fast_noisy_channel \
--beam 5 \
--path ${direct_model} \
--lm-model ${lm_model} \
--lm-data ${lm_data} \
--channel-model ${ch_model} \
--k2 ${k2} \
--combine-method noisy_channel \
--task noisy_channel_translation \
--lenpen ${lenpen} \
--lm-wt ${lm_wt} \
--ch-wt ${bw_wt} \
--gen-subset test \
--remove-bpe \
--fp16 \
--batch-size 1
```
## Fast Noisy Channel Modeling
[Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 approximations that speed up online noisy channel decoding -
- Smaller channel models (`Tranformer Base` with 1 encoder and decoder layer each vs. `Transformer Big`)
- This involves training a channel model that is possibly smaller and less accurate in terms of BLEU than a channel model of the same size as the direct model.
- Since the role of the channel model is mainly to assign low scores to generations from the language model if they don't translate back to the source, we may not need the most accurate channel model for this purpose.
- Smaller output vocabulary size for the channel model (~30,000 -> ~1000)
- The channel model doesn't need to score the full output vocabulary, it just needs to score the source tokens, which are completely known.
- This is specified using the arguments `--channel-scoring-type src_vocab --top-k-vocab 500`
- This means that the output vocabulary for the channel model will be the source tokens for all examples in the batch and the top-K most frequent tokens in the vocabulary
- This reduces the memory consumption needed to store channel model scores significantly
- Smaller number of candidates (`k2`) scored per beam
- This is specified by reducing the argument `--k2`
### Fast Noisy Channel Generation for German-English translation with fairseq
Here are instructions for **fast** noisy channel generation with a direct model, channel model and language model as explained in section [Fast Noisy Channel Modeling](#fast-noisy-channel-modeling). The main differences are that we use a smaller channel model, reduce `--k2`, set `--channel-scoring-type src_vocab --top-k-vocab 500` and increase the `--batch-size`.
Note:
- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
```sh
binarized_data=data_dir/binarized
direct_model=de_en_seed4.pt
lm_model=en_lm.pt
lm_data=lm_data
small_ch_model=en_de.base_1_1.seed4.pt
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
mkdir -p ${lm_data}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt -O ${small_ch_model}
k2=3
lenpen=0.23
lm_wt=0.58
bw_wt=0.26
fairseq-generate ${binarized_data} \
--user-dir examples/fast_noisy_channel \
--beam 5 \
--path ${direct_model} \
--lm-model ${lm_model} \
--lm-data ${lm_data} \
--channel-model ${small_ch_model} \
--k2 ${k2} \
--combine-method noisy_channel \
--task noisy_channel_translation \
--lenpen ${lenpen} \
--lm-wt ${lm_wt} \
--ch-wt ${bw_wt} \
--gen-subset test \
--remove-bpe \
--fp16 \
--batch-size 50 \
--channel-scoring-type src_vocab --top-k-vocab 500
```
## Test Data Preprocessing
For preprocessing and binarizing the test sets for Romanian-English and German-English translation, we use the following script -
```sh
FAIRSEQ=/path/to/fairseq
cd $FAIRSEQ
SCRIPTS=$FAIRSEQ/mosesdecoder/scripts
if [ ! -d "${SCRIPTS}" ]; then
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
fi
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
NORMALIZE=$SCRIPTS/tokenizer/normalize-punctuation.perl
s=de
t=en
test=wmt18
mkdir -p data_dir
# Tokenization
if [ $s == "ro" ] ; then
# Note: Get normalise-romanian.py and remove-diacritics.py from
# https://github.com/rsennrich/wmt16-scripts/tree/master/preprocess
sacrebleu -t $test -l $s-$t --echo src | \
$NORMALIZE -l $s | \
python normalise-romanian.py | \
python remove-diacritics.py | \
$TOKENIZER -l $s -a -q > data_dir/$test.$s-$t.$s
else
sacrebleu -t $test -l $s-$t --echo src | perl $NORMALIZE -l $s | perl $TOKENIZER -threads 8 -a -l $s > data_dir/$test.$s-$t.$s
fi
sacrebleu -t $test -l $s-$t --echo ref | perl $NORMALIZE -l $t | perl $TOKENIZER -threads 8 -a -l $t > data_dir/$test.$s-$t.$t
# Applying BPE
src_bpe_code=/path/to/source/language/bpe/code
tgt_bpe_code=/path/to/target/language/bpe/code
src_dict=/path/to/source/language/dict
tgt_dict=/path/to/target/language/dict
FASTBPE=$FAIRSEQ/fastBPE
if [ ! -d "${FASTBPE}" ] ; then
git clone https://github.com/glample/fastBPE.git
# Follow compilation instructions at https://github.com/glample/fastBPE
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
fi
${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${src_bpe_code}
${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${tgt_bpe_code}
fairseq-preprocess -s $s -t $t \
--testpref data_dir/bpe.$test.$s-$t \
--destdir data_dir/binarized \
--srcdict ${src_dict} \
--tgtdict ${tgt_dict}
```
## Calculating BLEU
```sh
DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl
cat ${generation_output} | grep -P "^H" | sort -V | cut -f 3- | $DETOKENIZER -l $t -q -a | sacrebleu -t $test -l $s-$t
```
## Romanian-English Translation
The direct and channel models are trained using bitext data (WMT16) combined with backtranslated data (The monolingual data used for backtranslation comes from http://data.statmt.org/rsennrich/wmt16_backtranslations/ (Sennrich et al., 2016c))
The backtranslated data is generated using an ensemble of 3 English-Romanian models trained on bitext training data (WMT16) with unrestricted sampling.
### BPE Codes and Dictionary
We learn a joint BPE vocabulary of 18K types on the bitext training data which is used for both the source and target.
||Path|
|----------|------|
| BPE Code | [joint_bpe_18k](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/bpe_18k) |
| Dictionary | [dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/dict) |
### Direct Models
For Ro-En with backtranslation, the direct and channel models use a Transformer-Big architecture.
| Seed | Model |
|----|----|
| 2 | [ro_en_seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed2.pt)
| 4 | [ro_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed4.pt)
| 6 | [ro_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed6.pt)
### Channel Models
For channel models, we follow the same steps as for the direct models. But backtranslated data is generated in the opposite direction using [this Romanian monolingual data](http://data.statmt.org/rsennrich/wmt16_backtranslations/).
The best lenpen, LM weight and CH weight are obtained by sweeping over the validation set (wmt16/dev) using beam 5.
| Model Size | Lenpen | LM Weight | CH Weight | Seed 2 | Seed 4 | Seed 6 |
|----|----|----|----|----|----|----|
| `big` | 0.84 | 0.64 | 0.56 | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) |
| `base_1_1` | 0.63 | 0.40 | 0.37 | [base_1_1.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed2.pt) | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed6.pt) |
### Language Model
The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
| | Path |
|----|----|
| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/transformer_lm.pt) |
| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/lm_dict)
## German-English Translation
### BPE Codes and Dictionaries
| | Path|
|----------|------|
| Source BPE Code | [de_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_bpe_code_24K) |
| Target BPE Code | [en_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_bpe_code_24K)
| Source Dictionary | [de_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_dict) |
| Target Dictionary | [en_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_dict) |
### Direct Models
We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
We use the Transformer-Big architecture for the direct model.
| Seed | Model |
|:----:|----|
| 4 | [de_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt)
| 5 | [de_en_seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed5.pt)
| 6 | [de_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed6.pt)
### Channel Models
We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
| Model Size | Seed 4 | Seed 5 | Seed 6 |
|----|----|----|----|
| `big` | [big.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt) | [big.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed5.pt) | [big.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed6.pt) |
| `big_1_1` | [big_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed4.pt) | [big_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed5.pt) | [big_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed6.pt) |
| `base` | [base.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed4.pt) | [base.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed5.pt) | [base.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed6.pt) |
| `base_1_1` | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed5.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed6.pt) |
| `half` | [half.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed4.pt) | [half.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed5.pt) | [half.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed6.pt) |
| `half_1_1` | [half_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed4.pt) | [half_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed5.pt) | [half_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed6.pt) |
| `quarter` | [quarter.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed4.pt) | [quarter.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed5.pt) | [quarter.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed6.pt) |
| `quarter_1_1` | [quarter_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed4.pt) | [quarter_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed5.pt) | [quarter_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed6.pt) |
| `8th` | [8th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed4.pt) | [8th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed5.pt) | [8th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed6.pt) |
| `8th_1_1` | [8th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed4.pt) | [8th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed5.pt) | [8th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed6.pt) |
| `16th` | [16th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed4.pt) | [16th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed5.pt) | [16th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed6.pt) |
| `16th_1_1` | [16th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed4.pt) | [16th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed5.pt) | [16th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed6.pt) |
### Language Model
The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
| | Path |
|----|----|
| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt) |
| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/)
## Citation
```bibtex
@inproceedings{bhosale2020language,
title={Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling},
author={Shruti Bhosale and Kyra Yee and Sergey Edunov and Michael Auli},
booktitle={Proceedings of the Fifth Conference on Machine Translation (WMT)},
year={2020},
}
@inproceedings{yee2019simple,
title={Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
author={Yee, Kyra and Dauphin, Yann and Auli, Michael},
booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
pages={5700--5705},
year={2019}
}
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import noisy_channel_translation # noqa
from . import noisy_channel_sequence_generator # noqa
from . import noisy_channel_beam_search # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.search import Search
class NoisyChannelBeamSearch(Search):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
self.fw_scores_buf = None
self.lm_scores_buf = None
def _init_buffers(self, t):
# super()._init_buffers(t)
if self.fw_scores_buf is None:
self.scores_buf = t.new()
self.indices_buf = torch.LongTensor().to(device=t.device)
self.beams_buf = torch.LongTensor().to(device=t.device)
self.fw_scores_buf = t.new()
self.lm_scores_buf = t.new()
def combine_fw_bw(self, combine_method, fw_cum, bw, step):
if combine_method == "noisy_channel":
fw_norm = fw_cum.div(step + 1)
lprobs = bw + fw_norm
elif combine_method == "lm_only":
lprobs = bw + fw_cum
return lprobs
def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method):
self._init_buffers(fw_lprobs)
bsz, beam_size, vocab_size = fw_lprobs.size()
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous()
bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous()
# nothing to add since we are at the first step
fw_lprobs_cum = fw_lprobs
else:
# make probs contain cumulative scores for each hypothesis
raw_scores = (scores[:, :, step - 1].unsqueeze(-1))
fw_lprobs_cum = (fw_lprobs.add(raw_scores))
combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step)
# choose the top k according to the combined noisy channel model score
torch.topk(
combined_lprobs.view(bsz, -1),
k=min(
# Take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size * 2,
combined_lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
),
out=(self.scores_buf, self.indices_buf),
)
# save corresponding fw and lm scores
self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf)
self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf)
# Project back into relative indices and beams
self.beams_buf = self.indices_buf // vocab_size
self.indices_buf.fmod_(vocab_size)
return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from .noisy_channel_beam_search import NoisyChannelBeamSearch
from fairseq.sequence_generator import EnsembleModel
class NoisyChannelSequenceGenerator(object):
def __init__(
self,
combine_method,
tgt_dict,
src_dict=None,
beam_size=1,
max_len_a=0,
max_len_b=200,
min_len=1,
len_penalty=1.0,
unk_penalty=0.0,
retain_dropout=False,
temperature=1.0,
match_source_len=False,
no_repeat_ngram_size=0,
normalize_scores=True,
channel_models=None,
k2=10,
ch_weight=1.0,
channel_scoring_type='log_norm',
top_k_vocab=0,
lm_models=None,
lm_dict=None,
lm_weight=1.0,
normalize_lm_scores_by_tgt_len=False,
):
"""Generates translations of a given source sentence,
using beam search with noisy channel decoding.
Args:
combine_method (string, optional): Method to combine direct, LM and
channel model scores (default: None)
tgt_dict (~fairseq.data.Dictionary): target dictionary
src_dict (~fairseq.data.Dictionary): source dictionary
beam_size (int, optional): beam width (default: 1)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
len_penalty (float, optional): length penalty, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
retain_dropout (bool, optional): use dropout when generating
(default: False)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
match_source_len (bool, optional): outputs should match the source
length (default: False)
no_repeat_ngram_size (int, optional): Size of n-grams that we avoid
repeating in the generation (default: 0)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
channel_models (List[~fairseq.models.FairseqModel]): ensemble of models
translating from the target to the source
k2 (int, optional): Top K2 candidates to score per beam at each step (default:10)
ch_weight (int, optional): Weight associated with the channel model score
assuming that the direct model score has weight 1.0 (default: 1.0)
channel_scoring_type (str, optional): String specifying how to score
the channel model (default: 'log_norm')
top_k_vocab (int, optional): If `channel_scoring_type` is `'src_vocab'` or
`'src_vocab_batched'`, then this parameter specifies the number of
most frequent tokens to include in the channel model output vocabulary,
in addition to the source tokens in the input batch (default: 0)
lm_models (List[~fairseq.models.FairseqModel]): ensemble of models
generating text in the target language
lm_dict (~fairseq.data.Dictionary): LM Model dictionary
lm_weight (int, optional): Weight associated with the LM model score
assuming that the direct model score has weight 1.0 (default: 1.0)
normalize_lm_scores_by_tgt_len (bool, optional): Should we normalize LM scores
by the target length? By default, we normalize the combination of
LM and channel model scores by the source length
"""
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(beam_size, self.vocab_size - 1)
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.min_len = min_len
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.temperature = temperature
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
self.channel_models = channel_models
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.combine_method = combine_method
self.k2 = k2
self.ch_weight = ch_weight
self.channel_scoring_type = channel_scoring_type
self.top_k_vocab = top_k_vocab
self.lm_models = lm_models
self.lm_dict = lm_dict
self.lm_weight = lm_weight
self.log_softmax_fn = torch.nn.LogSoftmax(dim=1)
self.normalize_lm_scores_by_tgt_len = normalize_lm_scores_by_tgt_len
self.share_tgt_dict = (self.lm_dict == self.tgt_dict)
self.tgt_to_lm = make_dict2dict(tgt_dict, lm_dict)
self.ch_scoring_bsz = 3072
assert temperature > 0, '--temperature must be greater than 0'
self.search = NoisyChannelBeamSearch(tgt_dict)
@torch.no_grad()
def generate(
self,
models,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
"""Generate a batch of translations.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
"""
model = EnsembleModel(models)
incremental_states = torch.jit.annotate(
List[Dict[str, Dict[str, Optional[Tensor]]]],
[
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
for i in range(model.models_size)
],
)
if not self.retain_dropout:
model.eval()
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample['net_input'].items()
if k != 'prev_output_tokens'
}
src_tokens = encoder_input['src_tokens']
src_lengths_no_eos = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
input_size = src_tokens.size()
# batch dimension goes first followed by source lengths
bsz = input_size[0]
src_len = input_size[1]
beam_size = self.beam_size
if self.match_source_len:
max_len = src_lengths_no_eos.max().item()
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
# exclude the EOS marker
model.max_decoder_positions() - 1,
)
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
src_lengths = encoder_input['src_lengths']
# initialize buffers
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
lm_prefix_scores = src_tokens.new(bsz * beam_size).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos if bos_token is None else bos_token
# reorder source tokens so they may be used as a reference in generating P(S|T)
src_tokens = reorder_all_tokens(src_tokens, src_lengths, self.src_dict.eos_index)
src_tokens = src_tokens.repeat(1, beam_size).view(-1, src_len)
src_lengths = src_lengths.view(bsz, -1).repeat(1, beam_size).view(bsz*beam_size, -1)
attn, attn_buf = None, None
nonpad_idxs = None
# The cands_to_ignore indicates candidates that should be ignored.
# For example, suppose we're sampling and have already finalized 2/5
# samples. Then the cands_to_ignore would mark 2 positions as being ignored,
# so that we only finalize the remaining 3 samples.
cands_to_ignore = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask
# list of completed sentences
finalized = [[] for i in range(bsz)]
finished = [False for i in range(bsz)]
num_remaining_sent = bsz
# number of candidate hypos per step
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly
buffers = {}
def buffer(name, type_of=tokens): # noqa
if name not in buffers:
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent, step, unfin_idx):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
possible score among unfinalized hypotheses.
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
return True
return False
def finalize_hypos(step, bbsz_idx, eos_scores, combined_noisy_channel_eos_scores):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
eos_scores: A vector of the same size as bbsz_idx containing
fw scores for each hypothesis
combined_noisy_channel_eos_scores: A vector of the same size as bbsz_idx containing
combined noisy channel scores for each hypothesis
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
assert not tokens_clone.eq(self.eos).any()
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
combined_noisy_channel_eos_scores /= (step + 1) ** self.len_penalty
cum_unfin = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
sents_seen = set()
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), combined_noisy_channel_eos_scores.tolist())):
unfin_idx = idx // beam_size
sent = unfin_idx + cum_unfin[unfin_idx]
sents_seen.add((sent, unfin_idx))
if self.match_source_len and step > src_lengths_no_eos[unfin_idx]:
score = -math.inf
def get_hypo():
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0)
else:
hypo_attn = None
alignment = None
return {
'tokens': tokens_clone[i],
'score': score,
'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment,
'positional_scores': pos_scores[i],
}
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
newly_finished = []
for sent, unfin_idx in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfin_idx):
finished[sent] = True
newly_finished.append(unfin_idx)
return newly_finished
def noisy_channel_rescoring(lprobs, beam_size, bsz, src_tokens, tokens, k):
"""Rescore the top k hypothesis from each beam using noisy channel modeling
Returns:
new_fw_lprobs: the direct model probabilities after pruning the top k
new_ch_lm_lprobs: the combined channel and language model probabilities
new_lm_lprobs: the language model probabilities after pruning the top k
"""
with torch.no_grad():
lprobs_size = lprobs.size()
if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1).data
).expand(-1, beam_size).contiguous().view(bsz*beam_size, 1)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, beam_size).data.contiguous().view(bsz*beam_size, 1)
# need to calculate and save fw and lm probs for prefix tokens
fw_top_k = cand_scores
fw_top_k_idx = cand_indices
k = 1
else:
# take the top k best words for every sentence in batch*beam
fw_top_k, fw_top_k_idx = torch.topk(lprobs.view(beam_size*bsz, -1), k=k)
eos_idx = torch.nonzero(fw_top_k_idx.view(bsz*beam_size*k, -1) == self.eos)[:, 0]
ch_scores = fw_top_k.new_full((beam_size*bsz*k, ), 0)
src_size = torch.sum(src_tokens[:, :] != self.src_dict.pad_index, dim=1, keepdim=True, dtype=fw_top_k.dtype)
if self.combine_method != "lm_only":
temp_src_tokens_full = src_tokens[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
not_padding = temp_src_tokens_full[:, 1:] != self.src_dict.pad_index
cur_tgt_size = step+2
# add eos to all candidate sentences except those that already end in eos
eos_tokens = tokens[:, 0].repeat(1, k).view(-1, 1)
eos_tokens[eos_idx] = self.tgt_dict.pad_index
if step == 0:
channel_input = torch.cat((fw_top_k_idx.view(-1, 1), eos_tokens), 1)
else:
# move eos from beginning to end of target sentence
channel_input = torch.cat((tokens[:, 1:step + 1].repeat(1, k).view(-1, step), fw_top_k_idx.view(-1, 1), eos_tokens), 1)
ch_input_lengths = torch.tensor(np.full(channel_input.size(0), cur_tgt_size))
ch_input_lengths[eos_idx] = cur_tgt_size-1
if self.channel_scoring_type == "unnormalized":
ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
del ch_encoder_output
ch_intermed_scores = channel_model.decoder.unnormalized_scores_given_target(ch_decoder_output, target_ids=temp_src_tokens_full[:, 1:])
ch_intermed_scores = ch_intermed_scores.float()
ch_intermed_scores *= not_padding.float()
ch_scores = torch.sum(ch_intermed_scores, dim=1)
elif self.channel_scoring_type == "k2_separate":
for k_idx in range(k):
k_eos_tokens = eos_tokens[k_idx::k, :]
if step == 0:
k_ch_input = torch.cat((fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
else:
# move eos from beginning to end of target sentence
k_ch_input = torch.cat((tokens[:, 1:step + 1], fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
k_ch_input_lengths = ch_input_lengths[k_idx::k]
k_ch_output = channel_model(k_ch_input, k_ch_input_lengths, src_tokens)
k_ch_lprobs = channel_model.get_normalized_probs(k_ch_output, log_probs=True)
k_ch_intermed_scores = torch.gather(k_ch_lprobs[:, :-1, :], 2, src_tokens[:, 1:].unsqueeze(2)).squeeze(2)
k_ch_intermed_scores *= not_padding.float()
ch_scores[k_idx::k] = torch.sum(k_ch_intermed_scores, dim=1)
elif self.channel_scoring_type == "src_vocab":
ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
del ch_encoder_output
ch_lprobs = normalized_scores_with_batch_vocab(
channel_model.decoder,
ch_decoder_output, src_tokens, k, bsz, beam_size,
self.src_dict.pad_index, top_k=self.top_k_vocab)
ch_scores = torch.sum(ch_lprobs, dim=1)
elif self.channel_scoring_type == "src_vocab_batched":
ch_bsz_size = temp_src_tokens_full.shape[0]
ch_lprobs_list = [None] * len(range(0, ch_bsz_size, self.ch_scoring_bsz))
for i, start_idx in enumerate(range(0, ch_bsz_size, self.ch_scoring_bsz)):
end_idx = min(start_idx + self.ch_scoring_bsz, ch_bsz_size)
temp_src_tokens_full_batch = temp_src_tokens_full[start_idx:end_idx, :]
channel_input_batch = channel_input[start_idx:end_idx, :]
ch_input_lengths_batch = ch_input_lengths[start_idx:end_idx]
ch_encoder_output_batch = channel_model.encoder(channel_input_batch, src_lengths=ch_input_lengths_batch)
ch_decoder_output_batch, _ = channel_model.decoder(temp_src_tokens_full_batch, encoder_out=ch_encoder_output_batch, features_only=True)
ch_lprobs_list[i] = normalized_scores_with_batch_vocab(
channel_model.decoder,
ch_decoder_output_batch, src_tokens, k, bsz, beam_size,
self.src_dict.pad_index, top_k=self.top_k_vocab,
start_idx=start_idx, end_idx=end_idx)
ch_lprobs = torch.cat(ch_lprobs_list, dim=0)
ch_scores = torch.sum(ch_lprobs, dim=1)
else:
ch_output = channel_model(channel_input, ch_input_lengths, temp_src_tokens_full)
ch_lprobs = channel_model.get_normalized_probs(ch_output, log_probs=True)
ch_intermed_scores = torch.gather(ch_lprobs[:, :-1, :], 2, temp_src_tokens_full[:, 1:].unsqueeze(2)).squeeze().view(bsz*beam_size*k, -1)
ch_intermed_scores *= not_padding.float()
ch_scores = torch.sum(ch_intermed_scores, dim=1)
else:
cur_tgt_size = 0
ch_scores = ch_scores.view(bsz*beam_size, k)
expanded_lm_prefix_scores = lm_prefix_scores.unsqueeze(1).expand(-1, k).flatten()
if self.share_tgt_dict:
lm_scores = get_lm_scores(lm, tokens[:, :step + 1].view(-1, step+1), lm_incremental_states, fw_top_k_idx.view(-1, 1), torch.tensor(np.full(tokens.size(0), step+1)), k)
else:
new_lm_input = dict2dict(tokens[:, :step + 1].view(-1, step+1), self.tgt_to_lm)
new_cands = dict2dict(fw_top_k_idx.view(-1, 1), self.tgt_to_lm)
lm_scores = get_lm_scores(lm, new_lm_input, lm_incremental_states, new_cands, torch.tensor(np.full(tokens.size(0), step+1)), k)
lm_scores.add_(expanded_lm_prefix_scores)
ch_lm_scores = combine_ch_lm(self.combine_method, ch_scores, lm_scores, src_size, cur_tgt_size)
# initialize all as min value
new_fw_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
new_ch_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
new_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
new_fw_lprobs[:, self.pad] = -math.inf
new_ch_lm_lprobs[:, self.pad] = -math.inf
new_lm_lprobs[:, self.pad] = -math.inf
new_fw_lprobs.scatter_(1, fw_top_k_idx, fw_top_k)
new_ch_lm_lprobs.scatter_(1, fw_top_k_idx, ch_lm_scores)
new_lm_lprobs.scatter_(1, fw_top_k_idx, lm_scores.view(-1, k))
return new_fw_lprobs, new_ch_lm_lprobs, new_lm_lprobs
def combine_ch_lm(combine_type, ch_scores, lm_scores1, src_size, tgt_size):
if self.channel_scoring_type == "unnormalized":
ch_scores = self.log_softmax_fn(
ch_scores.view(-1, self.beam_size * self.k2)
).view(ch_scores.shape)
ch_scores = ch_scores * self.ch_weight
lm_scores1 = lm_scores1 * self.lm_weight
if combine_type == "lm_only":
# log P(T|S) + log P(T)
ch_scores = lm_scores1.view(ch_scores.size())
elif combine_type == "noisy_channel":
# 1/t log P(T|S) + 1/s log P(S|T) + 1/t log P(T)
if self.normalize_lm_scores_by_tgt_len:
ch_scores.div_(src_size)
lm_scores_norm = lm_scores1.view(ch_scores.size()).div(tgt_size)
ch_scores.add_(lm_scores_norm)
# 1/t log P(T|S) + 1/s log P(S|T) + 1/s log P(T)
else:
ch_scores.add_(lm_scores1.view(ch_scores.size()))
ch_scores.div_(src_size)
return ch_scores
if self.channel_models is not None:
channel_model = self.channel_models[0] # assume only one channel_model model
else:
channel_model = None
lm = EnsembleModel(self.lm_models)
lm_incremental_states = torch.jit.annotate(
List[Dict[str, Dict[str, Optional[Tensor]]]],
[
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
for i in range(lm.models_size)
],
)
reorder_state = None
batch_idxs = None
for step in range(max_len + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
model.reorder_incremental_state(incremental_states, reorder_state)
encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state)
lm.reorder_incremental_state(lm_incremental_states, reorder_state)
fw_lprobs, avg_attn_scores = model.forward_decoder(
tokens[:, :step + 1], encoder_outs, incremental_states, temperature=self.temperature,
)
fw_lprobs[:, self.pad] = -math.inf # never select pad
fw_lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
fw_lprobs, ch_lm_lprobs, lm_lprobs = noisy_channel_rescoring(fw_lprobs, beam_size, bsz, src_tokens, tokens, self.k2)
# handle min and max length constraints
if step >= max_len:
fw_lprobs[:, :self.eos] = -math.inf
fw_lprobs[:, self.eos + 1:] = -math.inf
elif step < self.min_len:
fw_lprobs[:, self.eos] = -math.inf
# handle prefix tokens (possibly with different lengths)
if prefix_tokens is not None and step < prefix_tokens.size(1):
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
prefix_mask = prefix_toks.ne(self.pad)
prefix_fw_lprobs = fw_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
fw_lprobs[prefix_mask] = -math.inf
fw_lprobs[prefix_mask] = fw_lprobs[prefix_mask].scatter_(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_fw_lprobs
)
prefix_ch_lm_lprobs = ch_lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
ch_lm_lprobs[prefix_mask] = -math.inf
ch_lm_lprobs[prefix_mask] = ch_lm_lprobs[prefix_mask].scatter_(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_ch_lm_lprobs
)
prefix_lm_lprobs = lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
lm_lprobs[prefix_mask] = -math.inf
lm_lprobs[prefix_mask] = lm_lprobs[prefix_mask].scatter_(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lm_lprobs
)
# if prefix includes eos, then we should make sure tokens and
# scores are the same across all beams
eos_mask = prefix_toks.eq(self.eos)
if eos_mask.any():
# validate that the first beam matches the prefix
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
assert (first_beam == target_prefix).all()
def replicate_first_beam(tensor, mask):
tensor = tensor.view(-1, beam_size, tensor.size(-1))
tensor[mask] = tensor[mask][:, :1, :]
return tensor.view(-1, tensor.size(-1))
# copy tokens, scores and lprobs from the first beam to all beams
tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
scores = replicate_first_beam(scores, eos_mask_batch_dim)
fw_lprobs = replicate_first_beam(fw_lprobs, eos_mask_batch_dim)
ch_lm_lprobs = replicate_first_beam(ch_lm_lprobs, eos_mask_batch_dim)
lm_lprobs = replicate_first_beam(lm_lprobs, eos_mask_batch_dim)
if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
gen_tokens = tokens[bbsz_idx].tolist()
for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
# Record attention scores
if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
scores = scores.type_as(fw_lprobs)
scores_buf = scores_buf.type_as(fw_lprobs)
self.search.set_src_lengths(src_lengths_no_eos)
if self.no_repeat_ngram_size > 0:
def calculate_banned_tokens(bbsz_idx):
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
return gen_ngrams[bbsz_idx].get(ngram_index, [])
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
else:
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
fw_lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
combined_noisy_channel_scores, fw_lprobs_top_k, lm_lprobs_top_k, cand_indices, cand_beams = self.search.step(
step,
fw_lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step], ch_lm_lprobs.view(bsz, -1, self.vocab_size),
lm_lprobs.view(bsz, -1, self.vocab_size), self.combine_method
)
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos (except for candidates to be ignored)
eos_mask = cand_indices.eq(self.eos)
eos_mask[:, :beam_size] &= ~cands_to_ignore
# only consider eos when it's among the top beam_size indices
eos_bbsz_idx = torch.masked_select(
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
)
finalized_sents = set()
if eos_bbsz_idx.numel() > 0:
eos_scores = torch.masked_select(
fw_lprobs_top_k[:, :beam_size], mask=eos_mask[:, :beam_size]
)
combined_noisy_channel_eos_scores = torch.masked_select(
combined_noisy_channel_scores[:, :beam_size],
mask=eos_mask[:, :beam_size],
)
# finalize hypo using channel model score
finalized_sents = finalize_hypos(
step, eos_bbsz_idx, eos_scores, combined_noisy_channel_eos_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = cand_indices.new_ones(bsz)
batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = torch.nonzero(batch_mask).squeeze(-1)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
lm_lprobs_top_k = lm_lprobs_top_k[batch_idxs]
fw_lprobs_top_k = fw_lprobs_top_k[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
src_lengths_no_eos = src_lengths_no_eos[batch_idxs]
cands_to_ignore = cands_to_ignore[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
src_tokens = src_tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
src_lengths = src_lengths.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
lm_prefix_scores = lm_prefix_scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1).squeeze()
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
else:
batch_idxs = None
# Set active_mask so that values > cand_size indicate eos or
# ignored hypos and values < cand_size indicate candidate
# active hypos. After this, the min values per row are the top
# candidate active hypos.
eos_mask[:, :beam_size] |= cands_to_ignore
active_mask = torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[: eos_mask.size(1)],
)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos, new_cands_to_ignore = buffer('active_hypos'), buffer('new_cands_to_ignore')
torch.topk(
active_mask, k=beam_size, dim=1, largest=False,
out=(new_cands_to_ignore, active_hypos)
)
# update cands_to_ignore to ignore any finalized hypos
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
assert (~cands_to_ignore).any(dim=1).all()
active_bbsz_idx = buffer('active_bbsz_idx')
torch.gather(
cand_bbsz_idx, dim=1, index=active_hypos,
out=active_bbsz_idx,
)
active_scores = torch.gather(
fw_lprobs_top_k, dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
torch.index_select(
tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step + 1],
)
torch.gather(
cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)
if step > 0:
torch.index_select(
scores[:, :step], dim=0, index=active_bbsz_idx,
out=scores_buf[:, :step],
)
torch.gather(
fw_lprobs_top_k, dim=1, index=active_hypos,
out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
)
torch.gather(
lm_lprobs_top_k, dim=1, index=active_hypos,
out=lm_prefix_scores.view(bsz, beam_size)
)
# copy attention for active hypotheses
if attn is not None:
torch.index_select(
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step + 2],
)
# swap buffers
tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores
if attn is not None:
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized
def get_lm_scores(model, input_tokens, incremental_states, cand_tokens, input_len, k):
with torch.no_grad():
lm_lprobs, avg_attn_scores = model.forward_decoder(
input_tokens, encoder_outs=None, incremental_states=incremental_states,
)
lm_lprobs_size = lm_lprobs.size(0)
probs_next_wrd = torch.gather(lm_lprobs.repeat(1, k).view(lm_lprobs_size*k, -1), 1, cand_tokens).squeeze().view(-1)
return probs_next_wrd
def make_dict2dict(old_dict, new_dict):
dict2dict_map = {}
for sym in old_dict.symbols:
dict2dict_map[old_dict.index(sym)] = new_dict.index(sym)
return dict2dict_map
def dict2dict(tokens, dict2dict_map):
if tokens.device == torch.device('cpu'):
tokens_tmp = tokens
else:
tokens_tmp = tokens.cpu()
return tokens_tmp.map_(
tokens_tmp,
lambda _, val, dict2dict_map=dict2dict_map : dict2dict_map[float(val)]
).to(tokens.device)
def reorder_tokens(tokens, lengths, eos):
# reorder source tokens so they may be used as reference for P(S|T)
return torch.cat((tokens.new([eos]), tokens[-lengths:-1], tokens[:-lengths]), 0)
def reorder_all_tokens(tokens, lengths, eos):
# used to reorder src tokens from [<pad> <w1> <w2> .. <eos>] to [<eos> <w1> <w2>...<pad>]
# so source tokens can be used to predict P(S|T)
return torch.stack([reorder_tokens(token, length, eos) for token, length in zip(tokens, lengths)])
def normalized_scores_with_batch_vocab(
model_decoder, features, target_ids, k, bsz, beam_size,
pad_idx, top_k=0, vocab_size_meter=None, start_idx=None,
end_idx=None, **kwargs):
"""
Get normalized probabilities (or log probs) from a net's output
w.r.t. vocab consisting of target IDs in the batch
"""
if model_decoder.adaptive_softmax is None:
weight = model_decoder.output_projection.weight
vocab_ids = torch.unique(
torch.cat(
(torch.unique(target_ids), torch.arange(top_k, device=target_ids.device))
)
)
id_map = dict(zip(vocab_ids.tolist(), range(len(vocab_ids))))
mapped_target_ids = target_ids.cpu().apply_(
lambda x, id_map=id_map: id_map[x]
).to(target_ids.device)
expanded_target_ids = mapped_target_ids[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
if start_idx is not None and end_idx is not None:
expanded_target_ids = expanded_target_ids[start_idx:end_idx, :]
logits = F.linear(features, weight[vocab_ids, :])
log_softmax = F.log_softmax(logits, dim=-1, dtype=torch.float32)
intermed_scores = torch.gather(
log_softmax[:, :-1, :],
2,
expanded_target_ids[:, 1:].unsqueeze(2),
).squeeze()
not_padding = expanded_target_ids[:, 1:] != pad_idx
intermed_scores *= not_padding.float()
return intermed_scores
else:
raise ValueError("adaptive softmax doesn't work with " +
"`normalized_scores_with_batch_vocab()`")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.tasks.translation import TranslationTask
from fairseq.tasks.language_modeling import LanguageModelingTask
from fairseq import checkpoint_utils
import argparse
from fairseq.tasks import register_task
import torch
@register_task("noisy_channel_translation")
class NoisyChannelTranslation(TranslationTask):
"""
Rescore the top k candidates from each beam using noisy channel modeling
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
TranslationTask.add_args(parser)
# fmt: off
parser.add_argument('--channel-model', metavar='FILE',
help='path to P(S|T) model. P(S|T) and P(T|S) must share source and target dictionaries.')
parser.add_argument('--combine-method', default='lm_only',
choices=['lm_only', 'noisy_channel'],
help="""method for combining direct and channel model scores.
lm_only: decode with P(T|S)P(T)
noisy_channel: decode with 1/t P(T|S) + 1/s(P(S|T)P(T))""")
parser.add_argument('--normalize-lm-scores-by-tgt-len', action='store_true', default=False,
help='normalize lm score by target length instead of source length')
parser.add_argument('--channel-scoring-type', default='log_norm', choices=['unnormalized', 'log_norm', 'k2_separate', 'src_vocab', 'src_vocab_batched'],
help="Normalize bw scores with log softmax or return bw scores without log softmax")
parser.add_argument('--top-k-vocab', default=0, type=int,
help='top k vocab IDs to use with `src_vocab` in channel model scoring')
parser.add_argument('--k2', default=50, type=int,
help='the top k2 candidates to rescore with the noisy channel model for each beam')
parser.add_argument('--ch-wt', default=1, type=float,
help='weight for the channel model')
parser.add_argument('--lm-model', metavar='FILE',
help='path to lm model file, to model P(T). P(T) must share the same vocab as the direct model on the target side')
parser.add_argument('--lm-data', metavar='FILE',
help='path to lm model training data for target language, used to properly load LM with correct dictionary')
parser.add_argument('--lm-wt', default=1, type=float,
help='the weight of the lm in joint decoding')
# fmt: on
def build_generator(
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
):
if getattr(args, "score_reference", False):
raise NotImplementedError()
else:
from .noisy_channel_sequence_generator import NoisyChannelSequenceGenerator
use_cuda = torch.cuda.is_available() and not self.args.cpu
assert self.args.lm_model is not None, '--lm-model required for noisy channel generation!'
assert self.args.lm_data is not None, '--lm-data required for noisy channel generation to map between LM and bitext vocabs'
if self.args.channel_model is not None:
import copy
ch_args_task = copy.deepcopy(self.args)
tmp = ch_args_task.source_lang
ch_args_task.source_lang = ch_args_task.target_lang
ch_args_task.target_lang = tmp
ch_args_task._name = 'translation'
channel_task = TranslationTask.setup_task(ch_args_task)
arg_dict = {}
arg_dict['task'] = 'language_modeling'
arg_dict['sample_break_mode'] = 'eos'
arg_dict['data'] = self.args.lm_data
arg_dict['output_dictionary_size'] = -1
lm_args = argparse.Namespace(**arg_dict)
lm_task = LanguageModelingTask.setup_task(lm_args)
lm_dict = lm_task.output_dictionary
if self.args.channel_model is not None:
channel_models, _ = checkpoint_utils.load_model_ensemble(self.args.channel_model.split(':'), task=channel_task)
for model in channel_models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if self.args.fp16:
model.half()
if use_cuda:
model.cuda()
else:
channel_models = None
lm_models, _ = checkpoint_utils.load_model_ensemble(self.args.lm_model.split(':'), task=lm_task)
for model in lm_models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if self.args.fp16:
model.half()
if use_cuda:
model.cuda()
return NoisyChannelSequenceGenerator(
combine_method=self.args.combine_method,
tgt_dict=self.target_dictionary,
src_dict=self.source_dictionary,
beam_size=getattr(args, 'beam', 5),
max_len_a=getattr(args, 'max_len_a', 0),
max_len_b=getattr(args, 'max_len_b', 200),
min_len=getattr(args, 'min_len', 1),
len_penalty=getattr(args, 'lenpen', 1),
unk_penalty=getattr(args, 'unkpen', 0),
temperature=getattr(args, 'temperature', 1.),
match_source_len=getattr(args, 'match_source_len', False),
no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
normalize_scores=(not getattr(args, 'unnormalized', False)),
channel_models=channel_models,
k2=getattr(self.args, 'k2', 50),
ch_weight=getattr(self.args, 'ch_wt', 1),
channel_scoring_type=self.args.channel_scoring_type,
top_k_vocab=self.args.top_k_vocab,
lm_models=lm_models,
lm_dict=lm_dict,
lm_weight=getattr(self.args, 'lm_wt', 1),
normalize_lm_scores_by_tgt_len=getattr(self.args, 'normalize_lm_scores_by_tgt_len', False),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment