Commit 799a38c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #616 failed with stages
in 0 seconds
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate pre-processed data with a trained model.
"""
import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.sequence_generator import EnsembleModel
from fairseq.utils import safe_hasattr
def get_avg_pool(
models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
):
model = EnsembleModel(models)
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)
np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
np.float32
)
encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
if has_langtok:
encoder_mask = encoder_mask[1:, :, :]
np_encoder_outs = np_encoder_outs[1, :, :]
masked_encoder_outs = encoder_mask * np_encoder_outs
avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0)
return avg_pool
def main(args):
assert args.path is not None, "--path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
args.beam = 1
utils.import_user_module(args)
if args.max_tokens is None:
args.max_tokens = 12000
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
# Set dictionaries
try:
src_dict = getattr(task, "source_dictionary", None)
except NotImplementedError:
src_dict = None
tgt_dict = task.target_dictionary
# Load ensemble
print("| loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(":"),
arg_overrides=eval(args.model_overrides),
task=task,
)
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_positions=utils.resolve_max_positions(
task.max_positions(),
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
num_sentences = 0
source_sentences = []
shard_id = 0
all_avg_pool = None
encoder_has_langtok = (
safe_hasattr(task.args, "encoder_langtok")
and task.args.encoder_langtok is not None
and safe_hasattr(task.args, "lang_tok_replacing_bos_eos")
and not task.args.lang_tok_replacing_bos_eos
)
with progress_bar.build_progress_bar(args, itr) as t:
for sample in t:
if sample is None:
print("Skipping None")
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]
with torch.no_grad():
avg_pool = get_avg_pool(
models,
sample,
prefix_tokens,
src_dict,
args.post_process,
has_langtok=encoder_has_langtok,
)
if all_avg_pool is not None:
all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
else:
all_avg_pool = avg_pool
if not isinstance(sample["id"], list):
sample_ids = sample["id"].tolist()
else:
sample_ids = sample["id"]
for i, sample_id in enumerate(sample_ids):
# Remove padding
src_tokens = utils.strip_pad(
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
)
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(
sample_id
)
else:
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.post_process)
else:
src_str = ""
if not args.quiet:
if src_dict is not None:
print("S-{}\t{}".format(sample_id, src_str))
source_sentences.append(f"{sample_id}\t{src_str}")
num_sentences += sample["nsentences"]
if all_avg_pool.shape[0] >= 1000000:
with open(
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
"w",
) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file)
with open(
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
"w",
) as sentence_file:
sentence_file.writelines(f"{line}\n" for line in source_sentences)
all_avg_pool = None
source_sentences = []
shard_id += 1
if all_avg_pool is not None:
with open(
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
) as avg_pool_file:
all_avg_pool.tofile(avg_pool_file)
with open(
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
) as sentence_file:
sentence_file.writelines(f"{line}\n" for line in source_sentences)
return None
def cli_main():
parser = options.get_generation_parser()
parser.add_argument(
"--encoder-save-dir",
default="",
type=str,
metavar="N",
help="directory to save encoder outputs",
)
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import glob
import numpy as np
DIM = 1024
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
target_ids = [tid for tid in target_embs]
source_mat = np.stack(source_embs.values(), axis=0)
normalized_source_mat = source_mat / np.linalg.norm(
source_mat, axis=1, keepdims=True
)
target_mat = np.stack(target_embs.values(), axis=0)
normalized_target_mat = target_mat / np.linalg.norm(
target_mat, axis=1, keepdims=True
)
sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
if return_sim_mat:
return sim_mat
neighbors_map = {}
for i, sentence_id in enumerate(source_embs):
idx = np.argsort(sim_mat[i, :])[::-1][:k]
neighbors_map[sentence_id] = [target_ids[tid] for tid in idx]
return neighbors_map
def load_embeddings(directory, LANGS):
sentence_embeddings = {}
sentence_texts = {}
for lang in LANGS:
sentence_embeddings[lang] = {}
sentence_texts[lang] = {}
lang_dir = f"{directory}/{lang}"
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
for embed_file in embedding_files:
shard_id = embed_file.split(".")[-1]
embeddings = np.fromfile(embed_file, dtype=np.float32)
num_rows = embeddings.shape[0] // DIM
embeddings = embeddings.reshape((num_rows, DIM))
with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
for idx, line in enumerate(sentence_file):
sentence_id, sentence = line.strip().split("\t")
sentence_texts[lang][sentence_id] = sentence
sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
return sentence_embeddings, sentence_texts
def compute_accuracy(directory, LANGS):
sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS)
top_1_accuracy = {}
top1_str = " ".join(LANGS) + "\n"
for source_lang in LANGS:
top_1_accuracy[source_lang] = {}
top1_str += f"{source_lang} "
for target_lang in LANGS:
top1 = 0
top5 = 0
neighbors_map = compute_dist(
sentence_embeddings[source_lang], sentence_embeddings[target_lang]
)
for sentence_id, neighbors in neighbors_map.items():
if sentence_id == neighbors[0]:
top1 += 1
if sentence_id in neighbors[:5]:
top5 += 1
n = len(sentence_embeddings[target_lang])
top1_str += f"{top1/n} "
top1_str += "\n"
print(top1_str)
print(top1_str, file=open(f"{directory}/accuracy", "w"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Analyze encoder outputs")
parser.add_argument("directory", help="Source language corpus")
parser.add_argument("--langs", help="List of langs")
args = parser.parse_args()
langs = args.langs.split(",")
compute_accuracy(args.directory, langs)
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
source_lang=kk_KZ
target_lang=en_XX
MODEL=criss_checkpoints/criss.3rd.pt
SPM=criss_checkpoints/sentence.bpe.model
SPLIT=test
LANG_DICT=criss_checkpoints/lang_dict.txt
ENCODER_ANALYSIS=sentence_retrieval/encoder_analysis.py
SAVE_ENCODER=save_encoder.py
ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL
DATA_DIR=data_tmp
INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba
ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang}
mkdir -p $ENCODER_SAVE_DIR/${target_lang}
mkdir -p $ENCODER_SAVE_DIR/${source_lang}
# Save encoder outputs for source sentences
python $SAVE_ENCODER \
${INPUT_DIR} \
--path ${MODEL} \
--task translation_multi_simple_epoch \
--lang-dict ${LANG_DICT} \
--gen-subset ${SPLIT} \
--bpe 'sentencepiece' \
--lang-pairs ${source_lang}-${target_lang} \
-s ${source_lang} -t ${target_lang} \
--sentencepiece-model ${SPM} \
--remove-bpe 'sentencepiece' \
--beam 1 \
--lang-tok-style mbart \
--encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang}
# Save encoder outputs for target sentences
python $SAVE_ENCODER \
${INPUT_DIR} \
--path ${MODEL} \
--lang-dict ${LANG_DICT} \
--task translation_multi_simple_epoch \
--gen-subset ${SPLIT} \
--bpe 'sentencepiece' \
--lang-pairs ${target_lang}-${source_lang} \
-t ${source_lang} -s ${target_lang} \
--sentencepiece-model ${SPM} \
--remove-bpe 'sentencepiece' \
--beam 1 \
--lang-tok-style mbart \
--encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang}
# Analyze sentence retrieval accuracy
python $ENCODER_ANALYSIS --langs "${source_lang},${target_lang}" ${ENCODER_SAVE_DIR}
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
SRC=si_LK
TGT=en_XX
MODEL=criss_checkpoints/criss.3rd.pt
MULTIBLEU=mosesdecoder/scripts/generic/multi-bleu.perl
MOSES=mosesdecoder
REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl
NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl
TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl
GEN_TMP_DIR=gen_tmp
LANG_DICT=criss_checkpoints/lang_dict.txt
if [ ! -d "mosesdecoder" ]; then
git clone https://github.com/moses-smt/mosesdecoder
fi
mkdir -p $GEN_TMP_DIR
fairseq-generate data_tmp/${SRC}-${TGT}-flores \
--task translation_multi_simple_epoch \
--max-tokens 2000 \
--path ${MODEL} \
--skip-invalid-size-inputs-valid-test \
--beam 5 --lenpen 1.0 --gen-subset test \
--remove-bpe=sentencepiece \
--source-lang ${SRC} --target-lang ${TGT} \
--decoder-langtok --lang-pairs 'en_XX-ar_AR,en_XX-de_DE,en_XX-es_XX,en_XX-fr_XX,en_XX-hi_IN,en_XX-it_IT,en_XX-ja_XX,en_XX-ko_KR,en_XX-nl_XX,en_XX-ru_RU,en_XX-zh_CN,en_XX-tr_TR,en_XX-vi_VN,en_XX-ro_RO,en_XX-my_MM,en_XX-ne_NP,en_XX-si_LK,en_XX-cs_CZ,en_XX-lt_LT,en_XX-kk_KZ,en_XX-gu_IN,en_XX-fi_FI,en_XX-et_EE,en_XX-lv_LV,ar_AR-en_XX,cs_CZ-en_XX,de_DE-en_XX,es_XX-en_XX,et_EE-en_XX,fi_FI-en_XX,fr_XX-en_XX,gu_IN-en_XX,hi_IN-en_XX,it_IT-en_XX,ja_XX-en_XX,kk_KZ-en_XX,ko_KR-en_XX,lt_LT-en_XX,lv_LV-en_XX,my_MM-en_XX,ne_NP-en_XX,nl_XX-en_XX,ro_RO-en_XX,ru_RU-en_XX,si_LK-en_XX,tr_TR-en_XX,vi_VN-en_XX,zh_CN-en_XX,ar_AR-es_XX,es_XX-ar_AR,ar_AR-hi_IN,hi_IN-ar_AR,ar_AR-zh_CN,zh_CN-ar_AR,cs_CZ-es_XX,es_XX-cs_CZ,cs_CZ-hi_IN,hi_IN-cs_CZ,cs_CZ-zh_CN,zh_CN-cs_CZ,de_DE-es_XX,es_XX-de_DE,de_DE-hi_IN,hi_IN-de_DE,de_DE-zh_CN,zh_CN-de_DE,es_XX-hi_IN,hi_IN-es_XX,es_XX-zh_CN,zh_CN-es_XX,et_EE-es_XX,es_XX-et_EE,et_EE-hi_IN,hi_IN-et_EE,et_EE-zh_CN,zh_CN-et_EE,fi_FI-es_XX,es_XX-fi_FI,fi_FI-hi_IN,hi_IN-fi_FI,fi_FI-zh_CN,zh_CN-fi_FI,fr_XX-es_XX,es_XX-fr_XX,fr_XX-hi_IN,hi_IN-fr_XX,fr_XX-zh_CN,zh_CN-fr_XX,gu_IN-es_XX,es_XX-gu_IN,gu_IN-hi_IN,hi_IN-gu_IN,gu_IN-zh_CN,zh_CN-gu_IN,hi_IN-zh_CN,zh_CN-hi_IN,it_IT-es_XX,es_XX-it_IT,it_IT-hi_IN,hi_IN-it_IT,it_IT-zh_CN,zh_CN-it_IT,ja_XX-es_XX,es_XX-ja_XX,ja_XX-hi_IN,hi_IN-ja_XX,ja_XX-zh_CN,zh_CN-ja_XX,kk_KZ-es_XX,es_XX-kk_KZ,kk_KZ-hi_IN,hi_IN-kk_KZ,kk_KZ-zh_CN,zh_CN-kk_KZ,ko_KR-es_XX,es_XX-ko_KR,ko_KR-hi_IN,hi_IN-ko_KR,ko_KR-zh_CN,zh_CN-ko_KR,lt_LT-es_XX,es_XX-lt_LT,lt_LT-hi_IN,hi_IN-lt_LT,lt_LT-zh_CN,zh_CN-lt_LT,lv_LV-es_XX,es_XX-lv_LV,lv_LV-hi_IN,hi_IN-lv_LV,lv_LV-zh_CN,zh_CN-lv_LV,my_MM-es_XX,es_XX-my_MM,my_MM-hi_IN,hi_IN-my_MM,my_MM-zh_CN,zh_CN-my_MM,ne_NP-es_XX,es_XX-ne_NP,ne_NP-hi_IN,hi_IN-ne_NP,ne_NP-zh_CN,zh_CN-ne_NP,nl_XX-es_XX,es_XX-nl_XX,nl_XX-hi_IN,hi_IN-nl_XX,nl_XX-zh_CN,zh_CN-nl_XX,ro_RO-es_XX,es_XX-ro_RO,ro_RO-hi_IN,hi_IN-ro_RO,ro_RO-zh_CN,zh_CN-ro_RO,ru_RU-es_XX,es_XX-ru_RU,ru_RU-hi_IN,hi_IN-ru_RU,ru_RU-zh_CN,zh_CN-ru_RU,si_LK-es_XX,es_XX-si_LK,si_LK-hi_IN,hi_IN-si_LK,si_LK-zh_CN,zh_CN-si_LK,tr_TR-es_XX,es_XX-tr_TR,tr_TR-hi_IN,hi_IN-tr_TR,tr_TR-zh_CN,zh_CN-tr_TR,vi_VN-es_XX,es_XX-vi_VN,vi_VN-hi_IN,hi_IN-vi_VN,vi_VN-zh_CN,zh_CN-vi_VN' \
--lang-dict ${LANG_DICT} --lang-tok-style 'mbart' --sampling-method 'temperature' --sampling-temperature '1.0' > $GEN_TMP_DIR/${SRC}_${TGT}.gen
cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^T-" | cut -f2 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.hyp
cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^H-" | cut -f3 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.ref
${MULTIBLEU} $GEN_TMP_DIR/${SRC}_${TGT}.ref < $GEN_TMP_DIR/${SRC}_${TGT}.hyp
# Cross-Lingual Language Model Pre-training
Below are some details for training Cross-Lingual Language Models (XLM) - similar to the ones presented in [Lample & Conneau, 2019](https://arxiv.org/pdf/1901.07291.pdf) - in Fairseq. The current implementation only supports the Masked Language Model (MLM) from the paper above.
## Downloading and Tokenizing Monolingual Data
Pointers to the monolingual data from wikipedia, used for training the XLM-style MLM model as well as details on processing (tokenization and BPE) it can be found in the [XLM Github Repository](https://github.com/facebookresearch/XLM#download--preprocess-monolingual-data).
Let's assume the following for the code snippets in later sections to work
- Processed data is in the folder: monolingual_data/processed
- Each language has 3 files for train, test and validation. For example we have the following files for English:
train.en, valid.en
- We are training a model for 5 languages: Arabic (ar), German (de), English (en), Hindi (hi) and French (fr)
- The vocabulary file is monolingual_data/processed/vocab_mlm
## Fairseq Pre-processing and Binarization
Pre-process and binarize the data with the MaskedLMDictionary and cross_lingual_lm task
```bash
# Ensure the output directory exists
DATA_DIR=monolingual_data/fairseq_processed
mkdir -p "$DATA_DIR"
for lg in ar de en hi fr
do
fairseq-preprocess \
--task cross_lingual_lm \
--srcdict monolingual_data/processed/vocab_mlm \
--only-source \
--trainpref monolingual_data/processed/train \
--validpref monolingual_data/processed/valid \
--testpref monolingual_data/processed/test \
--destdir monolingual_data/fairseq_processed \
--workers 20 \
--source-lang $lg
# Since we only have a source language, the output file has a None for the
# target language. Remove this
for stage in train test valid
sudo mv "$DATA_DIR/$stage.$lg-None.$lg.bin" "$stage.$lg.bin"
sudo mv "$DATA_DIR/$stage.$lg-None.$lg.idx" "$stage.$lg.idx"
done
done
```
## Train a Cross-lingual Language Model similar to the XLM MLM model
Use the following command to train the model on 5 languages.
```
fairseq-train \
--task cross_lingual_lm monolingual_data/fairseq_processed \
--save-dir checkpoints/mlm \
--max-update 2400000 --save-interval 1 --no-epoch-checkpoints \
--arch xlm_base \
--optimizer adam --lr-scheduler reduce_lr_on_plateau \
--lr-shrink 0.5 --lr 0.0001 --stop-min-lr 1e-09 \
--dropout 0.1 \
--criterion legacy_masked_lm_loss \
--max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \
--dataset-impl lazy --seed 0 \
--masked-lm-only \
--monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \
--ddp-backend=legacy_ddp
```
Some Notes:
- Using tokens_per_sample greater than 256 can cause OOM (out-of-memory) issues. Usually since MLM packs in streams of text, this parameter doesn't need much tuning.
- The Evaluation workflow for computing MLM Perplexity on test data is in progress.
- Finetuning this model on a downstream task is something which is not currently available.
# Discriminative Reranking for Neural Machine Translation
https://aclanthology.org/2021.acl-long.563/
This folder contains source code for training DrNMT, a discriminatively trained reranker for neural machine translation.
## Data preparation
1. Follow the instructions under `examples/translation` to build a base MT model. Prepare three files, one with source sentences, one with ground truth target sentences, and one with hypotheses generated from the base MT model. Each line in the file contains one sentence in raw text (i.e. no sentencepiece, etc.). Below is an example of the files with _N_ hypotheses for each source sentence.
```
# Example of the source sentence file: (The file should contain L lines.)
source_sentence_1
source_sentence_2
source_sentence_3
...
source_sentence_L
# Example of the target sentence file: (The file should contain L lines.)
target_sentence_1
target_sentence_2
target_sentence_3
...
target_sentence_L
# Example of the hypotheses file: (The file should contain L*N lines.)
source_sentence_1_hypo_1
source_sentence_1_hypo_2
...
source_sentence_1_hypo_N
source_sentence_2_hypo_1
...
source_sentence_2_hypo_N
...
source_sentence_L_hypo_1
...
source_sentence_L_hypo_N
```
2. Download the [XLMR model](https://github.com/fairinternal/fairseq-py/tree/main/examples/xlmr#pre-trained-models).
```
wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz
tar zxvf xlmr.base.tar.gz
# The folder should contain dict.txt, model.pt and sentencepiece.bpe.model.
```
3. Prepare scores and BPE data.
* `N`: Number of hypotheses per each source sentence. We use 50 in the paper.
* `SPLIT`: Name of the data split, i.e. train, valid, test. Use split_name, split_name1, split_name2, ..., if there are multiple datasets for a split, e.g. train, train1, valid, valid1.
* `NUM_SHARDS`: Number of shards. Set this to 1 for non-train splits.
* `METRIC`: The metric for DrNMT to optimize for. We support either `bleu` or `ter`.
```
# For each data split, e.g. train, valid, test, etc., run the following:
SOURCE_FILE=/path/to/source_sentence_file
TARGET_FILE=/path/to/target_sentence_file
HYPO_FILE=/path/to/hypo_file
XLMR_DIR=/path/to/xlmr
OUTPUT_DIR=/path/to/output
python scripts/prep_data.py \
--input-source ${SOURCE_FILE} \
--input-target ${TARGET_FILE} \
--input-hypo ${HYPO_FILE} \
--output-dir ${OUTPUT_DIR} \
--split $SPLIT
--beam $N \
--sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \
--metric $METRIC \
--num-shards ${NUM_SHARDS}
# The script will create ${OUTPUT_DIR}/$METRIC with ${NUM_SHARDS} splits.
# Under split*/input_src, split*/input_tgt and split*/$METRIC, there will be $SPLIT.bpe and $SPLIT.$METRIC files, respectively.
```
4. Pre-process the data into fairseq format.
```
# use comma to separate if there are more than one train or valid set
for suffix in src tgt ; do
fairseq-preprocess --only-source \
--trainpref ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/train.bpe \
--validpref ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/valid.bpe \
--destdir ${OUTPUT_DIR}/$METRIC/split1/input_${suffix} \
--workers 60 \
--srcdict ${XLMR_DIR}/dict.txt
done
for i in `seq 2 ${NUM_SHARDS}`; do
for suffix in src tgt ; do
fairseq-preprocess --only-source \
--trainpref ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix}/train.bpe \
--destdir ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix} \
--workers 60 \
--srcdict ${XLMR_DIR}/dict.txt
ln -s ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/valid* ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix}/.
done
ln -s ${OUTPUT_DIR}/$METRIC/split1/$METRIC/valid* ${OUTPUT_DIR}/$METRIC/split${i}/$METRIC/.
done
```
## Training
```
EXP_DIR=/path/to/exp
# An example of training the model with the config for De-En experiment in the paper.
# The config uses 16 GPUs and 50 hypotheses.
# For training with fewer number of GPUs, set
# distributed_training.distributed_world_size=k +optimization.update_freq='[x]' where x = 16/k
# For training with fewer number of hypotheses, set
# task.mt_beam=N dataset.batch_size=N dataset.required_batch_size_multiple=N
fairseq-hydra-train -m \
--config-dir config/ --config-name deen \
task.data=${OUTPUT_DIR}/$METRIC/split1/ \
task.num_data_splits=${NUM_SHARDS} \
model.pretrained_model=${XLMR_DIR}/model.pt \
common.user_dir=${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \
checkpoint.save_dir=${EXP_DIR}
```
## Inference & scoring
Perform DrNMT reranking (fw + reranker score)
1. Tune weights on valid sets.
```
# genrate N hypotheses with the base MT model (fw score)
VALID_SOURCE_FILE=/path/to/source_sentences # one sentence per line, converted to the sentencepiece used by the base MT model
VALID_TARGET_FILE=/path/to/target_sentences # one sentence per line in raw text, i.e. no sentencepiece and tokenization
MT_MODEL=/path/to/mt_model
MT_DATA_PATH=/path/to/mt_data
cat ${VALID_SOURCE_FILE} | \
fairseq-interactive ${MT_DATA_PATH} \
--max-tokens 4000 --buffer-size 16 \
--num-workers 32 --path ${MT_MODEL} \
--beam $N --nbest $N \
--post-process sentencepiece &> valid-hypo.out
# replace "bleu" with "ter" to optimize for TER
python drnmt_rerank.py \
${OUTPUT_DIR}/$METRIC/split1/ \
--path ${EXP_DIR}/checkpoint_best.pt \
--in-text valid-hypo.out \
--results-path ${EXP_DIR} \
--gen-subset valid \
--target-text ${VALID_TARGET_FILE} \
--user-dir ${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \
--bpe sentencepiece \
--sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \
--beam $N \
--batch-size $N \
--metric bleu \
--tune
```
2. Apply best weights on test sets
```
# genrate N hypotheses with the base MT model (fw score)
TEST_SOURCE_FILE=/path/to/source_sentences # one sentence per line, converted to the sentencepiece used by the base MT model
cat ${TEST_SOURCE_FILE} | \
fairseq-interactive ${MT_DATA_PATH} \
--max-tokens 4000 --buffer-size 16 \
--num-workers 32 --path ${MT_MODEL} \
--beam $N --nbest $N \
--post-process sentencepiece &> test-hypo.out
# replace "bleu" with "ter" to evaluate TER
# Add --target-text for evaluating BLEU/TER,
# otherwise the script will only generate the hypotheses with the highest scores only.
python drnmt_rerank.py \
${OUTPUT_DIR}/$METRIC/split1/ \
--path ${EXP_DIR}/checkpoint_best.pt \
--in-text test-hypo.out \
--results-path ${EXP_DIR} \
--gen-subset test \
--user-dir ${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \
--bpe sentencepiece \
--sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \
--beam $N \
--batch-size $N \
--metric bleu \
--fw-weight ${BEST_FW_WEIGHT} \
--lenpen ${BEST_LENPEN}
```
## Citation
```bibtex
@inproceedings{lee2021discriminative,
title={Discriminative Reranking for Neural Machine Translation},
author={Lee, Ann and Auli, Michael and Ranzato, Marc'Aurelio},
booktitle={ACL},
year={2021}
}
```
from . import criterions, models, tasks # noqa
# @package _group_
common:
fp16: true
log_format: json
log_interval: 50
seed: 2
checkpoint:
no_epoch_checkpoints: true
best_checkpoint_metric: bleu
maximize_best_checkpoint_metric: true
task:
_name: discriminative_reranking_nmt
data: ???
num_data_splits: ???
include_src: true
mt_beam: 50
eval_target_metric: true
target_metric: bleu
dataset:
batch_size: 50
num_workers: 6
required_batch_size_multiple: 50
valid_subset: ???
criterion:
_name: kl_divergence_rereanking
target_dist_norm: minmax
temperature: 0.5
optimization:
max_epoch: 200
lr: [0.00005]
update_freq: [32]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 8000
total_num_update: 320000
model:
_name: discriminative_nmt_reranker
pretrained_model: ???
classifier_dropout: 0.2
distributed_training:
ddp_backend: no_c10d
distributed_world_size: 16
from .discriminative_reranking_criterion import KLDivergenceRerankingCriterion
__all__ = [
"KLDivergenceRerankingCriterion",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
_EPSILON = torch.finfo(torch.float32).eps
TARGET_DIST_NORM_CHOICES = ChoiceEnum(["none", "minmax"])
@dataclass
class KLDivergenceRerankingCriterionConfig(FairseqDataclass):
target_dist_norm: TARGET_DIST_NORM_CHOICES = field(
default="none",
metadata={"help": "method to normalize the range of target scores"},
)
temperature: float = field(
default=1.0,
metadata={"help": "temperature in softmax for target distributions"},
)
forward_batch_size: int = field(
default=32,
metadata={
"help": "number of hypotheses per batch for model forward (set a value smaller than --mt-beam to avoid OOM when training with a large beam size)"
},
)
@register_criterion(
"kl_divergence_rereanking", dataclass=KLDivergenceRerankingCriterionConfig
)
class KLDivergenceRerankingCriterion(FairseqCriterion):
def __init__(
self, task, target_dist_norm, temperature, forward_batch_size,
):
super().__init__(task)
self.target_dist_norm = target_dist_norm
self.temperature = temperature
self.forward_batch_size = forward_batch_size
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
sample_size = sample["id"].numel()
assert sample_size % self.task.cfg.mt_beam == 0, (
f"sample_size ({sample_size}) cannot be divided by beam size ({self.task.cfg.mt_beam})."
f"Please set --required-batch-size-multiple={self.task.cfg.mt_beam}."
)
# split into smaller batches for model forward
batch_out = []
for i in range(0, sample_size, self.forward_batch_size):
j = min(i + self.forward_batch_size, sample_size)
out = model(
src_tokens=sample["net_input"]["src_tokens"][i:j, :],
src_lengths=sample["net_input"]["src_lengths"][i:j],
)
batch_out.append(
model.sentence_forward(out, sample["net_input"]["src_tokens"][i:j, :])
)
batch_out = torch.cat(batch_out, dim=0).view(
self.task.cfg.mt_beam, sample_size // self.task.cfg.mt_beam, -1
) # T x B x C
if model.joint_classification == "sent":
batch_out = model.joint_forward(batch_out)
scores = model.classification_forward(batch_out.view(sample_size, 1, -1)).view(
-1, self.task.cfg.mt_beam
) # input: B x T x C
loss = self.compute_kl_loss(
scores, sample["target"][:, 0].view(-1, self.task.cfg.mt_beam)
)
sample_size = sample_size // self.task.cfg.mt_beam
logging_output = {
"loss": loss.detach(),
"ntokens": sample["ntokens"],
"nsentences": sample_size * self.task.cfg.mt_beam,
"sample_size": sample_size,
"scores": scores.detach(),
}
return loss, sample_size, logging_output
def compute_kl_loss(self, logits, target):
norm_target = target
if self.target_dist_norm == "minmax":
min_v = torch.min(target, 1, keepdim=True).values
max_v = torch.max(target, 1, keepdim=True).values
norm_target = (target - min_v) / (max_v - min_v + _EPSILON)
target_dist = F.softmax(
norm_target / self.temperature, dim=-1, dtype=torch.float32
)
model_dist = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = -(target_dist * model_dist - target_dist * target_dist.log()).sum()
return loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
loss = loss_sum / sample_size / math.log(2)
metrics.log_scalar("loss", loss, sample_size, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Score raw text with a trained model.
"""
from collections import namedtuple
import logging
from multiprocessing import Pool
import sys
import os
import random
import numpy as np
import sacrebleu
import torch
from fairseq import checkpoint_utils, options, utils
logger = logging.getLogger("fairseq_cli.drnmt_rerank")
logger.setLevel(logging.INFO)
Batch = namedtuple("Batch", "ids src_tokens src_lengths")
pool_init_variables = {}
def init_loaded_scores(mt_scores, model_scores, hyp, ref):
global pool_init_variables
pool_init_variables["mt_scores"] = mt_scores
pool_init_variables["model_scores"] = model_scores
pool_init_variables["hyp"] = hyp
pool_init_variables["ref"] = ref
def parse_fairseq_gen(filename, task):
source = {}
hypos = {}
scores = {}
with open(filename, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("S-"): # source
uid, text = line.split("\t", 1)
uid = int(uid[2:])
source[uid] = text
elif line.startswith("D-"): # hypo
uid, score, text = line.split("\t", 2)
uid = int(uid[2:])
if uid not in hypos:
hypos[uid] = []
scores[uid] = []
hypos[uid].append(text)
scores[uid].append(float(score))
else:
continue
source_out = [source[i] for i in range(len(hypos))]
hypos_out = [h for i in range(len(hypos)) for h in hypos[i]]
scores_out = [s for i in range(len(scores)) for s in scores[i]]
return source_out, hypos_out, scores_out
def read_target(filename):
with open(filename, "r", encoding="utf-8") as f:
output = [line.strip() for line in f]
return output
def make_batches(args, src, hyp, task, max_positions, encode_fn):
assert len(src) * args.beam == len(
hyp
), f"Expect {len(src) * args.beam} hypotheses for {len(src)} source sentences with beam size {args.beam}. Got {len(hyp)} hypotheses intead."
hyp_encode = [
task.source_dictionary.encode_line(encode_fn(h), add_if_not_exist=False).long()
for h in hyp
]
if task.cfg.include_src:
src_encode = [
task.source_dictionary.encode_line(
encode_fn(s), add_if_not_exist=False
).long()
for s in src
]
tokens = [(src_encode[i // args.beam], h) for i, h in enumerate(hyp_encode)]
lengths = [(t1.numel(), t2.numel()) for t1, t2 in tokens]
else:
tokens = [(h,) for h in hyp_encode]
lengths = [(h.numel(),) for h in hyp_encode]
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(tokens, lengths),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=max_positions,
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch(
ids=batch["id"],
src_tokens=batch["net_input"]["src_tokens"],
src_lengths=batch["net_input"]["src_lengths"],
)
def decode_rerank_scores(args):
if args.max_tokens is None and args.batch_size is None:
args.batch_size = 1
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load ensemble
logger.info("loading model(s) from {}".format(args.path))
models, _model_args, task = checkpoint_utils.load_model_ensemble_and_task(
[args.path], arg_overrides=eval(args.model_overrides),
)
for model in models:
if args.fp16:
model.half()
if use_cuda:
model.cuda()
# Initialize generator
generator = task.build_generator(args)
# Handle tokenization and BPE
tokenizer = task.build_tokenizer(args)
bpe = task.build_bpe(args)
def encode_fn(x):
if tokenizer is not None:
x = tokenizer.encode(x)
if bpe is not None:
x = bpe.encode(x)
return x
max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)
src, hyp, mt_scores = parse_fairseq_gen(args.in_text, task)
model_scores = {}
logger.info("decode reranker score")
for batch in make_batches(args, src, hyp, task, max_positions, encode_fn):
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
sample = {
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths},
}
scores = task.inference_step(generator, models, sample)
for id, sc in zip(batch.ids.tolist(), scores.tolist()):
model_scores[id] = sc[0]
model_scores = [model_scores[i] for i in range(len(model_scores))]
return src, hyp, mt_scores, model_scores
def get_score(mt_s, md_s, w1, lp, tgt_len):
return mt_s / (tgt_len ** lp) * w1 + md_s
def get_best_hyps(mt_scores, md_scores, hypos, fw_weight, lenpen, beam):
assert len(mt_scores) == len(md_scores) and len(mt_scores) == len(hypos)
hypo_scores = []
best_hypos = []
best_scores = []
offset = 0
for i in range(len(hypos)):
tgt_len = len(hypos[i].split())
hypo_scores.append(
get_score(mt_scores[i], md_scores[i], fw_weight, lenpen, tgt_len)
)
if (i + 1) % beam == 0:
max_i = np.argmax(hypo_scores)
best_hypos.append(hypos[offset + max_i])
best_scores.append(hypo_scores[max_i])
hypo_scores = []
offset += beam
return best_hypos, best_scores
def eval_metric(args, hypos, ref):
if args.metric == "bleu":
score = sacrebleu.corpus_bleu(hypos, [ref]).score
else:
score = sacrebleu.corpus_ter(hypos, [ref]).score
return score
def score_target_hypo(args, fw_weight, lp):
mt_scores = pool_init_variables["mt_scores"]
model_scores = pool_init_variables["model_scores"]
hyp = pool_init_variables["hyp"]
ref = pool_init_variables["ref"]
best_hypos, _ = get_best_hyps(
mt_scores, model_scores, hyp, fw_weight, lp, args.beam
)
rerank_eval = None
if ref:
rerank_eval = eval_metric(args, best_hypos, ref)
print(f"fw_weight {fw_weight}, lenpen {lp}, eval {rerank_eval}")
return rerank_eval
def print_result(best_scores, best_hypos, output_file):
for i, (s, h) in enumerate(zip(best_scores, best_hypos)):
print(f"{i}\t{s}\t{h}", file=output_file)
def main(args):
utils.import_user_module(args)
src, hyp, mt_scores, model_scores = decode_rerank_scores(args)
assert (
not args.tune or args.target_text is not None
), "--target-text has to be set when tuning weights"
if args.target_text:
ref = read_target(args.target_text)
assert len(src) == len(
ref
), f"different numbers of source and target sentences ({len(src)} vs. {len(ref)})"
orig_best_hypos = [hyp[i] for i in range(0, len(hyp), args.beam)]
orig_eval = eval_metric(args, orig_best_hypos, ref)
if args.tune:
logger.info("tune weights for reranking")
random_params = np.array(
[
[
random.uniform(
args.lower_bound_fw_weight, args.upper_bound_fw_weight
),
random.uniform(args.lower_bound_lenpen, args.upper_bound_lenpen),
]
for k in range(args.num_trials)
]
)
logger.info("launching pool")
with Pool(
32,
initializer=init_loaded_scores,
initargs=(mt_scores, model_scores, hyp, ref),
) as p:
rerank_scores = p.starmap(
score_target_hypo,
[
(args, random_params[i][0], random_params[i][1],)
for i in range(args.num_trials)
],
)
if args.metric == "bleu":
best_index = np.argmax(rerank_scores)
else:
best_index = np.argmin(rerank_scores)
best_fw_weight = random_params[best_index][0]
best_lenpen = random_params[best_index][1]
else:
assert (
args.lenpen is not None and args.fw_weight is not None
), "--lenpen and --fw-weight should be set"
best_fw_weight, best_lenpen = args.fw_weight, args.lenpen
best_hypos, best_scores = get_best_hyps(
mt_scores, model_scores, hyp, best_fw_weight, best_lenpen, args.beam
)
if args.results_path is not None:
os.makedirs(args.results_path, exist_ok=True)
output_path = os.path.join(
args.results_path, "generate-{}.txt".format(args.gen_subset),
)
with open(output_path, "w", buffering=1, encoding="utf-8") as o:
print_result(best_scores, best_hypos, o)
else:
print_result(best_scores, best_hypos, sys.stdout)
if args.target_text:
rerank_eval = eval_metric(args, best_hypos, ref)
print(f"before reranking, {args.metric.upper()}:", orig_eval)
print(
f"after reranking with fw_weight={best_fw_weight}, lenpen={best_lenpen}, {args.metric.upper()}:",
rerank_eval,
)
def cli_main():
parser = options.get_generation_parser(interactive=True)
parser.add_argument(
"--in-text",
default=None,
required=True,
help="text from fairseq-interactive output, containing source sentences and hypotheses",
)
parser.add_argument("--target-text", default=None, help="reference text")
parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu")
parser.add_argument(
"--tune",
action="store_true",
help="if set, tune weights on fw scores and lenpen instead of applying fixed weights for reranking",
)
parser.add_argument(
"--lower-bound-fw-weight",
default=0.0,
type=float,
help="lower bound of search space",
)
parser.add_argument(
"--upper-bound-fw-weight",
default=3,
type=float,
help="upper bound of search space",
)
parser.add_argument(
"--lower-bound-lenpen",
default=0.0,
type=float,
help="lower bound of search space",
)
parser.add_argument(
"--upper-bound-lenpen",
default=3,
type=float,
help="upper bound of search space",
)
parser.add_argument(
"--fw-weight", type=float, default=None, help="weight on the fw model score"
)
parser.add_argument(
"--num-trials",
default=1000,
type=int,
help="number of trials to do for random search",
)
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()
from .discriminative_reranking_model import DiscriminativeNMTReranker
__all__ = [
"DiscriminativeNMTReranker",
]
from dataclasses import dataclass, field
import os
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
BaseFairseqModel,
register_model,
)
from fairseq.models.roberta.model import RobertaClassificationHead
from fairseq.modules import (
LayerNorm,
TransformerSentenceEncoder,
TransformerSentenceEncoderLayer,
)
ACTIVATION_FN_CHOICES = ChoiceEnum(utils.get_available_activation_fns())
JOINT_CLASSIFICATION_CHOICES = ChoiceEnum(["none", "sent"])
SENTENCE_REP_CHOICES = ChoiceEnum(["head", "meanpool", "maxpool"])
def update_init_roberta_model_state(state):
"""
update the state_dict of a Roberta model for initializing
weights of the BertRanker
"""
for k in list(state.keys()):
if ".lm_head." in k or "version" in k:
del state[k]
continue
# remove 'encoder/decoder.sentence_encoder.' from the key
assert k.startswith("encoder.sentence_encoder.") or k.startswith(
"decoder.sentence_encoder."
), f"Cannot recognize parameter name {k}"
if "layernorm_embedding" in k:
new_k = k.replace(".layernorm_embedding.", ".emb_layer_norm.")
state[new_k[25:]] = state[k]
else:
state[k[25:]] = state[k]
del state[k]
class BaseRanker(nn.Module):
def __init__(self, args, task):
super().__init__()
self.separator_token = task.dictionary.eos()
self.padding_idx = task.dictionary.pad()
def forward(self, src_tokens):
raise NotImplementedError
def get_segment_labels(self, src_tokens):
segment_boundary = (src_tokens == self.separator_token).long()
segment_labels = (
segment_boundary.cumsum(dim=1)
- segment_boundary
- (src_tokens == self.padding_idx).long()
)
return segment_labels
def get_positions(self, src_tokens, segment_labels):
segment_positions = (
torch.arange(src_tokens.shape[1])
.to(src_tokens.device)
.repeat(src_tokens.shape[0], 1)
)
segment_boundary = (src_tokens == self.separator_token).long()
_, col_idx = (segment_positions * segment_boundary).nonzero(as_tuple=True)
col_idx = torch.cat([torch.zeros(1).type_as(col_idx), col_idx])
offset = torch.cat(
[
torch.zeros(1).type_as(segment_boundary),
segment_boundary.sum(dim=1).cumsum(dim=0)[:-1],
]
)
segment_positions -= col_idx[segment_labels + offset.unsqueeze(1)] * (
segment_labels != 0
)
padding_mask = src_tokens.ne(self.padding_idx)
segment_positions = (segment_positions + 1) * padding_mask.type_as(
segment_positions
) + self.padding_idx
return segment_positions
class BertRanker(BaseRanker):
def __init__(self, args, task):
super(BertRanker, self).__init__(args, task)
init_model = getattr(args, "pretrained_model", "")
self.joint_layers = nn.ModuleList()
if os.path.isfile(init_model):
print(f"initialize weight from {init_model}")
from fairseq import hub_utils
x = hub_utils.from_pretrained(
os.path.dirname(init_model),
checkpoint_file=os.path.basename(init_model),
)
in_state_dict = x["models"][0].state_dict()
init_args = x["args"].model
num_positional_emb = init_args.max_positions + task.dictionary.pad() + 1
# follow the setup in roberta
self.model = TransformerSentenceEncoder(
padding_idx=task.dictionary.pad(),
vocab_size=len(task.dictionary),
num_encoder_layers=getattr(
args, "encoder_layers", init_args.encoder_layers
),
embedding_dim=init_args.encoder_embed_dim,
ffn_embedding_dim=init_args.encoder_ffn_embed_dim,
num_attention_heads=init_args.encoder_attention_heads,
dropout=init_args.dropout,
attention_dropout=init_args.attention_dropout,
activation_dropout=init_args.activation_dropout,
num_segments=2, # add language embeddings
max_seq_len=num_positional_emb,
offset_positions_by_padding=False,
encoder_normalize_before=True,
apply_bert_init=True,
activation_fn=init_args.activation_fn,
freeze_embeddings=args.freeze_embeddings,
n_trans_layers_to_freeze=args.n_trans_layers_to_freeze,
)
# still need to learn segment embeddings as we added a second language embedding
if args.freeze_embeddings:
for p in self.model.segment_embeddings.parameters():
p.requires_grad = False
update_init_roberta_model_state(in_state_dict)
print("loading weights from the pretrained model")
self.model.load_state_dict(
in_state_dict, strict=False
) # ignore mismatch in language embeddings
ffn_embedding_dim = init_args.encoder_ffn_embed_dim
num_attention_heads = init_args.encoder_attention_heads
dropout = init_args.dropout
attention_dropout = init_args.attention_dropout
activation_dropout = init_args.activation_dropout
activation_fn = init_args.activation_fn
classifier_embed_dim = getattr(
args, "embed_dim", init_args.encoder_embed_dim
)
if classifier_embed_dim != init_args.encoder_embed_dim:
self.transform_layer = nn.Linear(
init_args.encoder_embed_dim, classifier_embed_dim
)
else:
self.model = TransformerSentenceEncoder(
padding_idx=task.dictionary.pad(),
vocab_size=len(task.dictionary),
num_encoder_layers=args.encoder_layers,
embedding_dim=args.embed_dim,
ffn_embedding_dim=args.ffn_embed_dim,
num_attention_heads=args.attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
max_seq_len=task.max_positions()
if task.max_positions()
else args.tokens_per_sample,
num_segments=2,
offset_positions_by_padding=False,
encoder_normalize_before=args.encoder_normalize_before,
apply_bert_init=args.apply_bert_init,
activation_fn=args.activation_fn,
)
classifier_embed_dim = args.embed_dim
ffn_embedding_dim = args.ffn_embed_dim
num_attention_heads = args.attention_heads
dropout = args.dropout
attention_dropout = args.attention_dropout
activation_dropout = args.activation_dropout
activation_fn = args.activation_fn
self.joint_classification = args.joint_classification
if args.joint_classification == "sent":
if args.joint_normalize_before:
self.joint_layer_norm = LayerNorm(classifier_embed_dim)
else:
self.joint_layer_norm = None
self.joint_layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=classifier_embed_dim,
ffn_embedding_dim=ffn_embedding_dim,
num_attention_heads=num_attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
)
for _ in range(args.num_joint_layers)
]
)
self.classifier = RobertaClassificationHead(
classifier_embed_dim,
classifier_embed_dim,
1, # num_classes
"tanh",
args.classifier_dropout,
)
def forward(self, src_tokens, src_lengths):
segment_labels = self.get_segment_labels(src_tokens)
positions = self.get_positions(src_tokens, segment_labels)
inner_states, _ = self.model(
tokens=src_tokens,
segment_labels=segment_labels,
last_state_only=True,
positions=positions,
)
return inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
def sentence_forward(self, encoder_out, src_tokens=None, sentence_rep="head"):
# encoder_out: B x T x C
if sentence_rep == "head":
x = encoder_out[:, :1, :]
else: # 'meanpool', 'maxpool'
assert src_tokens is not None, "meanpool requires src_tokens input"
segment_labels = self.get_segment_labels(src_tokens)
padding_mask = src_tokens.ne(self.padding_idx)
encoder_mask = segment_labels * padding_mask.type_as(segment_labels)
if sentence_rep == "meanpool":
ntokens = torch.sum(encoder_mask, dim=1, keepdim=True)
x = torch.sum(
encoder_out * encoder_mask.unsqueeze(2), dim=1, keepdim=True
) / ntokens.unsqueeze(2).type_as(encoder_out)
else: # 'maxpool'
encoder_out[
(encoder_mask == 0).unsqueeze(2).repeat(1, 1, encoder_out.shape[-1])
] = -float("inf")
x, _ = torch.max(encoder_out, dim=1, keepdim=True)
if hasattr(self, "transform_layer"):
x = self.transform_layer(x)
return x # B x 1 x C
def joint_forward(self, x):
# x: T x B x C
if self.joint_layer_norm:
x = self.joint_layer_norm(x.transpose(0, 1))
x = x.transpose(0, 1)
for layer in self.joint_layers:
x, _ = layer(x, self_attn_padding_mask=None)
return x
def classification_forward(self, x):
# x: B x T x C
return self.classifier(x)
@dataclass
class DiscriminativeNMTRerankerConfig(FairseqDataclass):
pretrained_model: str = field(
default="", metadata={"help": "pretrained model to load"}
)
sentence_rep: SENTENCE_REP_CHOICES = field(
default="head",
metadata={
"help": "method to transform the output of the transformer stack to a sentence-level representation"
},
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
attention_dropout: float = field(
default=0.0, metadata={"help": "dropout probability for attention weights"}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN"}
)
classifier_dropout: float = field(
default=0.0, metadata={"help": "classifier dropout probability"}
)
embed_dim: int = field(default=768, metadata={"help": "embedding dimension"})
ffn_embed_dim: int = field(
default=2048, metadata={"help": "embedding dimension for FFN"}
)
encoder_layers: int = field(default=12, metadata={"help": "num encoder layers"})
attention_heads: int = field(default=8, metadata={"help": "num attention heads"})
encoder_normalize_before: bool = field(
default=False, metadata={"help": "apply layernorm before each encoder block"}
)
apply_bert_init: bool = field(
default=False, metadata={"help": "use custom param initialization for BERT"}
)
activation_fn: ACTIVATION_FN_CHOICES = field(
default="relu", metadata={"help": "activation function to use"}
)
freeze_embeddings: bool = field(
default=False, metadata={"help": "freeze embeddings in the pretrained model"}
)
n_trans_layers_to_freeze: int = field(
default=0,
metadata={
"help": "number of layers to freeze in the pretrained transformer model"
},
)
# joint classfication
joint_classification: JOINT_CLASSIFICATION_CHOICES = field(
default="none",
metadata={"help": "method to compute joint features for classification"},
)
num_joint_layers: int = field(
default=1, metadata={"help": "number of joint layers"}
)
joint_normalize_before: bool = field(
default=False,
metadata={"help": "apply layer norm on the input to the joint layer"},
)
@register_model(
"discriminative_nmt_reranker", dataclass=DiscriminativeNMTRerankerConfig
)
class DiscriminativeNMTReranker(BaseFairseqModel):
@classmethod
def build_model(cls, args, task):
model = BertRanker(args, task)
return DiscriminativeNMTReranker(args, model)
def __init__(self, args, model):
super().__init__()
self.model = model
self.sentence_rep = args.sentence_rep
self.joint_classification = args.joint_classification
def forward(self, src_tokens, src_lengths, **kwargs):
return self.model(src_tokens, src_lengths)
def sentence_forward(self, encoder_out, src_tokens):
return self.model.sentence_forward(encoder_out, src_tokens, self.sentence_rep)
def joint_forward(self, x):
return self.model.joint_forward(x)
def classification_forward(self, x):
return self.model.classification_forward(x)
#!/usr/bin/env python
import argparse
from multiprocessing import Pool
from pathlib import Path
import sacrebleu
import sentencepiece as spm
def read_text_file(filename):
with open(filename, "r") as f:
output = [line.strip() for line in f]
return output
def get_bleu(in_sent, target_sent):
bleu = sacrebleu.corpus_bleu([in_sent], [[target_sent]])
out = " ".join(
map(str, [bleu.score, bleu.sys_len, bleu.ref_len] + bleu.counts + bleu.totals)
)
return out
def get_ter(in_sent, target_sent):
ter = sacrebleu.corpus_ter([in_sent], [[target_sent]])
out = " ".join(map(str, [ter.score, ter.num_edits, ter.ref_length]))
return out
def init(sp_model):
global sp
sp = spm.SentencePieceProcessor()
sp.Load(sp_model)
def process(source_sent, target_sent, hypo_sent, metric):
source_bpe = " ".join(sp.EncodeAsPieces(source_sent))
hypo_bpe = [" ".join(sp.EncodeAsPieces(h)) for h in hypo_sent]
if metric == "bleu":
score_str = [get_bleu(h, target_sent) for h in hypo_sent]
else: # ter
score_str = [get_ter(h, target_sent) for h in hypo_sent]
return source_bpe, hypo_bpe, score_str
def main(args):
assert (
args.split.startswith("train") or args.num_shards == 1
), "--num-shards should be set to 1 for valid and test sets"
assert (
args.split.startswith("train")
or args.split.startswith("valid")
or args.split.startswith("test")
), "--split should be set to train[n]/valid[n]/test[n]"
source_sents = read_text_file(args.input_source)
target_sents = read_text_file(args.input_target)
num_sents = len(source_sents)
assert num_sents == len(
target_sents
), f"{args.input_source} and {args.input_target} should have the same number of sentences."
hypo_sents = read_text_file(args.input_hypo)
assert (
len(hypo_sents) % args.beam == 0
), f"Number of hypotheses ({len(hypo_sents)}) cannot be divided by beam size ({args.beam})."
hypo_sents = [
hypo_sents[i : i + args.beam] for i in range(0, len(hypo_sents), args.beam)
]
assert num_sents == len(
hypo_sents
), f"{args.input_hypo} should contain {num_sents * args.beam} hypotheses but only has {len(hypo_sents) * args.beam}. (--beam={args.beam})"
output_dir = args.output_dir / args.metric
for ns in range(args.num_shards):
print(f"processing shard {ns+1}/{args.num_shards}")
shard_output_dir = output_dir / f"split{ns+1}"
source_output_dir = shard_output_dir / "input_src"
hypo_output_dir = shard_output_dir / "input_tgt"
metric_output_dir = shard_output_dir / args.metric
source_output_dir.mkdir(parents=True, exist_ok=True)
hypo_output_dir.mkdir(parents=True, exist_ok=True)
metric_output_dir.mkdir(parents=True, exist_ok=True)
if args.n_proc > 1:
with Pool(
args.n_proc, initializer=init, initargs=(args.sentencepiece_model,)
) as p:
output = p.starmap(
process,
[
(source_sents[i], target_sents[i], hypo_sents[i], args.metric)
for i in range(ns, num_sents, args.num_shards)
],
)
else:
init(args.sentencepiece_model)
output = [
process(source_sents[i], target_sents[i], hypo_sents[i], args.metric)
for i in range(ns, num_sents, args.num_shards)
]
with open(source_output_dir / f"{args.split}.bpe", "w") as s_o, open(
hypo_output_dir / f"{args.split}.bpe", "w"
) as h_o, open(metric_output_dir / f"{args.split}.{args.metric}", "w") as m_o:
for source_bpe, hypo_bpe, score_str in output:
assert len(hypo_bpe) == len(score_str)
for h, m in zip(hypo_bpe, score_str):
s_o.write(f"{source_bpe}\n")
h_o.write(f"{h}\n")
m_o.write(f"{m}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input-source", type=Path, required=True)
parser.add_argument("--input-target", type=Path, required=True)
parser.add_argument("--input-hypo", type=Path, required=True)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--split", type=str, required=True)
parser.add_argument("--beam", type=int, required=True)
parser.add_argument("--sentencepiece-model", type=str, required=True)
parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu")
parser.add_argument("--num-shards", type=int, default=1)
parser.add_argument("--n-proc", type=int, default=8)
args = parser.parse_args()
main(args)
from .discriminative_reranking_task import DiscriminativeRerankingNMTTask
__all__ = [
"DiscriminativeRerankingNMTTask",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import itertools
import logging
import os
import numpy as np
import torch
from fairseq import metrics
from fairseq.data import (
ConcatDataset,
ConcatSentencesDataset,
data_utils,
Dictionary,
IdDataset,
indexed_dataset,
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
SortDataset,
TruncateDataset,
TokenBlockDataset,
)
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II, MISSING
EVAL_BLEU_ORDER = 4
TARGET_METRIC_CHOICES = ChoiceEnum(["bleu", "ter"])
logger = logging.getLogger(__name__)
@dataclass
class DiscriminativeRerankingNMTConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
num_data_splits: int = field(
default=1, metadata={"help": "total number of data splits"}
)
no_shuffle: bool = field(
default=False, metadata={"help": "do not shuffle training data"}
)
max_positions: int = field(
default=512, metadata={"help": "number of positional embeddings to learn"}
)
include_src: bool = field(
default=False, metadata={"help": "include source sentence"}
)
mt_beam: int = field(default=50, metadata={"help": "beam size of input hypotheses"})
eval_target_metric: bool = field(
default=False,
metadata={"help": "evaluation with the target metric during validation"},
)
target_metric: TARGET_METRIC_CHOICES = field(
default="bleu", metadata={"help": "name of the target metric to optimize for"}
)
train_subset: str = field(
default=II("dataset.train_subset"),
metadata={"help": "data subset to use for training (e.g. train, valid, test)"},
)
seed: int = field(
default=II("common.seed"),
metadata={"help": "pseudo random number generator seed"},
)
class RerankerScorer(object):
"""Scores the target for a given (source (optional), target) input."""
def __init__(self, args, mt_beam):
self.mt_beam = mt_beam
@torch.no_grad()
def generate(self, models, sample, **kwargs):
"""Score a batch of translations."""
net_input = sample["net_input"]
assert len(models) == 1, "does not support model ensemble"
model = models[0]
bs = net_input["src_tokens"].shape[0]
assert (
model.joint_classification == "none" or bs % self.mt_beam == 0
), f"invalid batch size ({bs}) for joint classification with beam size ({self.mt_beam})"
model.eval()
logits = model(**net_input)
batch_out = model.sentence_forward(logits, net_input["src_tokens"])
if model.joint_classification == "sent":
batch_out = model.joint_forward(
batch_out.view(self.mt_beam, bs // self.mt_beam, -1)
)
scores = model.classification_forward(
batch_out.view(bs, 1, -1)
) # input: B x T x C
return scores
@register_task(
"discriminative_reranking_nmt", dataclass=DiscriminativeRerankingNMTConfig
)
class DiscriminativeRerankingNMTTask(FairseqTask):
"""
Translation rerank task.
The input can be either (src, tgt) sentence pairs or tgt sentence only.
"""
cfg: DiscriminativeRerankingNMTConfig
def __init__(self, cfg: DiscriminativeRerankingNMTConfig, data_dictionary=None):
super().__init__(cfg)
self.dictionary = data_dictionary
self._max_positions = cfg.max_positions
# args.tokens_per_sample = self._max_positions
# self.num_classes = 1 # for model
@classmethod
def load_dictionary(cls, cfg, filename):
"""Load the dictionary from the filename"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol("<mask>") # for loading pretrained XLMR model
return dictionary
@classmethod
def setup_task(cls, cfg: DiscriminativeRerankingNMTConfig, **kwargs):
# load data dictionary (assume joint dictionary)
data_path = cfg.data
data_dict = cls.load_dictionary(
cfg, os.path.join(data_path, "input_src/dict.txt")
)
logger.info("[input] src dictionary: {} types".format(len(data_dict)))
return DiscriminativeRerankingNMTTask(cfg, data_dict)
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
if self.cfg.data.endswith("1"):
data_shard = (epoch - 1) % self.cfg.num_data_splits + 1
data_path = self.cfg.data[:-1] + str(data_shard)
else:
data_path = self.cfg.data
def get_path(type, data_split):
return os.path.join(data_path, str(type), data_split)
def make_dataset(type, dictionary, data_split, combine):
split_path = get_path(type, data_split)
dataset = data_utils.load_indexed_dataset(
split_path, dictionary, combine=combine,
)
return dataset
def load_split(data_split, metric):
input_src = None
if self.cfg.include_src:
input_src = make_dataset(
"input_src", self.dictionary, data_split, combine=False
)
assert input_src is not None, "could not find dataset: {}".format(
get_path("input_src", data_split)
)
input_tgt = make_dataset(
"input_tgt", self.dictionary, data_split, combine=False
)
assert input_tgt is not None, "could not find dataset: {}".format(
get_path("input_tgt", data_split)
)
label_path = f"{get_path(metric, data_split)}.{metric}"
assert os.path.exists(label_path), f"could not find dataset: {label_path}"
np_labels = np.loadtxt(label_path)
if self.cfg.target_metric == "ter":
np_labels = -np_labels
label = RawLabelDataset(np_labels)
return input_src, input_tgt, label
src_datasets = []
tgt_datasets = []
label_datasets = []
if split == self.cfg.train_subset:
for k in itertools.count():
split_k = "train" + (str(k) if k > 0 else "")
prefix = os.path.join(data_path, "input_tgt", split_k)
if not indexed_dataset.dataset_exists(prefix, impl=None):
if k > 0:
break
else:
raise FileNotFoundError(f"Dataset not found: {prefix}")
input_src, input_tgt, label = load_split(
split_k, self.cfg.target_metric
)
src_datasets.append(input_src)
tgt_datasets.append(input_tgt)
label_datasets.append(label)
else:
input_src, input_tgt, label = load_split(split, self.cfg.target_metric)
src_datasets.append(input_src)
tgt_datasets.append(input_tgt)
label_datasets.append(label)
if len(tgt_datasets) == 1:
input_tgt, label = tgt_datasets[0], label_datasets[0]
if self.cfg.include_src:
input_src = src_datasets[0]
else:
input_tgt = ConcatDataset(tgt_datasets)
label = ConcatDataset(label_datasets)
if self.cfg.include_src:
input_src = ConcatDataset(src_datasets)
input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions)
if self.cfg.include_src:
input_src = PrependTokenDataset(input_src, self.dictionary.bos())
input_src = TruncateDataset(input_src, self.cfg.max_positions)
src_lengths = NumelDataset(input_src, reduce=False)
src_tokens = ConcatSentencesDataset(input_src, input_tgt)
else:
src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos())
src_lengths = NumelDataset(src_tokens, reduce=False)
dataset = {
"id": IdDataset(),
"net_input": {
"src_tokens": RightPadDataset(
src_tokens, pad_idx=self.source_dictionary.pad(),
),
"src_lengths": src_lengths,
},
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens, reduce=True),
"target": label,
}
dataset = NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],)
assert len(dataset) % self.cfg.mt_beam == 0, (
"dataset size (%d) is not a multiple of beam size (%d)"
% (len(dataset), self.cfg.mt_beam)
)
# no need to shuffle valid/test sets
if not self.cfg.no_shuffle and split == self.cfg.train_subset:
# need to keep all hypothese together
start_idx = np.arange(0, len(dataset), self.cfg.mt_beam)
with data_utils.numpy_seed(self.cfg.seed + epoch):
np.random.shuffle(start_idx)
idx = np.arange(0, self.cfg.mt_beam)
shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile(
start_idx, (self.cfg.mt_beam, 1)
).transpose().reshape(-1)
dataset = SortDataset(dataset, sort_order=[shuffle],)
logger.info(f"Loaded {split} with #samples: {len(dataset)}")
self.datasets[split] = dataset
return self.datasets[split]
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
assert not self.cfg.include_src or len(src_tokens[0]) == 2
input_src = None
if self.cfg.include_src:
input_src = TokenBlockDataset(
[t[0] for t in src_tokens],
[l[0] for l in src_lengths],
block_size=None, # ignored for "eos" break mode
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode="eos",
)
input_src = PrependTokenDataset(input_src, self.dictionary.bos())
input_src = TruncateDataset(input_src, self.cfg.max_positions)
input_tgt = TokenBlockDataset(
[t[-1] for t in src_tokens],
[l[-1] for l in src_lengths],
block_size=None, # ignored for "eos" break mode
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode="eos",
)
input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions)
if self.cfg.include_src:
src_tokens = ConcatSentencesDataset(input_src, input_tgt)
src_lengths = NumelDataset(input_src, reduce=False)
else:
input_tgt = PrependTokenDataset(input_tgt, self.dictionary.bos())
src_tokens = input_tgt
src_lengths = NumelDataset(src_tokens, reduce=False)
dataset = {
"id": IdDataset(),
"net_input": {
"src_tokens": RightPadDataset(
src_tokens, pad_idx=self.source_dictionary.pad(),
),
"src_lengths": src_lengths,
},
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens, reduce=True),
}
return NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],)
def build_model(self, cfg: FairseqDataclass):
return super().build_model(cfg)
def build_generator(self, args):
return RerankerScorer(args, mt_beam=self.cfg.mt_beam)
def max_positions(self):
return self._max_positions
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
def create_dummy_batch(self, device):
dummy_target = (
torch.zeros(self.cfg.mt_beam, EVAL_BLEU_ORDER * 2 + 3).long().to(device)
if not self.cfg.eval_ter
else torch.zeros(self.cfg.mt_beam, 3).long().to(device)
)
return {
"id": torch.zeros(self.cfg.mt_beam, 1).long().to(device),
"net_input": {
"src_tokens": torch.zeros(self.cfg.mt_beam, 4).long().to(device),
"src_lengths": torch.ones(self.cfg.mt_beam, 1).long().to(device),
},
"nsentences": 0,
"ntokens": 0,
"target": dummy_target,
}
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
if ignore_grad and sample is None:
sample = self.create_dummy_batch(model.device)
return super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad
)
def valid_step(self, sample, model, criterion):
if sample is None:
sample = self.create_dummy_batch(model.device)
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if not self.cfg.eval_target_metric:
return loss, sample_size, logging_output
scores = logging_output["scores"]
if self.cfg.target_metric == "bleu":
assert sample["target"].shape[1] == EVAL_BLEU_ORDER * 2 + 3, (
"target does not contain enough information ("
+ str(sample["target"].shape[1])
+ "for evaluating BLEU"
)
max_id = torch.argmax(scores, dim=1)
select_id = max_id + torch.arange(
0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam
).to(max_id.device)
bleu_data = sample["target"][select_id, 1:].sum(0).data
logging_output["_bleu_sys_len"] = bleu_data[0]
logging_output["_bleu_ref_len"] = bleu_data[1]
for i in range(EVAL_BLEU_ORDER):
logging_output["_bleu_counts_" + str(i)] = bleu_data[2 + i]
logging_output["_bleu_totals_" + str(i)] = bleu_data[
2 + EVAL_BLEU_ORDER + i
]
elif self.cfg.target_metric == "ter":
assert sample["target"].shape[1] == 3, (
"target does not contain enough information ("
+ str(sample["target"].shape[1])
+ "for evaluating TER"
)
max_id = torch.argmax(scores, dim=1)
select_id = max_id + torch.arange(
0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam
).to(max_id.device)
ter_data = sample["target"][select_id, 1:].sum(0).data
logging_output["_ter_num_edits"] = -ter_data[0]
logging_output["_ter_ref_len"] = -ter_data[1]
return loss, sample_size, logging_output
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if not self.cfg.eval_target_metric:
return
def sum_logs(key):
return sum(log.get(key, 0) for log in logging_outputs)
if self.cfg.target_metric == "bleu":
counts, totals = [], []
for i in range(EVAL_BLEU_ORDER):
counts.append(sum_logs("_bleu_counts_" + str(i)))
totals.append(sum_logs("_bleu_totals_" + str(i)))
if max(totals) > 0:
# log counts as numpy arrays -- log_scalar will sum them correctly
metrics.log_scalar("_bleu_counts", np.array(counts))
metrics.log_scalar("_bleu_totals", np.array(totals))
metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))
def compute_bleu(meters):
import inspect
import sacrebleu
fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
if "smooth_method" in fn_sig:
smooth = {"smooth_method": "exp"}
else:
smooth = {"smooth": "exp"}
bleu = sacrebleu.compute_bleu(
correct=meters["_bleu_counts"].sum,
total=meters["_bleu_totals"].sum,
sys_len=meters["_bleu_sys_len"].sum,
ref_len=meters["_bleu_ref_len"].sum,
**smooth,
)
return round(bleu.score, 2)
metrics.log_derived("bleu", compute_bleu)
elif self.cfg.target_metric == "ter":
num_edits = sum_logs("_ter_num_edits")
ref_len = sum_logs("_ter_ref_len")
if ref_len > 0:
metrics.log_scalar("_ter_num_edits", num_edits)
metrics.log_scalar("_ter_ref_len", ref_len)
def compute_ter(meters):
score = meters["_ter_num_edits"].sum / meters["_ter_ref_len"].sum
return round(score.item(), 2)
metrics.log_derived("ter", compute_ter)
# Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling
## Introduction
- [Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) introduce a simple and effective noisy channel modeling approach for neural machine translation. However, the noisy channel online decoding approach introduced in this paper is too slow to be practical.
- To address this, [Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 simple approximations to make this approach very fast and practical without much loss in accuracy.
- This README provides intructions on how to run online decoding or generation with the noisy channel modeling approach, including ways to make it very fast without much loss in accuracy.
## Noisy Channel Modeling
[Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) applies the Bayes Rule to predict `P(y|x)`, the probability of the target `y` given the source `x`.
```P(y|x) = P(x|y) * P(y) / P(x)```
- `P(x|y)` predicts the source `x` given the target `y` and is referred to as the **channel model**
- `P(y)` is a **language model** over the target `y`
- `P(x)` is generally not modeled since it is constant for all `y`.
We use Transformer models to parameterize the direct model `P(y|x)`, the channel model `P(x|y)` and the language model `P(y)`.
During online decoding with beam search, we generate the top `K2` candidates per beam and score them with the following linear combination of the channel model, the language model as well as the direct model scores.
```(1 / t) * log(P(y|x) + (1 / s) * ( λ1 * log(P(x|y)) + λ2 * log(P(y) ) )```
- `t` - Target Prefix Length
- `s` - Source Length
- `λ1` - Channel Model Weight
- `λ2` - Language Model Weight
The top `beam_size` candidates based on the above combined scores are chosen to continue the beams in beam search. In beam search with a direct model alone, the scores from the direct model `P(y|x)` are used to choose the top candidates in beam search.
This framework provides a great way to utlize strong target language models trained on large amounts of unlabeled data. Language models can prefer targets unrelated to the source, so we also need a channel model whose role is to ensure that the target preferred by the language model also translates back to the source.
### Training Translation Models and Language Models
For training Transformer models in fairseq for machine translation, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/translation)
For training Transformer models in fairseq for language modeling, refer to instructions [here](https://github.com/pytorch/fairseq/tree/main/examples/language_model)
### Generation with Language Model for German-English translation with fairseq
Here are instructions to generate using a direct model and a target-side language model.
Note:
- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
```sh
binarized_data=data_dir/binarized
direct_model=de_en_seed4.pt
lm_model=en_lm.pt
lm_data=lm_data
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
mkdir -p ${lm_data}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
k2=10
lenpen=0.16
lm_wt=0.14
fairseq-generate ${binarized_data} \
--user-dir examples/fast_noisy_channel \
--beam 5 \
--path ${direct_model} \
--lm-model ${lm_model} \
--lm-data ${lm_data} \
--k2 ${k2} \
--combine-method lm_only \
--task noisy_channel_translation \
--lenpen ${lenpen} \
--lm-wt ${lm_wt} \
--gen-subset valid \
--remove-bpe \
--fp16 \
--batch-size 10
```
### Noisy Channel Generation for German-English translation with fairseq
Here are instructions for noisy channel generation with a direct model, channel model and language model as explained in section [Noisy Channel Modeling](#noisy-channel-modeling).
Note:
- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
```sh
binarized_data=data_dir/binarized
direct_model=de_en_seed4.pt
lm_model=en_lm.pt
lm_data=lm_data
ch_model=en_de.big.seed4.pt
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
mkdir -p ${lm_data}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt -O ${ch_model}
k2=10
lenpen=0.21
lm_wt=0.50
bw_wt=0.30
fairseq-generate ${binarized_data} \
--user-dir examples/fast_noisy_channel \
--beam 5 \
--path ${direct_model} \
--lm-model ${lm_model} \
--lm-data ${lm_data} \
--channel-model ${ch_model} \
--k2 ${k2} \
--combine-method noisy_channel \
--task noisy_channel_translation \
--lenpen ${lenpen} \
--lm-wt ${lm_wt} \
--ch-wt ${bw_wt} \
--gen-subset test \
--remove-bpe \
--fp16 \
--batch-size 1
```
## Fast Noisy Channel Modeling
[Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 approximations that speed up online noisy channel decoding -
- Smaller channel models (`Tranformer Base` with 1 encoder and decoder layer each vs. `Transformer Big`)
- This involves training a channel model that is possibly smaller and less accurate in terms of BLEU than a channel model of the same size as the direct model.
- Since the role of the channel model is mainly to assign low scores to generations from the language model if they don't translate back to the source, we may not need the most accurate channel model for this purpose.
- Smaller output vocabulary size for the channel model (~30,000 -> ~1000)
- The channel model doesn't need to score the full output vocabulary, it just needs to score the source tokens, which are completely known.
- This is specified using the arguments `--channel-scoring-type src_vocab --top-k-vocab 500`
- This means that the output vocabulary for the channel model will be the source tokens for all examples in the batch and the top-K most frequent tokens in the vocabulary
- This reduces the memory consumption needed to store channel model scores significantly
- Smaller number of candidates (`k2`) scored per beam
- This is specified by reducing the argument `--k2`
### Fast Noisy Channel Generation for German-English translation with fairseq
Here are instructions for **fast** noisy channel generation with a direct model, channel model and language model as explained in section [Fast Noisy Channel Modeling](#fast-noisy-channel-modeling). The main differences are that we use a smaller channel model, reduce `--k2`, set `--channel-scoring-type src_vocab --top-k-vocab 500` and increase the `--batch-size`.
Note:
- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq)
- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing)
```sh
binarized_data=data_dir/binarized
direct_model=de_en_seed4.pt
lm_model=en_lm.pt
lm_data=lm_data
small_ch_model=en_de.base_1_1.seed4.pt
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model}
mkdir -p ${lm_data}
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt
wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt -O ${small_ch_model}
k2=3
lenpen=0.23
lm_wt=0.58
bw_wt=0.26
fairseq-generate ${binarized_data} \
--user-dir examples/fast_noisy_channel \
--beam 5 \
--path ${direct_model} \
--lm-model ${lm_model} \
--lm-data ${lm_data} \
--channel-model ${small_ch_model} \
--k2 ${k2} \
--combine-method noisy_channel \
--task noisy_channel_translation \
--lenpen ${lenpen} \
--lm-wt ${lm_wt} \
--ch-wt ${bw_wt} \
--gen-subset test \
--remove-bpe \
--fp16 \
--batch-size 50 \
--channel-scoring-type src_vocab --top-k-vocab 500
```
## Test Data Preprocessing
For preprocessing and binarizing the test sets for Romanian-English and German-English translation, we use the following script -
```sh
FAIRSEQ=/path/to/fairseq
cd $FAIRSEQ
SCRIPTS=$FAIRSEQ/mosesdecoder/scripts
if [ ! -d "${SCRIPTS}" ]; then
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
fi
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
NORMALIZE=$SCRIPTS/tokenizer/normalize-punctuation.perl
s=de
t=en
test=wmt18
mkdir -p data_dir
# Tokenization
if [ $s == "ro" ] ; then
# Note: Get normalise-romanian.py and remove-diacritics.py from
# https://github.com/rsennrich/wmt16-scripts/tree/master/preprocess
sacrebleu -t $test -l $s-$t --echo src | \
$NORMALIZE -l $s | \
python normalise-romanian.py | \
python remove-diacritics.py | \
$TOKENIZER -l $s -a -q > data_dir/$test.$s-$t.$s
else
sacrebleu -t $test -l $s-$t --echo src | perl $NORMALIZE -l $s | perl $TOKENIZER -threads 8 -a -l $s > data_dir/$test.$s-$t.$s
fi
sacrebleu -t $test -l $s-$t --echo ref | perl $NORMALIZE -l $t | perl $TOKENIZER -threads 8 -a -l $t > data_dir/$test.$s-$t.$t
# Applying BPE
src_bpe_code=/path/to/source/language/bpe/code
tgt_bpe_code=/path/to/target/language/bpe/code
src_dict=/path/to/source/language/dict
tgt_dict=/path/to/target/language/dict
FASTBPE=$FAIRSEQ/fastBPE
if [ ! -d "${FASTBPE}" ] ; then
git clone https://github.com/glample/fastBPE.git
# Follow compilation instructions at https://github.com/glample/fastBPE
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
fi
${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${src_bpe_code}
${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${tgt_bpe_code}
fairseq-preprocess -s $s -t $t \
--testpref data_dir/bpe.$test.$s-$t \
--destdir data_dir/binarized \
--srcdict ${src_dict} \
--tgtdict ${tgt_dict}
```
## Calculating BLEU
```sh
DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl
cat ${generation_output} | grep -P "^H" | sort -V | cut -f 3- | $DETOKENIZER -l $t -q -a | sacrebleu -t $test -l $s-$t
```
## Romanian-English Translation
The direct and channel models are trained using bitext data (WMT16) combined with backtranslated data (The monolingual data used for backtranslation comes from http://data.statmt.org/rsennrich/wmt16_backtranslations/ (Sennrich et al., 2016c))
The backtranslated data is generated using an ensemble of 3 English-Romanian models trained on bitext training data (WMT16) with unrestricted sampling.
### BPE Codes and Dictionary
We learn a joint BPE vocabulary of 18K types on the bitext training data which is used for both the source and target.
||Path|
|----------|------|
| BPE Code | [joint_bpe_18k](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/bpe_18k) |
| Dictionary | [dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/dict) |
### Direct Models
For Ro-En with backtranslation, the direct and channel models use a Transformer-Big architecture.
| Seed | Model |
|----|----|
| 2 | [ro_en_seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed2.pt)
| 4 | [ro_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed4.pt)
| 6 | [ro_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed6.pt)
### Channel Models
For channel models, we follow the same steps as for the direct models. But backtranslated data is generated in the opposite direction using [this Romanian monolingual data](http://data.statmt.org/rsennrich/wmt16_backtranslations/).
The best lenpen, LM weight and CH weight are obtained by sweeping over the validation set (wmt16/dev) using beam 5.
| Model Size | Lenpen | LM Weight | CH Weight | Seed 2 | Seed 4 | Seed 6 |
|----|----|----|----|----|----|----|
| `big` | 0.84 | 0.64 | 0.56 | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) |
| `base_1_1` | 0.63 | 0.40 | 0.37 | [base_1_1.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed2.pt) | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed6.pt) |
### Language Model
The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
| | Path |
|----|----|
| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/transformer_lm.pt) |
| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/lm_dict)
## German-English Translation
### BPE Codes and Dictionaries
| | Path|
|----------|------|
| Source BPE Code | [de_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_bpe_code_24K) |
| Target BPE Code | [en_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_bpe_code_24K)
| Source Dictionary | [de_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_dict) |
| Target Dictionary | [en_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_dict) |
### Direct Models
We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
We use the Transformer-Big architecture for the direct model.
| Seed | Model |
|:----:|----|
| 4 | [de_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt)
| 5 | [de_en_seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed5.pt)
| 6 | [de_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed6.pt)
### Channel Models
We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs.
| Model Size | Seed 4 | Seed 5 | Seed 6 |
|----|----|----|----|
| `big` | [big.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt) | [big.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed5.pt) | [big.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed6.pt) |
| `big_1_1` | [big_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed4.pt) | [big_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed5.pt) | [big_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed6.pt) |
| `base` | [base.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed4.pt) | [base.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed5.pt) | [base.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed6.pt) |
| `base_1_1` | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed5.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed6.pt) |
| `half` | [half.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed4.pt) | [half.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed5.pt) | [half.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed6.pt) |
| `half_1_1` | [half_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed4.pt) | [half_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed5.pt) | [half_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed6.pt) |
| `quarter` | [quarter.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed4.pt) | [quarter.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed5.pt) | [quarter.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed6.pt) |
| `quarter_1_1` | [quarter_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed4.pt) | [quarter_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed5.pt) | [quarter_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed6.pt) |
| `8th` | [8th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed4.pt) | [8th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed5.pt) | [8th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed6.pt) |
| `8th_1_1` | [8th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed4.pt) | [8th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed5.pt) | [8th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed6.pt) |
| `16th` | [16th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed4.pt) | [16th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed5.pt) | [16th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed6.pt) |
| `16th_1_1` | [16th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed4.pt) | [16th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed5.pt) | [16th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed6.pt) |
### Language Model
The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization.
| | Path |
|----|----|
| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt) |
| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/)
## Citation
```bibtex
@inproceedings{bhosale2020language,
title={Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling},
author={Shruti Bhosale and Kyra Yee and Sergey Edunov and Michael Auli},
booktitle={Proceedings of the Fifth Conference on Machine Translation (WMT)},
year={2020},
}
@inproceedings{yee2019simple,
title={Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
author={Yee, Kyra and Dauphin, Yann and Auli, Michael},
booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
pages={5700--5705},
year={2019}
}
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import noisy_channel_translation # noqa
from . import noisy_channel_sequence_generator # noqa
from . import noisy_channel_beam_search # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.search import Search
class NoisyChannelBeamSearch(Search):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
self.fw_scores_buf = None
self.lm_scores_buf = None
def _init_buffers(self, t):
# super()._init_buffers(t)
if self.fw_scores_buf is None:
self.scores_buf = t.new()
self.indices_buf = torch.LongTensor().to(device=t.device)
self.beams_buf = torch.LongTensor().to(device=t.device)
self.fw_scores_buf = t.new()
self.lm_scores_buf = t.new()
def combine_fw_bw(self, combine_method, fw_cum, bw, step):
if combine_method == "noisy_channel":
fw_norm = fw_cum.div(step + 1)
lprobs = bw + fw_norm
elif combine_method == "lm_only":
lprobs = bw + fw_cum
return lprobs
def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method):
self._init_buffers(fw_lprobs)
bsz, beam_size, vocab_size = fw_lprobs.size()
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous()
bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous()
# nothing to add since we are at the first step
fw_lprobs_cum = fw_lprobs
else:
# make probs contain cumulative scores for each hypothesis
raw_scores = (scores[:, :, step - 1].unsqueeze(-1))
fw_lprobs_cum = (fw_lprobs.add(raw_scores))
combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step)
# choose the top k according to the combined noisy channel model score
torch.topk(
combined_lprobs.view(bsz, -1),
k=min(
# Take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size * 2,
combined_lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
),
out=(self.scores_buf, self.indices_buf),
)
# save corresponding fw and lm scores
self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf)
self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf)
# Project back into relative indices and beams
self.beams_buf = self.indices_buf // vocab_size
self.indices_buf.fmod_(vocab_size)
return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from .noisy_channel_beam_search import NoisyChannelBeamSearch
from fairseq.sequence_generator import EnsembleModel
class NoisyChannelSequenceGenerator(object):
def __init__(
self,
combine_method,
tgt_dict,
src_dict=None,
beam_size=1,
max_len_a=0,
max_len_b=200,
min_len=1,
len_penalty=1.0,
unk_penalty=0.0,
retain_dropout=False,
temperature=1.0,
match_source_len=False,
no_repeat_ngram_size=0,
normalize_scores=True,
channel_models=None,
k2=10,
ch_weight=1.0,
channel_scoring_type='log_norm',
top_k_vocab=0,
lm_models=None,
lm_dict=None,
lm_weight=1.0,
normalize_lm_scores_by_tgt_len=False,
):
"""Generates translations of a given source sentence,
using beam search with noisy channel decoding.
Args:
combine_method (string, optional): Method to combine direct, LM and
channel model scores (default: None)
tgt_dict (~fairseq.data.Dictionary): target dictionary
src_dict (~fairseq.data.Dictionary): source dictionary
beam_size (int, optional): beam width (default: 1)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
len_penalty (float, optional): length penalty, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
retain_dropout (bool, optional): use dropout when generating
(default: False)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
match_source_len (bool, optional): outputs should match the source
length (default: False)
no_repeat_ngram_size (int, optional): Size of n-grams that we avoid
repeating in the generation (default: 0)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
channel_models (List[~fairseq.models.FairseqModel]): ensemble of models
translating from the target to the source
k2 (int, optional): Top K2 candidates to score per beam at each step (default:10)
ch_weight (int, optional): Weight associated with the channel model score
assuming that the direct model score has weight 1.0 (default: 1.0)
channel_scoring_type (str, optional): String specifying how to score
the channel model (default: 'log_norm')
top_k_vocab (int, optional): If `channel_scoring_type` is `'src_vocab'` or
`'src_vocab_batched'`, then this parameter specifies the number of
most frequent tokens to include in the channel model output vocabulary,
in addition to the source tokens in the input batch (default: 0)
lm_models (List[~fairseq.models.FairseqModel]): ensemble of models
generating text in the target language
lm_dict (~fairseq.data.Dictionary): LM Model dictionary
lm_weight (int, optional): Weight associated with the LM model score
assuming that the direct model score has weight 1.0 (default: 1.0)
normalize_lm_scores_by_tgt_len (bool, optional): Should we normalize LM scores
by the target length? By default, we normalize the combination of
LM and channel model scores by the source length
"""
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(beam_size, self.vocab_size - 1)
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.min_len = min_len
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.temperature = temperature
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
self.channel_models = channel_models
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.combine_method = combine_method
self.k2 = k2
self.ch_weight = ch_weight
self.channel_scoring_type = channel_scoring_type
self.top_k_vocab = top_k_vocab
self.lm_models = lm_models
self.lm_dict = lm_dict
self.lm_weight = lm_weight
self.log_softmax_fn = torch.nn.LogSoftmax(dim=1)
self.normalize_lm_scores_by_tgt_len = normalize_lm_scores_by_tgt_len
self.share_tgt_dict = (self.lm_dict == self.tgt_dict)
self.tgt_to_lm = make_dict2dict(tgt_dict, lm_dict)
self.ch_scoring_bsz = 3072
assert temperature > 0, '--temperature must be greater than 0'
self.search = NoisyChannelBeamSearch(tgt_dict)
@torch.no_grad()
def generate(
self,
models,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
"""Generate a batch of translations.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
"""
model = EnsembleModel(models)
incremental_states = torch.jit.annotate(
List[Dict[str, Dict[str, Optional[Tensor]]]],
[
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
for i in range(model.models_size)
],
)
if not self.retain_dropout:
model.eval()
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample['net_input'].items()
if k != 'prev_output_tokens'
}
src_tokens = encoder_input['src_tokens']
src_lengths_no_eos = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
input_size = src_tokens.size()
# batch dimension goes first followed by source lengths
bsz = input_size[0]
src_len = input_size[1]
beam_size = self.beam_size
if self.match_source_len:
max_len = src_lengths_no_eos.max().item()
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
# exclude the EOS marker
model.max_decoder_positions() - 1,
)
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
src_lengths = encoder_input['src_lengths']
# initialize buffers
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
lm_prefix_scores = src_tokens.new(bsz * beam_size).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos if bos_token is None else bos_token
# reorder source tokens so they may be used as a reference in generating P(S|T)
src_tokens = reorder_all_tokens(src_tokens, src_lengths, self.src_dict.eos_index)
src_tokens = src_tokens.repeat(1, beam_size).view(-1, src_len)
src_lengths = src_lengths.view(bsz, -1).repeat(1, beam_size).view(bsz*beam_size, -1)
attn, attn_buf = None, None
nonpad_idxs = None
# The cands_to_ignore indicates candidates that should be ignored.
# For example, suppose we're sampling and have already finalized 2/5
# samples. Then the cands_to_ignore would mark 2 positions as being ignored,
# so that we only finalize the remaining 3 samples.
cands_to_ignore = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask
# list of completed sentences
finalized = [[] for i in range(bsz)]
finished = [False for i in range(bsz)]
num_remaining_sent = bsz
# number of candidate hypos per step
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly
buffers = {}
def buffer(name, type_of=tokens): # noqa
if name not in buffers:
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent, step, unfin_idx):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
possible score among unfinalized hypotheses.
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
return True
return False
def finalize_hypos(step, bbsz_idx, eos_scores, combined_noisy_channel_eos_scores):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
eos_scores: A vector of the same size as bbsz_idx containing
fw scores for each hypothesis
combined_noisy_channel_eos_scores: A vector of the same size as bbsz_idx containing
combined noisy channel scores for each hypothesis
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
assert not tokens_clone.eq(self.eos).any()
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
combined_noisy_channel_eos_scores /= (step + 1) ** self.len_penalty
cum_unfin = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
sents_seen = set()
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), combined_noisy_channel_eos_scores.tolist())):
unfin_idx = idx // beam_size
sent = unfin_idx + cum_unfin[unfin_idx]
sents_seen.add((sent, unfin_idx))
if self.match_source_len and step > src_lengths_no_eos[unfin_idx]:
score = -math.inf
def get_hypo():
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0)
else:
hypo_attn = None
alignment = None
return {
'tokens': tokens_clone[i],
'score': score,
'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment,
'positional_scores': pos_scores[i],
}
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
newly_finished = []
for sent, unfin_idx in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfin_idx):
finished[sent] = True
newly_finished.append(unfin_idx)
return newly_finished
def noisy_channel_rescoring(lprobs, beam_size, bsz, src_tokens, tokens, k):
"""Rescore the top k hypothesis from each beam using noisy channel modeling
Returns:
new_fw_lprobs: the direct model probabilities after pruning the top k
new_ch_lm_lprobs: the combined channel and language model probabilities
new_lm_lprobs: the language model probabilities after pruning the top k
"""
with torch.no_grad():
lprobs_size = lprobs.size()
if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1).data
).expand(-1, beam_size).contiguous().view(bsz*beam_size, 1)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, beam_size).data.contiguous().view(bsz*beam_size, 1)
# need to calculate and save fw and lm probs for prefix tokens
fw_top_k = cand_scores
fw_top_k_idx = cand_indices
k = 1
else:
# take the top k best words for every sentence in batch*beam
fw_top_k, fw_top_k_idx = torch.topk(lprobs.view(beam_size*bsz, -1), k=k)
eos_idx = torch.nonzero(fw_top_k_idx.view(bsz*beam_size*k, -1) == self.eos)[:, 0]
ch_scores = fw_top_k.new_full((beam_size*bsz*k, ), 0)
src_size = torch.sum(src_tokens[:, :] != self.src_dict.pad_index, dim=1, keepdim=True, dtype=fw_top_k.dtype)
if self.combine_method != "lm_only":
temp_src_tokens_full = src_tokens[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
not_padding = temp_src_tokens_full[:, 1:] != self.src_dict.pad_index
cur_tgt_size = step+2
# add eos to all candidate sentences except those that already end in eos
eos_tokens = tokens[:, 0].repeat(1, k).view(-1, 1)
eos_tokens[eos_idx] = self.tgt_dict.pad_index
if step == 0:
channel_input = torch.cat((fw_top_k_idx.view(-1, 1), eos_tokens), 1)
else:
# move eos from beginning to end of target sentence
channel_input = torch.cat((tokens[:, 1:step + 1].repeat(1, k).view(-1, step), fw_top_k_idx.view(-1, 1), eos_tokens), 1)
ch_input_lengths = torch.tensor(np.full(channel_input.size(0), cur_tgt_size))
ch_input_lengths[eos_idx] = cur_tgt_size-1
if self.channel_scoring_type == "unnormalized":
ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
del ch_encoder_output
ch_intermed_scores = channel_model.decoder.unnormalized_scores_given_target(ch_decoder_output, target_ids=temp_src_tokens_full[:, 1:])
ch_intermed_scores = ch_intermed_scores.float()
ch_intermed_scores *= not_padding.float()
ch_scores = torch.sum(ch_intermed_scores, dim=1)
elif self.channel_scoring_type == "k2_separate":
for k_idx in range(k):
k_eos_tokens = eos_tokens[k_idx::k, :]
if step == 0:
k_ch_input = torch.cat((fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
else:
# move eos from beginning to end of target sentence
k_ch_input = torch.cat((tokens[:, 1:step + 1], fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1)
k_ch_input_lengths = ch_input_lengths[k_idx::k]
k_ch_output = channel_model(k_ch_input, k_ch_input_lengths, src_tokens)
k_ch_lprobs = channel_model.get_normalized_probs(k_ch_output, log_probs=True)
k_ch_intermed_scores = torch.gather(k_ch_lprobs[:, :-1, :], 2, src_tokens[:, 1:].unsqueeze(2)).squeeze(2)
k_ch_intermed_scores *= not_padding.float()
ch_scores[k_idx::k] = torch.sum(k_ch_intermed_scores, dim=1)
elif self.channel_scoring_type == "src_vocab":
ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths)
ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True)
del ch_encoder_output
ch_lprobs = normalized_scores_with_batch_vocab(
channel_model.decoder,
ch_decoder_output, src_tokens, k, bsz, beam_size,
self.src_dict.pad_index, top_k=self.top_k_vocab)
ch_scores = torch.sum(ch_lprobs, dim=1)
elif self.channel_scoring_type == "src_vocab_batched":
ch_bsz_size = temp_src_tokens_full.shape[0]
ch_lprobs_list = [None] * len(range(0, ch_bsz_size, self.ch_scoring_bsz))
for i, start_idx in enumerate(range(0, ch_bsz_size, self.ch_scoring_bsz)):
end_idx = min(start_idx + self.ch_scoring_bsz, ch_bsz_size)
temp_src_tokens_full_batch = temp_src_tokens_full[start_idx:end_idx, :]
channel_input_batch = channel_input[start_idx:end_idx, :]
ch_input_lengths_batch = ch_input_lengths[start_idx:end_idx]
ch_encoder_output_batch = channel_model.encoder(channel_input_batch, src_lengths=ch_input_lengths_batch)
ch_decoder_output_batch, _ = channel_model.decoder(temp_src_tokens_full_batch, encoder_out=ch_encoder_output_batch, features_only=True)
ch_lprobs_list[i] = normalized_scores_with_batch_vocab(
channel_model.decoder,
ch_decoder_output_batch, src_tokens, k, bsz, beam_size,
self.src_dict.pad_index, top_k=self.top_k_vocab,
start_idx=start_idx, end_idx=end_idx)
ch_lprobs = torch.cat(ch_lprobs_list, dim=0)
ch_scores = torch.sum(ch_lprobs, dim=1)
else:
ch_output = channel_model(channel_input, ch_input_lengths, temp_src_tokens_full)
ch_lprobs = channel_model.get_normalized_probs(ch_output, log_probs=True)
ch_intermed_scores = torch.gather(ch_lprobs[:, :-1, :], 2, temp_src_tokens_full[:, 1:].unsqueeze(2)).squeeze().view(bsz*beam_size*k, -1)
ch_intermed_scores *= not_padding.float()
ch_scores = torch.sum(ch_intermed_scores, dim=1)
else:
cur_tgt_size = 0
ch_scores = ch_scores.view(bsz*beam_size, k)
expanded_lm_prefix_scores = lm_prefix_scores.unsqueeze(1).expand(-1, k).flatten()
if self.share_tgt_dict:
lm_scores = get_lm_scores(lm, tokens[:, :step + 1].view(-1, step+1), lm_incremental_states, fw_top_k_idx.view(-1, 1), torch.tensor(np.full(tokens.size(0), step+1)), k)
else:
new_lm_input = dict2dict(tokens[:, :step + 1].view(-1, step+1), self.tgt_to_lm)
new_cands = dict2dict(fw_top_k_idx.view(-1, 1), self.tgt_to_lm)
lm_scores = get_lm_scores(lm, new_lm_input, lm_incremental_states, new_cands, torch.tensor(np.full(tokens.size(0), step+1)), k)
lm_scores.add_(expanded_lm_prefix_scores)
ch_lm_scores = combine_ch_lm(self.combine_method, ch_scores, lm_scores, src_size, cur_tgt_size)
# initialize all as min value
new_fw_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
new_ch_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
new_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1)
new_fw_lprobs[:, self.pad] = -math.inf
new_ch_lm_lprobs[:, self.pad] = -math.inf
new_lm_lprobs[:, self.pad] = -math.inf
new_fw_lprobs.scatter_(1, fw_top_k_idx, fw_top_k)
new_ch_lm_lprobs.scatter_(1, fw_top_k_idx, ch_lm_scores)
new_lm_lprobs.scatter_(1, fw_top_k_idx, lm_scores.view(-1, k))
return new_fw_lprobs, new_ch_lm_lprobs, new_lm_lprobs
def combine_ch_lm(combine_type, ch_scores, lm_scores1, src_size, tgt_size):
if self.channel_scoring_type == "unnormalized":
ch_scores = self.log_softmax_fn(
ch_scores.view(-1, self.beam_size * self.k2)
).view(ch_scores.shape)
ch_scores = ch_scores * self.ch_weight
lm_scores1 = lm_scores1 * self.lm_weight
if combine_type == "lm_only":
# log P(T|S) + log P(T)
ch_scores = lm_scores1.view(ch_scores.size())
elif combine_type == "noisy_channel":
# 1/t log P(T|S) + 1/s log P(S|T) + 1/t log P(T)
if self.normalize_lm_scores_by_tgt_len:
ch_scores.div_(src_size)
lm_scores_norm = lm_scores1.view(ch_scores.size()).div(tgt_size)
ch_scores.add_(lm_scores_norm)
# 1/t log P(T|S) + 1/s log P(S|T) + 1/s log P(T)
else:
ch_scores.add_(lm_scores1.view(ch_scores.size()))
ch_scores.div_(src_size)
return ch_scores
if self.channel_models is not None:
channel_model = self.channel_models[0] # assume only one channel_model model
else:
channel_model = None
lm = EnsembleModel(self.lm_models)
lm_incremental_states = torch.jit.annotate(
List[Dict[str, Dict[str, Optional[Tensor]]]],
[
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
for i in range(lm.models_size)
],
)
reorder_state = None
batch_idxs = None
for step in range(max_len + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
model.reorder_incremental_state(incremental_states, reorder_state)
encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state)
lm.reorder_incremental_state(lm_incremental_states, reorder_state)
fw_lprobs, avg_attn_scores = model.forward_decoder(
tokens[:, :step + 1], encoder_outs, incremental_states, temperature=self.temperature,
)
fw_lprobs[:, self.pad] = -math.inf # never select pad
fw_lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
fw_lprobs, ch_lm_lprobs, lm_lprobs = noisy_channel_rescoring(fw_lprobs, beam_size, bsz, src_tokens, tokens, self.k2)
# handle min and max length constraints
if step >= max_len:
fw_lprobs[:, :self.eos] = -math.inf
fw_lprobs[:, self.eos + 1:] = -math.inf
elif step < self.min_len:
fw_lprobs[:, self.eos] = -math.inf
# handle prefix tokens (possibly with different lengths)
if prefix_tokens is not None and step < prefix_tokens.size(1):
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
prefix_mask = prefix_toks.ne(self.pad)
prefix_fw_lprobs = fw_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
fw_lprobs[prefix_mask] = -math.inf
fw_lprobs[prefix_mask] = fw_lprobs[prefix_mask].scatter_(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_fw_lprobs
)
prefix_ch_lm_lprobs = ch_lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
ch_lm_lprobs[prefix_mask] = -math.inf
ch_lm_lprobs[prefix_mask] = ch_lm_lprobs[prefix_mask].scatter_(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_ch_lm_lprobs
)
prefix_lm_lprobs = lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1))
lm_lprobs[prefix_mask] = -math.inf
lm_lprobs[prefix_mask] = lm_lprobs[prefix_mask].scatter_(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lm_lprobs
)
# if prefix includes eos, then we should make sure tokens and
# scores are the same across all beams
eos_mask = prefix_toks.eq(self.eos)
if eos_mask.any():
# validate that the first beam matches the prefix
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
assert (first_beam == target_prefix).all()
def replicate_first_beam(tensor, mask):
tensor = tensor.view(-1, beam_size, tensor.size(-1))
tensor[mask] = tensor[mask][:, :1, :]
return tensor.view(-1, tensor.size(-1))
# copy tokens, scores and lprobs from the first beam to all beams
tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
scores = replicate_first_beam(scores, eos_mask_batch_dim)
fw_lprobs = replicate_first_beam(fw_lprobs, eos_mask_batch_dim)
ch_lm_lprobs = replicate_first_beam(ch_lm_lprobs, eos_mask_batch_dim)
lm_lprobs = replicate_first_beam(lm_lprobs, eos_mask_batch_dim)
if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
gen_tokens = tokens[bbsz_idx].tolist()
for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
# Record attention scores
if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
scores = scores.type_as(fw_lprobs)
scores_buf = scores_buf.type_as(fw_lprobs)
self.search.set_src_lengths(src_lengths_no_eos)
if self.no_repeat_ngram_size > 0:
def calculate_banned_tokens(bbsz_idx):
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
return gen_ngrams[bbsz_idx].get(ngram_index, [])
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
else:
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
fw_lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
combined_noisy_channel_scores, fw_lprobs_top_k, lm_lprobs_top_k, cand_indices, cand_beams = self.search.step(
step,
fw_lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step], ch_lm_lprobs.view(bsz, -1, self.vocab_size),
lm_lprobs.view(bsz, -1, self.vocab_size), self.combine_method
)
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos (except for candidates to be ignored)
eos_mask = cand_indices.eq(self.eos)
eos_mask[:, :beam_size] &= ~cands_to_ignore
# only consider eos when it's among the top beam_size indices
eos_bbsz_idx = torch.masked_select(
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
)
finalized_sents = set()
if eos_bbsz_idx.numel() > 0:
eos_scores = torch.masked_select(
fw_lprobs_top_k[:, :beam_size], mask=eos_mask[:, :beam_size]
)
combined_noisy_channel_eos_scores = torch.masked_select(
combined_noisy_channel_scores[:, :beam_size],
mask=eos_mask[:, :beam_size],
)
# finalize hypo using channel model score
finalized_sents = finalize_hypos(
step, eos_bbsz_idx, eos_scores, combined_noisy_channel_eos_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = cand_indices.new_ones(bsz)
batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = torch.nonzero(batch_mask).squeeze(-1)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
lm_lprobs_top_k = lm_lprobs_top_k[batch_idxs]
fw_lprobs_top_k = fw_lprobs_top_k[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
src_lengths_no_eos = src_lengths_no_eos[batch_idxs]
cands_to_ignore = cands_to_ignore[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
src_tokens = src_tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
src_lengths = src_lengths.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
lm_prefix_scores = lm_prefix_scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1).squeeze()
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
else:
batch_idxs = None
# Set active_mask so that values > cand_size indicate eos or
# ignored hypos and values < cand_size indicate candidate
# active hypos. After this, the min values per row are the top
# candidate active hypos.
eos_mask[:, :beam_size] |= cands_to_ignore
active_mask = torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[: eos_mask.size(1)],
)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos, new_cands_to_ignore = buffer('active_hypos'), buffer('new_cands_to_ignore')
torch.topk(
active_mask, k=beam_size, dim=1, largest=False,
out=(new_cands_to_ignore, active_hypos)
)
# update cands_to_ignore to ignore any finalized hypos
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
assert (~cands_to_ignore).any(dim=1).all()
active_bbsz_idx = buffer('active_bbsz_idx')
torch.gather(
cand_bbsz_idx, dim=1, index=active_hypos,
out=active_bbsz_idx,
)
active_scores = torch.gather(
fw_lprobs_top_k, dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
torch.index_select(
tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step + 1],
)
torch.gather(
cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)
if step > 0:
torch.index_select(
scores[:, :step], dim=0, index=active_bbsz_idx,
out=scores_buf[:, :step],
)
torch.gather(
fw_lprobs_top_k, dim=1, index=active_hypos,
out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
)
torch.gather(
lm_lprobs_top_k, dim=1, index=active_hypos,
out=lm_prefix_scores.view(bsz, beam_size)
)
# copy attention for active hypotheses
if attn is not None:
torch.index_select(
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step + 2],
)
# swap buffers
tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores
if attn is not None:
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized
def get_lm_scores(model, input_tokens, incremental_states, cand_tokens, input_len, k):
with torch.no_grad():
lm_lprobs, avg_attn_scores = model.forward_decoder(
input_tokens, encoder_outs=None, incremental_states=incremental_states,
)
lm_lprobs_size = lm_lprobs.size(0)
probs_next_wrd = torch.gather(lm_lprobs.repeat(1, k).view(lm_lprobs_size*k, -1), 1, cand_tokens).squeeze().view(-1)
return probs_next_wrd
def make_dict2dict(old_dict, new_dict):
dict2dict_map = {}
for sym in old_dict.symbols:
dict2dict_map[old_dict.index(sym)] = new_dict.index(sym)
return dict2dict_map
def dict2dict(tokens, dict2dict_map):
if tokens.device == torch.device('cpu'):
tokens_tmp = tokens
else:
tokens_tmp = tokens.cpu()
return tokens_tmp.map_(
tokens_tmp,
lambda _, val, dict2dict_map=dict2dict_map : dict2dict_map[float(val)]
).to(tokens.device)
def reorder_tokens(tokens, lengths, eos):
# reorder source tokens so they may be used as reference for P(S|T)
return torch.cat((tokens.new([eos]), tokens[-lengths:-1], tokens[:-lengths]), 0)
def reorder_all_tokens(tokens, lengths, eos):
# used to reorder src tokens from [<pad> <w1> <w2> .. <eos>] to [<eos> <w1> <w2>...<pad>]
# so source tokens can be used to predict P(S|T)
return torch.stack([reorder_tokens(token, length, eos) for token, length in zip(tokens, lengths)])
def normalized_scores_with_batch_vocab(
model_decoder, features, target_ids, k, bsz, beam_size,
pad_idx, top_k=0, vocab_size_meter=None, start_idx=None,
end_idx=None, **kwargs):
"""
Get normalized probabilities (or log probs) from a net's output
w.r.t. vocab consisting of target IDs in the batch
"""
if model_decoder.adaptive_softmax is None:
weight = model_decoder.output_projection.weight
vocab_ids = torch.unique(
torch.cat(
(torch.unique(target_ids), torch.arange(top_k, device=target_ids.device))
)
)
id_map = dict(zip(vocab_ids.tolist(), range(len(vocab_ids))))
mapped_target_ids = target_ids.cpu().apply_(
lambda x, id_map=id_map: id_map[x]
).to(target_ids.device)
expanded_target_ids = mapped_target_ids[:, :].repeat(1, k).view(bsz*beam_size*k, -1)
if start_idx is not None and end_idx is not None:
expanded_target_ids = expanded_target_ids[start_idx:end_idx, :]
logits = F.linear(features, weight[vocab_ids, :])
log_softmax = F.log_softmax(logits, dim=-1, dtype=torch.float32)
intermed_scores = torch.gather(
log_softmax[:, :-1, :],
2,
expanded_target_ids[:, 1:].unsqueeze(2),
).squeeze()
not_padding = expanded_target_ids[:, 1:] != pad_idx
intermed_scores *= not_padding.float()
return intermed_scores
else:
raise ValueError("adaptive softmax doesn't work with " +
"`normalized_scores_with_batch_vocab()`")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment