Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
#!/bin/bash
path_2_data=$1 # <path to data> which contains binarized data for each directions
lang_list=$2 # <path to a file which contains a list of languages separted by new lines>
lang_pairs=$3 #a list language pairs to train multilingual models, e.g. "en-fr,en-cs,fr-en,cs-en"
fairseq-train "$path_2_data" \
--encoder-normalize-before --decoder-normalize-before \
--arch transformer --layernorm-embedding \
--task translation_multi_simple_epoch \
--sampling-method "temperature" \
--sampling-temperature 1.5 \
--encoder-langtok "src" \
--decoder-langtok \
--lang-dict "$lang_list" \
--lang-pairs "$lang_pairs" \
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
--max-tokens 1024 --update-freq 2 \
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
--seed 222 --log-format simple --log-interval 2
# Simple and Effective Noisy Channel Modeling for Neural Machine Translation (Yee et al., 2019)
This page contains pointers to pre-trained models as well as instructions on how to run the reranking scripts.
## Citation:
```bibtex
@inproceedings{yee2019simple,
title = {Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
author = {Kyra Yee and Yann Dauphin and Michael Auli},
booktitle = {Conference on Empirical Methods in Natural Language Processing},
year = {2019},
}
```
## Pre-trained Models:
Model | Description | Download
---|---|---
`transformer.noisychannel.de-en` | De->En Forward Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2)
`transformer.noisychannel.en-de` | En->De Channel Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2)
`transformer_lm.noisychannel.en` | En Language model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2)
Test Data: [newstest_wmt17](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2)
## Example usage
```
mkdir rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2 | tar xvjf - -C rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2 | tar xvjf - -C rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2 | tar xvjf - -C rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2 | tar xvjf - -C rerank_example
beam=50
num_trials=1000
fw_name=fw_model_ex
bw_name=bw_model_ex
lm_name=lm_ex
data_dir=rerank_example/hyphen-splitting-mixed-case-wmt17test-wmt14bpe
data_dir_name=wmt17
lm=rerank_example/lm/checkpoint_best.pt
lm_bpe_code=rerank_example/lm/bpe32k.code
lm_dict=rerank_example/lm/dict.txt
batch_size=32
bw=rerank_example/backward_en2de.pt
fw=rerank_example/forward_de2en.pt
# reranking with P(T|S) P(S|T) and P(T)
python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight1 weight3 \
--lower-bound 0 0 0 --upper-bound 3 3 3 --data-dir-name $data_dir_name \
--num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
-n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw \
--backwards1 --weight2 1 \
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
--model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
# reranking with P(T|S) and P(T)
python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight3 \
--lower-bound 0 0 --upper-bound 3 3 --data-dir-name $data_dir_name \
--num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
-n $beam --batch-size $batch_size --score-model1 $fw \
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
--model1-name $fw_name --gen-model-name $fw_name
# to run with a preconfigured set of hyperparameters for the lenpen and model weights, using rerank.py instead.
python examples/noisychannel/rerank.py $data_dir \
--lenpen 0.269 --weight1 1 --weight2 0.929 --weight3 0.831 \
--data-dir-name $data_dir_name --source-lang de --target-lang en --gen-model $fw \
-n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw --backwards1 \
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
--model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
```
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .rerank_options import * # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from multiprocessing import Pool
import numpy as np
from fairseq import options
from fairseq.data import dictionary
from fairseq.scoring import bleu
from . import (
rerank_generate,
rerank_options,
rerank_score_bw,
rerank_score_lm,
rerank_utils,
)
def score_target_hypo(
args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize
):
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
dict = dictionary.Dictionary()
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
ordered_hypos = {}
ordered_targets = {}
for shard_id in range(len(bitext1_lst)):
bitext1 = bitext1_lst[shard_id]
bitext2 = bitext2_lst[shard_id]
gen_output = gen_output_lst[shard_id]
lm_res = lm_res_lst[shard_id]
total = len(bitext1.rescore_source.keys())
source_lst = []
hypo_lst = []
score_lst = []
reference_lst = []
j = 1
best_score = -math.inf
for i in range(total):
# length is measured in terms of words, not bpe tokens, since models may not share the same bpe
target_len = len(bitext1.rescore_hypo[i].split())
if lm_res is not None:
lm_score = lm_res.score[i]
else:
lm_score = 0
if bitext2 is not None:
bitext2_score = bitext2.rescore_score[i]
bitext2_backwards = bitext2.backwards
else:
bitext2_score = None
bitext2_backwards = None
score = rerank_utils.get_score(
a,
b,
c,
target_len,
bitext1.rescore_score[i],
bitext2_score,
lm_score=lm_score,
lenpen=lenpen,
src_len=bitext1.source_lengths[i],
tgt_len=bitext1.target_lengths[i],
bitext1_backwards=bitext1.backwards,
bitext2_backwards=bitext2_backwards,
normalize=normalize,
)
if score > best_score:
best_score = score
best_hypo = bitext1.rescore_hypo[i]
if j == gen_output.num_hypos[i] or j == args.num_rescore:
j = 1
hypo_lst.append(best_hypo)
score_lst.append(best_score)
source_lst.append(bitext1.rescore_source[i])
reference_lst.append(bitext1.rescore_target[i])
best_score = -math.inf
best_hypo = ""
else:
j += 1
gen_keys = list(sorted(gen_output.no_bpe_target.keys()))
for key in range(len(gen_keys)):
if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
"pred and rescore hypo mismatch: i: "
+ str(key)
+ ", "
+ str(hypo_lst[key])
+ str(gen_keys[key])
+ str(gen_output.no_bpe_hypo[key])
)
sys_tok = dict.encode_line(hypo_lst[key])
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
scorer.add(ref_tok, sys_tok)
else:
full_hypo = rerank_utils.get_full_from_prefix(
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
)
sys_tok = dict.encode_line(full_hypo)
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
scorer.add(ref_tok, sys_tok)
# if only one set of hyper parameters is provided, write the predictions to a file
if write_hypos:
# recover the orinal ids from n best list generation
for key in range(len(gen_output.no_bpe_target)):
if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
"pred and rescore hypo mismatch:"
+ "i:"
+ str(key)
+ str(hypo_lst[key])
+ str(gen_output.no_bpe_hypo[key])
)
ordered_hypos[gen_keys[key]] = hypo_lst[key]
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
gen_keys[key]
]
else:
full_hypo = rerank_utils.get_full_from_prefix(
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
)
ordered_hypos[gen_keys[key]] = full_hypo
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
gen_keys[key]
]
# write the hypos in the original order from nbest list generation
if args.num_shards == (len(bitext1_lst)):
with open(target_outfile, "w") as t:
with open(hypo_outfile, "w") as h:
for key in range(len(ordered_hypos)):
t.write(ordered_targets[key])
h.write(ordered_hypos[key])
res = scorer.result_string(4)
if write_hypos:
print(res)
score = rerank_utils.parse_bleu_scoring(res)
return score
def match_target_hypo(args, target_outfile, hypo_outfile):
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
if len(args.weight1) == 1:
res = score_target_hypo(
args,
args.weight1[0],
args.weight2[0],
args.weight3[0],
args.lenpen[0],
target_outfile,
hypo_outfile,
True,
args.normalize,
)
rerank_scores = [res]
else:
print("launching pool")
with Pool(32) as p:
rerank_scores = p.starmap(
score_target_hypo,
[
(
args,
args.weight1[i],
args.weight2[i],
args.weight3[i],
args.lenpen[i],
target_outfile,
hypo_outfile,
False,
args.normalize,
)
for i in range(len(args.weight1))
],
)
if len(rerank_scores) > 1:
best_index = np.argmax(rerank_scores)
best_score = rerank_scores[best_index]
print("best score", best_score)
print("best lenpen", args.lenpen[best_index])
print("best weight1", args.weight1[best_index])
print("best weight2", args.weight2[best_index])
print("best weight3", args.weight3[best_index])
return (
args.lenpen[best_index],
args.weight1[best_index],
args.weight2[best_index],
args.weight3[best_index],
best_score,
)
else:
return (
args.lenpen[0],
args.weight1[0],
args.weight2[0],
args.weight3[0],
rerank_scores[0],
)
def load_score_files(args):
if args.all_shards:
shard_ids = list(range(args.num_shards))
else:
shard_ids = [args.shard_id]
gen_output_lst = []
bitext1_lst = []
bitext2_lst = []
lm_res1_lst = []
for shard_id in shard_ids:
using_nbest = args.nbest_list is not None
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
rerank1_is_gen = (
args.gen_model == args.score_model1 and args.source_prefix_frac is None
)
rerank2_is_gen = (
args.gen_model == args.score_model2 and args.source_prefix_frac is None
)
score1_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1,
)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2,
)
if args.language_model is not None:
lm_score_file = rerank_utils.rescore_file_name(
pre_gen, args.prefix_len, args.lm_name, lm_file=True
)
# get gen output
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
gen_output = rerank_utils.BitextOutputFromGen(
predictions_bpe_file,
bpe_symbol=args.remove_bpe,
nbest=using_nbest,
prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac,
)
if rerank1_is_gen:
bitext1 = gen_output
else:
bitext1 = rerank_utils.BitextOutput(
score1_file,
args.backwards1,
args.right_to_left1,
args.remove_bpe,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
if args.score_model2 is not None or args.nbest_list is not None:
if rerank2_is_gen:
bitext2 = gen_output
else:
bitext2 = rerank_utils.BitextOutput(
score2_file,
args.backwards2,
args.right_to_left2,
args.remove_bpe,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
assert (
bitext2.source_lengths == bitext1.source_lengths
), "source lengths for rescoring models do not match"
assert (
bitext2.target_lengths == bitext1.target_lengths
), "target lengths for rescoring models do not match"
else:
if args.diff_bpe:
assert args.score_model2 is None
bitext2 = gen_output
else:
bitext2 = None
if args.language_model is not None:
lm_res1 = rerank_utils.LMOutput(
lm_score_file,
args.lm_dict,
args.prefix_len,
args.remove_bpe,
args.target_prefix_frac,
)
else:
lm_res1 = None
gen_output_lst.append(gen_output)
bitext1_lst.append(bitext1)
bitext2_lst.append(bitext2)
lm_res1_lst.append(lm_res1)
return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst
def rerank(args):
if type(args.lenpen) is not list:
args.lenpen = [args.lenpen]
if type(args.weight1) is not list:
args.weight1 = [args.weight1]
if type(args.weight2) is not list:
args.weight2 = [args.weight2]
if type(args.weight3) is not list:
args.weight3 = [args.weight3]
if args.all_shards:
shard_ids = list(range(args.num_shards))
else:
shard_ids = [args.shard_id]
for shard_id in shard_ids:
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
rerank_generate.gen_and_reprocess_nbest(args)
rerank_score_bw.score_bw(args)
rerank_score_lm.score_lm(args)
if args.write_hypos is None:
write_targets = pre_gen + "/matched_targets"
write_hypos = pre_gen + "/matched_hypos"
else:
write_targets = args.write_hypos + "_targets" + args.gen_subset
write_hypos = args.write_hypos + "_hypos" + args.gen_subset
if args.all_shards:
write_targets += "_all_shards"
write_hypos += "_all_shards"
(
best_lenpen,
best_weight1,
best_weight2,
best_weight3,
best_score,
) = match_target_hypo(args, write_targets, write_hypos)
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
def cli_main():
parser = rerank_options.get_reranking_parser()
args = options.parse_args_and_arch(parser)
rerank(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Generate n-best translations using a trained model.
"""
import os
import subprocess
from contextlib import redirect_stdout
from fairseq import options
from fairseq_cli import generate, preprocess
from . import rerank_options, rerank_utils
def gen_and_reprocess_nbest(args):
if args.score_dict_dir is None:
args.score_dict_dir = args.data
if args.prefix_len is not None:
assert (
args.right_to_left1 is False
), "prefix length not compatible with right to left models"
assert (
args.right_to_left2 is False
), "prefix length not compatible with right to left models"
if args.nbest_list is not None:
assert args.score_model2 is None
if args.backwards1:
scorer1_src = args.target_lang
scorer1_tgt = args.source_lang
else:
scorer1_src = args.source_lang
scorer1_tgt = args.target_lang
store_data = (
os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name
)
if not os.path.exists(store_data):
os.makedirs(store_data)
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
args.shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
assert not (
args.right_to_left1 and args.backwards1
), "backwards right to left not supported"
assert not (
args.right_to_left2 and args.backwards2
), "backwards right to left not supported"
assert not (
args.prefix_len is not None and args.target_prefix_frac is not None
), "target prefix frac and target prefix len incompatible"
# make directory to store generation results
if not os.path.exists(pre_gen):
os.makedirs(pre_gen)
rerank1_is_gen = (
args.gen_model == args.score_model1 and args.source_prefix_frac is None
)
rerank2_is_gen = (
args.gen_model == args.score_model2 and args.source_prefix_frac is None
)
if args.nbest_list is not None:
rerank2_is_gen = True
# make directories to store preprossed nbest list for reranking
if not os.path.exists(left_to_right_preprocessed_dir):
os.makedirs(left_to_right_preprocessed_dir)
if not os.path.exists(right_to_left_preprocessed_dir):
os.makedirs(right_to_left_preprocessed_dir)
if not os.path.exists(lm_preprocessed_dir):
os.makedirs(lm_preprocessed_dir)
if not os.path.exists(backwards_preprocessed_dir):
os.makedirs(backwards_preprocessed_dir)
score1_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1,
)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2,
)
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
using_nbest = args.nbest_list is not None
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
else:
if not os.path.isfile(predictions_bpe_file):
print("STEP 1: generate predictions using the p(T|S) model with bpe")
print(args.data)
param1 = [
args.data,
"--path",
args.gen_model,
"--shard-id",
str(args.shard_id),
"--num-shards",
str(args.num_shards),
"--nbest",
str(args.num_rescore),
"--batch-size",
str(args.batch_size),
"--beam",
str(args.num_rescore),
"--batch-size",
str(args.num_rescore),
"--gen-subset",
args.gen_subset,
"--source-lang",
args.source_lang,
"--target-lang",
args.target_lang,
]
if args.sampling:
param1 += ["--sampling"]
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, param1)
print(input_args)
with open(predictions_bpe_file, "w") as f:
with redirect_stdout(f):
generate.main(input_args)
gen_output = rerank_utils.BitextOutputFromGen(
predictions_bpe_file,
bpe_symbol=args.remove_bpe,
nbest=using_nbest,
prefix_len=args.prefix_len,
target_prefix_frac=args.target_prefix_frac,
)
if args.diff_bpe:
rerank_utils.write_reprocessed(
gen_output.no_bpe_source,
gen_output.no_bpe_hypo,
gen_output.no_bpe_target,
pre_gen + "/source_gen_bpe." + args.source_lang,
pre_gen + "/target_gen_bpe." + args.target_lang,
pre_gen + "/reference_gen_bpe." + args.target_lang,
)
bitext_bpe = args.rescore_bpe_code
bpe_src_param = [
"-c",
bitext_bpe,
"--input",
pre_gen + "/source_gen_bpe." + args.source_lang,
"--output",
pre_gen + "/rescore_data." + args.source_lang,
]
bpe_tgt_param = [
"-c",
bitext_bpe,
"--input",
pre_gen + "/target_gen_bpe." + args.target_lang,
"--output",
pre_gen + "/rescore_data." + args.target_lang,
]
subprocess.call(
[
"python",
os.path.join(
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
),
]
+ bpe_src_param,
shell=False,
)
subprocess.call(
[
"python",
os.path.join(
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
),
]
+ bpe_tgt_param,
shell=False,
)
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or (
args.score_model2 is not None
and not os.path.isfile(score2_file)
and not rerank2_is_gen
):
print(
"STEP 2: process the output of generate.py so we have clean text files with the translations"
)
rescore_file = "/rescore_data"
if args.prefix_len is not None:
prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len)
if args.target_prefix_frac is not None:
target_prefix_frac_rescore_file = (
rescore_file + "target_prefix_frac" + str(args.target_prefix_frac)
)
if args.source_prefix_frac is not None:
source_prefix_frac_rescore_file = (
rescore_file + "source_prefix_frac" + str(args.source_prefix_frac)
)
if not args.right_to_left1 or not args.right_to_left2:
if not args.diff_bpe:
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen + rescore_file + "." + args.source_lang,
pre_gen + rescore_file + "." + args.target_lang,
pre_gen + "/reference_file",
bpe_symbol=args.remove_bpe,
)
if args.prefix_len is not None:
bw_rescore_file = prefix_len_rescore_file
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen + prefix_len_rescore_file + "." + args.source_lang,
pre_gen + prefix_len_rescore_file + "." + args.target_lang,
pre_gen + "/reference_file",
prefix_len=args.prefix_len,
bpe_symbol=args.remove_bpe,
)
elif args.target_prefix_frac is not None:
bw_rescore_file = target_prefix_frac_rescore_file
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen
+ target_prefix_frac_rescore_file
+ "."
+ args.source_lang,
pre_gen
+ target_prefix_frac_rescore_file
+ "."
+ args.target_lang,
pre_gen + "/reference_file",
bpe_symbol=args.remove_bpe,
target_prefix_frac=args.target_prefix_frac,
)
else:
bw_rescore_file = rescore_file
if args.source_prefix_frac is not None:
fw_rescore_file = source_prefix_frac_rescore_file
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen
+ source_prefix_frac_rescore_file
+ "."
+ args.source_lang,
pre_gen
+ source_prefix_frac_rescore_file
+ "."
+ args.target_lang,
pre_gen + "/reference_file",
bpe_symbol=args.remove_bpe,
source_prefix_frac=args.source_prefix_frac,
)
else:
fw_rescore_file = rescore_file
if args.right_to_left1 or args.right_to_left2:
rerank_utils.write_reprocessed(
gen_output.source,
gen_output.hypo,
gen_output.target,
pre_gen + "/right_to_left_rescore_data." + args.source_lang,
pre_gen + "/right_to_left_rescore_data." + args.target_lang,
pre_gen + "/right_to_left_reference_file",
right_to_left=True,
bpe_symbol=args.remove_bpe,
)
print("STEP 3: binarize the translations")
if (
not args.right_to_left1
or args.score_model2 is not None
and not args.right_to_left2
or not rerank1_is_gen
):
if args.backwards1 or args.backwards2:
if args.backwards_score_dict_dir is not None:
bw_dict = args.backwards_score_dict_dir
else:
bw_dict = args.score_dict_dir
bw_preprocess_param = [
"--source-lang",
scorer1_src,
"--target-lang",
scorer1_tgt,
"--trainpref",
pre_gen + bw_rescore_file,
"--srcdict",
bw_dict + "/dict." + scorer1_src + ".txt",
"--tgtdict",
bw_dict + "/dict." + scorer1_tgt + ".txt",
"--destdir",
backwards_preprocessed_dir,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(bw_preprocess_param)
preprocess.main(input_args)
preprocess_param = [
"--source-lang",
scorer1_src,
"--target-lang",
scorer1_tgt,
"--trainpref",
pre_gen + fw_rescore_file,
"--srcdict",
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
"--tgtdict",
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
"--destdir",
left_to_right_preprocessed_dir,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args)
if args.right_to_left1 or args.right_to_left2:
preprocess_param = [
"--source-lang",
scorer1_src,
"--target-lang",
scorer1_tgt,
"--trainpref",
pre_gen + "/right_to_left_rescore_data",
"--srcdict",
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
"--tgtdict",
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
"--destdir",
right_to_left_preprocessed_dir,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_param)
preprocess.main(input_args)
return gen_output
def cli_main():
parser = rerank_options.get_reranking_parser()
args = options.parse_args_and_arch(parser)
gen_and_reprocess_nbest(args)
if __name__ == "__main__":
cli_main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import options
def get_reranking_parser(default_task="translation"):
parser = options.get_parser("Generation and reranking", default_task)
add_reranking_args(parser)
return parser
def get_tuning_parser(default_task="translation"):
parser = options.get_parser("Reranking tuning", default_task)
add_reranking_args(parser)
add_tuning_args(parser)
return parser
def add_reranking_args(parser):
group = parser.add_argument_group("Reranking")
# fmt: off
group.add_argument('--score-model1', '-s1', type=str, metavar='FILE', required=True,
help='path to first model or ensemble of models for rescoring')
group.add_argument('--score-model2', '-s2', type=str, metavar='FILE', required=False,
help='path to second model or ensemble of models for rescoring')
group.add_argument('--num-rescore', '-n', type=int, metavar='N', default=10,
help='the number of candidate hypothesis to rescore')
group.add_argument('-bz', '--batch-size', type=int, metavar='N', default=128,
help='batch size for generating the nbest list')
group.add_argument('--gen-subset', default='test', metavar='SET', choices=['test', 'train', 'valid'],
help='data subset to generate (train, valid, test)')
group.add_argument('--gen-model', default=None, metavar='FILE',
help='the model to generate translations')
group.add_argument('-b1', '--backwards1', action='store_true',
help='whether or not the first model group is backwards')
group.add_argument('-b2', '--backwards2', action='store_true',
help='whether or not the second model group is backwards')
group.add_argument('-a', '--weight1', default=1, nargs='+', type=float,
help='the weight(s) of the first model')
group.add_argument('-b', '--weight2', default=1, nargs='+', type=float,
help='the weight(s) of the second model, or the gen model if using nbest from interactive.py')
group.add_argument('-c', '--weight3', default=1, nargs='+', type=float,
help='the weight(s) of the third model')
# lm arguments
group.add_argument('-lm', '--language-model', default=None, metavar='FILE',
help='language model for target language to rescore translations')
group.add_argument('--lm-dict', default=None, metavar='FILE',
help='the dict of the language model for the target language')
group.add_argument('--lm-name', default=None,
help='the name of the language model for the target language')
group.add_argument('--lm-bpe-code', default=None, metavar='FILE',
help='the bpe code for the language model for the target language')
group.add_argument('--data-dir-name', default=None,
help='name of data directory')
group.add_argument('--lenpen', default=1, nargs='+', type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--score-dict-dir', default=None,
help='the directory with dictionaries for the scoring models')
group.add_argument('--right-to-left1', action='store_true',
help='whether the first model group is a right to left model')
group.add_argument('--right-to-left2', action='store_true',
help='whether the second model group is a right to left model')
group.add_argument('--remove-bpe', '--post-process', default='@@ ',
help='the bpe symbol, used for the bitext and LM')
group.add_argument('--prefix-len', default=None, type=int,
help='the length of the target prefix to use in rescoring (in terms of words wo bpe)')
group.add_argument('--sampling', action='store_true',
help='use sampling instead of beam search for generating n best list')
group.add_argument('--diff-bpe', action='store_true',
help='bpe for rescoring and nbest list not the same')
group.add_argument('--rescore-bpe-code', default=None,
help='bpe code for rescoring models')
group.add_argument('--nbest-list', default=None,
help='use predefined nbest list in interactive.py format')
group.add_argument('--write-hypos', default=None,
help='filename prefix to write hypos to')
group.add_argument('--ref-translation', default=None,
help='reference translation to use with nbest list from interactive.py')
group.add_argument('--backwards-score-dict-dir', default=None,
help='the directory with dictionaries for the backwards model,'
'if None then it is assumed the fw and backwards models share dictionaries')
# extra scaling args
group.add_argument('--gen-model-name', default=None,
help='the name of the models that generated the nbest list')
group.add_argument('--model1-name', default=None,
help='the name of the set for model1 group ')
group.add_argument('--model2-name', default=None,
help='the name of the set for model2 group')
group.add_argument('--shard-id', default=0, type=int,
help='the id of the shard to generate')
group.add_argument('--num-shards', default=1, type=int,
help='the number of shards to generate across')
group.add_argument('--all-shards', action='store_true',
help='use all shards')
group.add_argument('--target-prefix-frac', default=None, type=float,
help='the fraction of the target prefix to use in rescoring (in terms of words wo bpe)')
group.add_argument('--source-prefix-frac', default=None, type=float,
help='the fraction of the source prefix to use in rescoring (in terms of words wo bpe)')
group.add_argument('--normalize', action='store_true',
help='whether to normalize by src and target len')
# fmt: on
return group
def add_tuning_args(parser):
group = parser.add_argument_group("Tuning")
group.add_argument(
"--lower-bound",
default=[-0.7],
nargs="+",
type=float,
help="lower bound of search space",
)
group.add_argument(
"--upper-bound",
default=[3],
nargs="+",
type=float,
help="upper bound of search space",
)
group.add_argument(
"--tune-param",
default=["lenpen"],
nargs="+",
choices=["lenpen", "weight1", "weight2", "weight3"],
help="the parameter(s) to tune",
)
group.add_argument(
"--tune-subset",
default="valid",
choices=["valid", "test", "train"],
help="the subset to tune on ",
)
group.add_argument(
"--num-trials",
default=1000,
type=int,
help="number of trials to do for random search",
)
group.add_argument(
"--share-weights", action="store_true", help="share weight2 and weight 3"
)
return group
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from contextlib import redirect_stdout
from fairseq import options
from fairseq_cli import generate
from . import rerank_options, rerank_utils
def score_bw(args):
if args.backwards1:
scorer1_src = args.target_lang
scorer1_tgt = args.source_lang
else:
scorer1_src = args.source_lang
scorer1_tgt = args.target_lang
if args.score_model2 is not None:
if args.backwards2:
scorer2_src = args.target_lang
scorer2_tgt = args.source_lang
else:
scorer2_src = args.source_lang
scorer2_tgt = args.target_lang
rerank1_is_gen = (
args.gen_model == args.score_model1 and args.source_prefix_frac is None
)
rerank2_is_gen = (
args.gen_model == args.score_model2 and args.source_prefix_frac is None
)
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
args.shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
score1_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model1_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards1,
)
if args.score_model2 is not None:
score2_file = rerank_utils.rescore_file_name(
pre_gen,
args.prefix_len,
args.model2_name,
target_prefix_frac=args.target_prefix_frac,
source_prefix_frac=args.source_prefix_frac,
backwards=args.backwards2,
)
if args.right_to_left1:
rerank_data1 = right_to_left_preprocessed_dir
elif args.backwards1:
rerank_data1 = backwards_preprocessed_dir
else:
rerank_data1 = left_to_right_preprocessed_dir
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
if not rerank1_is_gen and not os.path.isfile(score1_file):
print("STEP 4: score the translations for model 1")
model_param1 = [
"--path",
args.score_model1,
"--source-lang",
scorer1_src,
"--target-lang",
scorer1_tgt,
]
gen_model1_param = [rerank_data1] + gen_param + model_param1
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
with open(score1_file, "w") as f:
with redirect_stdout(f):
generate.main(input_args)
if (
args.score_model2 is not None
and not os.path.isfile(score2_file)
and not rerank2_is_gen
):
print("STEP 4: score the translations for model 2")
if args.right_to_left2:
rerank_data2 = right_to_left_preprocessed_dir
elif args.backwards2:
rerank_data2 = backwards_preprocessed_dir
else:
rerank_data2 = left_to_right_preprocessed_dir
model_param2 = [
"--path",
args.score_model2,
"--source-lang",
scorer2_src,
"--target-lang",
scorer2_tgt,
]
gen_model2_param = [rerank_data2] + gen_param + model_param2
gen_parser = options.get_generation_parser()
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
with open(score2_file, "w") as f:
with redirect_stdout(f):
generate.main(input_args)
def cli_main():
parser = rerank_options.get_reranking_parser()
args = options.parse_args_and_arch(parser)
score_bw(args)
if __name__ == "__main__":
cli_main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from fairseq import options
from . import rerank_options, rerank_utils
def score_lm(args):
using_nbest = args.nbest_list is not None
(
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
) = rerank_utils.get_directories(
args.data_dir_name,
args.num_rescore,
args.gen_subset,
args.gen_model_name,
args.shard_id,
args.num_shards,
args.sampling,
args.prefix_len,
args.target_prefix_frac,
args.source_prefix_frac,
)
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
if using_nbest:
print("Using predefined n-best list from interactive.py")
predictions_bpe_file = args.nbest_list
gen_output = rerank_utils.BitextOutputFromGen(
predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest
)
if args.language_model is not None:
lm_score_file = rerank_utils.rescore_file_name(
pre_gen, args.prefix_len, args.lm_name, lm_file=True
)
if args.language_model is not None and not os.path.isfile(lm_score_file):
print("STEP 4.5: language modeling for P(T)")
if args.lm_bpe_code is None:
bpe_status = "no bpe"
elif args.lm_bpe_code == "shared":
bpe_status = "shared"
else:
bpe_status = "different"
rerank_utils.lm_scoring(
lm_preprocessed_dir,
bpe_status,
gen_output,
pre_gen,
args.lm_dict,
args.lm_name,
args.language_model,
args.lm_bpe_code,
128,
lm_score_file,
args.target_lang,
args.source_lang,
prefix_len=args.prefix_len,
)
def cli_main():
parser = rerank_options.get_reranking_parser()
args = options.parse_args_and_arch(parser)
score_lm(args)
if __name__ == "__main__":
cli_main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import random
import numpy as np
from fairseq import options
from . import rerank, rerank_options
def random_search(args):
param_values = []
tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"]
initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
for i, elem in enumerate(initial_params):
if type(elem) is not list:
initial_params[i] = [elem]
else:
initial_params[i] = elem
tune_parameters = args.tune_param.copy()
for i in range(len(args.tune_param)):
assert args.upper_bound[i] >= args.lower_bound[i]
index = tuneable_parameters.index(args.tune_param[i])
del tuneable_parameters[index]
del initial_params[index]
tune_parameters += tuneable_parameters
param_values += initial_params
random.seed(args.seed)
random_params = np.array(
[
[
random.uniform(args.lower_bound[i], args.upper_bound[i])
for i in range(len(args.tune_param))
]
for k in range(args.num_trials)
]
)
set_params = np.array(
[
[initial_params[i][0] for i in range(len(tuneable_parameters))]
for k in range(args.num_trials)
]
)
random_params = np.concatenate((random_params, set_params), 1)
rerank_args = vars(args).copy()
if args.nbest_list:
rerank_args["gen_subset"] = "test"
else:
rerank_args["gen_subset"] = args.tune_subset
for k in range(len(tune_parameters)):
rerank_args[tune_parameters[k]] = list(random_params[:, k])
if args.share_weights:
k = tune_parameters.index("weight2")
rerank_args["weight3"] = list(random_params[:, k])
rerank_args = argparse.Namespace(**rerank_args)
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(
rerank_args
)
rerank_args = vars(args).copy()
rerank_args["lenpen"] = [best_lenpen]
rerank_args["weight1"] = [best_weight1]
rerank_args["weight2"] = [best_weight2]
rerank_args["weight3"] = [best_weight3]
# write the hypothesis from the valid set from the best trial
if args.gen_subset != "valid":
rerank_args["gen_subset"] = "valid"
rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args)
# test with the best hyperparameters on gen subset
rerank_args = vars(args).copy()
rerank_args["gen_subset"] = args.gen_subset
rerank_args["lenpen"] = [best_lenpen]
rerank_args["weight1"] = [best_weight1]
rerank_args["weight2"] = [best_weight2]
rerank_args["weight3"] = [best_weight3]
rerank_args = argparse.Namespace(**rerank_args)
rerank.rerank(rerank_args)
def cli_main():
parser = rerank_options.get_tuning_parser()
args = options.parse_args_and_arch(parser)
random_search(args)
if __name__ == "__main__":
cli_main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import os
import re
import subprocess
from contextlib import redirect_stdout
from fairseq import options
from fairseq_cli import eval_lm, preprocess
def reprocess(fle):
# takes in a file of generate.py translation generate_output
# returns a source dict and hypothesis dict, where keys are the ID num (as a string)
# and values and the corresponding source and translation. There may be several translations
# per source, so the values for hypothesis_dict are lists.
# parses output of generate.py
with open(fle, "r") as f:
txt = f.read()
"""reprocess generate.py output"""
p = re.compile(r"[STHP][-]\d+\s*")
hp = re.compile(r"(\s*[-]?\d+[.]?\d+\s*)|(\s*(-inf)\s*)")
source_dict = {}
hypothesis_dict = {}
score_dict = {}
target_dict = {}
pos_score_dict = {}
lines = txt.split("\n")
for line in lines:
line += "\n"
prefix = re.search(p, line)
if prefix is not None:
assert len(prefix.group()) > 2, "prefix id not found"
_, j = prefix.span()
id_num = prefix.group()[2:]
id_num = int(id_num)
line_type = prefix.group()[0]
if line_type == "H":
h_txt = line[j:]
hypo = re.search(hp, h_txt)
assert (
hypo is not None
), "regular expression failed to find the hypothesis scoring"
_, i = hypo.span()
score = hypo.group()
if id_num in hypothesis_dict:
hypothesis_dict[id_num].append(h_txt[i:])
score_dict[id_num].append(float(score))
else:
hypothesis_dict[id_num] = [h_txt[i:]]
score_dict[id_num] = [float(score)]
elif line_type == "S":
source_dict[id_num] = line[j:]
elif line_type == "T":
target_dict[id_num] = line[j:]
elif line_type == "P":
pos_scores = (line[j:]).split()
pos_scores = [float(x) for x in pos_scores]
if id_num in pos_score_dict:
pos_score_dict[id_num].append(pos_scores)
else:
pos_score_dict[id_num] = [pos_scores]
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
def reprocess_nbest(fle):
"""reprocess interactive.py output"""
with open(fle, "r") as f:
txt = f.read()
source_dict = {}
hypothesis_dict = {}
score_dict = {}
target_dict = {}
pos_score_dict = {}
lines = txt.split("\n")
hp = re.compile(r"[-]?\d+[.]?\d+")
j = -1
for _i, line in enumerate(lines):
line += "\n"
line_type = line[0]
if line_type == "H":
hypo = re.search(hp, line)
_, start_index = hypo.span()
score = hypo.group()
if j in score_dict:
score_dict[j].append(float(score))
hypothesis_dict[j].append(line[start_index:].strip("\t"))
else:
score_dict[j] = [float(score)]
hypothesis_dict[j] = [line[start_index:].strip("\t")]
elif line_type == "O":
j += 1
source_dict[j] = line[2:]
# we don't have the targets for interactive.py
target_dict[j] = "filler"
elif line_type == "P":
pos_scores = [float(pos_score) for pos_score in line.split()[1:]]
if j in pos_score_dict:
pos_score_dict[j].append(pos_scores)
else:
pos_score_dict[j] = [pos_scores]
assert source_dict.keys() == hypothesis_dict.keys()
assert source_dict.keys() == pos_score_dict.keys()
assert source_dict.keys() == score_dict.keys()
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
def write_reprocessed(
sources,
hypos,
targets,
source_outfile,
hypo_outfile,
target_outfile,
right_to_left=False,
prefix_len=None,
bpe_symbol=None,
target_prefix_frac=None,
source_prefix_frac=None,
):
"""writes nbest hypothesis for rescoring"""
assert not (
prefix_len is not None and target_prefix_frac is not None
), "in writing reprocessed, only one type of prefix may be used"
assert not (
prefix_len is not None and source_prefix_frac is not None
), "in writing reprocessed, only one type of prefix may be used"
assert not (
target_prefix_frac is not None and source_prefix_frac is not None
), "in writing reprocessed, only one type of prefix may be used"
with open(source_outfile, "w") as source_file, open(
hypo_outfile, "w"
) as hypo_file, open(target_outfile, "w") as target_file:
assert len(sources) == len(hypos), "sources and hypos list length mismatch"
if right_to_left:
for i in range(len(sources)):
for j in range(len(hypos[i])):
if prefix_len is None:
hypo_file.write(make_right_to_left(hypos[i][j]) + "\n")
else:
raise NotImplementedError()
source_file.write(make_right_to_left(sources[i]) + "\n")
target_file.write(make_right_to_left(targets[i]) + "\n")
else:
for i in sorted(sources.keys()):
for j in range(len(hypos[i])):
if prefix_len is not None:
shortened = (
get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)
+ "\n"
)
hypo_file.write(shortened)
source_file.write(sources[i])
target_file.write(targets[i])
elif target_prefix_frac is not None:
num_words, shortened, num_bpe_tokens = calc_length_from_frac(
hypos[i][j], target_prefix_frac, bpe_symbol
)
shortened += "\n"
hypo_file.write(shortened)
source_file.write(sources[i])
target_file.write(targets[i])
elif source_prefix_frac is not None:
num_words, shortened, num_bpe_tokensn = calc_length_from_frac(
sources[i], source_prefix_frac, bpe_symbol
)
shortened += "\n"
hypo_file.write(hypos[i][j])
source_file.write(shortened)
target_file.write(targets[i])
else:
hypo_file.write(hypos[i][j])
source_file.write(sources[i])
target_file.write(targets[i])
def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
# return number of words, (not bpe tokens) that we want
no_bpe_sen = remove_bpe(bpe_sentence, bpe_symbol)
len_sen = len(no_bpe_sen.split())
num_words = math.ceil(len_sen * prefix_frac)
prefix = get_prefix_no_bpe(bpe_sentence, bpe_symbol, num_words)
num_bpe_tokens = len(prefix.split())
return num_words, prefix, num_bpe_tokens
def get_prefix(sentence, prefix_len):
"""assuming no bpe, gets the prefix of the sentence with prefix_len words"""
tokens = sentence.strip("\n").split()
if prefix_len >= len(tokens):
return sentence.strip("\n")
else:
return " ".join(tokens[:prefix_len])
def get_prefix_no_bpe(sentence, bpe_symbol, prefix_len):
if bpe_symbol is None:
return get_prefix(sentence, prefix_len)
else:
return " ".join(get_prefix_from_len(sentence.split(), bpe_symbol, prefix_len))
def get_prefix_from_len(sentence, bpe_symbol, prefix_len):
"""get the prefix of sentence with bpe, with prefix len in terms of words, not bpe tokens"""
bpe_count = sum([bpe_symbol.strip(" ") in t for t in sentence[:prefix_len]])
if bpe_count == 0:
return sentence[:prefix_len]
else:
return sentence[:prefix_len] + get_prefix_from_len(
sentence[prefix_len:], bpe_symbol, bpe_count
)
def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
"""given a prefix length in terms of words, return the number of bpe tokens"""
prefix = get_prefix_no_bpe(sentence, bpe_symbol, prefix_len)
assert len(remove_bpe(prefix, bpe_symbol).split()) <= prefix_len
return len(prefix.split(" "))
def make_right_to_left(line):
tokens = line.split()
tokens.reverse()
new_line = " ".join(tokens)
return new_line
def remove_bpe(line, bpe_symbol):
line = line.replace("\n", "")
line = (line + " ").replace(bpe_symbol, "").rstrip()
return line + ("\n")
def remove_bpe_dict(pred_dict, bpe_symbol):
new_dict = {}
for i in pred_dict:
if type(pred_dict[i]) == list:
new_list = [remove_bpe(elem, bpe_symbol) for elem in pred_dict[i]]
new_dict[i] = new_list
else:
new_dict[i] = remove_bpe(pred_dict[i], bpe_symbol)
return new_dict
def parse_bleu_scoring(line):
p = re.compile(r"(BLEU4 = )\d+[.]\d+")
res = re.search(p, line)
assert res is not None, line
return float(res.group()[8:])
def get_full_from_prefix(hypo_prefix, hypos):
"""given a hypo prefix, recover the first hypo from the list of complete hypos beginning with that prefix"""
for hypo in hypos:
hypo_prefix = hypo_prefix.strip("\n")
len_prefix = len(hypo_prefix)
if hypo[:len_prefix] == hypo_prefix:
return hypo
# no match found
raise Exception()
def get_score(
a,
b,
c,
target_len,
bitext_score1,
bitext_score2=None,
lm_score=None,
lenpen=None,
src_len=None,
tgt_len=None,
bitext1_backwards=False,
bitext2_backwards=False,
normalize=False,
):
if bitext1_backwards:
bitext1_norm = src_len
else:
bitext1_norm = tgt_len
if bitext_score2 is not None:
if bitext2_backwards:
bitext2_norm = src_len
else:
bitext2_norm = tgt_len
else:
bitext2_norm = 1
bitext_score2 = 0
if normalize:
score = (
a * bitext_score1 / bitext1_norm
+ b * bitext_score2 / bitext2_norm
+ c * lm_score / src_len
)
else:
score = a * bitext_score1 + b * bitext_score2 + c * lm_score
if lenpen is not None:
score /= (target_len) ** float(lenpen)
return score
class BitextOutput(object):
def __init__(
self,
output_file,
backwards,
right_to_left,
bpe_symbol,
prefix_len=None,
target_prefix_frac=None,
source_prefix_frac=None,
):
"""process output from rescoring"""
source, hypo, score, target, pos_score = reprocess(output_file)
if backwards:
self.hypo_fracs = source_prefix_frac
else:
self.hypo_fracs = target_prefix_frac
# remove length penalty so we can use raw scores
score, num_bpe_tokens = get_score_from_pos(
pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards
)
source_lengths = {}
target_lengths = {}
assert hypo.keys() == source.keys(), "key mismatch"
if backwards:
tmp = hypo
hypo = source
source = tmp
for i in source:
# since we are reranking, there should only be one hypo per source sentence
if backwards:
len_src = len(source[i][0].split())
# record length without <eos>
if len_src == num_bpe_tokens[i][0] - 1:
source_lengths[i] = num_bpe_tokens[i][0] - 1
else:
source_lengths[i] = num_bpe_tokens[i][0]
target_lengths[i] = len(hypo[i].split())
source[i] = remove_bpe(source[i][0], bpe_symbol)
target[i] = remove_bpe(target[i], bpe_symbol)
hypo[i] = remove_bpe(hypo[i], bpe_symbol)
score[i] = float(score[i][0])
pos_score[i] = pos_score[i][0]
else:
len_tgt = len(hypo[i][0].split())
# record length without <eos>
if len_tgt == num_bpe_tokens[i][0] - 1:
target_lengths[i] = num_bpe_tokens[i][0] - 1
else:
target_lengths[i] = num_bpe_tokens[i][0]
source_lengths[i] = len(source[i].split())
if right_to_left:
source[i] = remove_bpe(make_right_to_left(source[i]), bpe_symbol)
target[i] = remove_bpe(make_right_to_left(target[i]), bpe_symbol)
hypo[i] = remove_bpe(make_right_to_left(hypo[i][0]), bpe_symbol)
score[i] = float(score[i][0])
pos_score[i] = pos_score[i][0]
else:
assert (
len(hypo[i]) == 1
), "expected only one hypothesis per source sentence"
source[i] = remove_bpe(source[i], bpe_symbol)
target[i] = remove_bpe(target[i], bpe_symbol)
hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
score[i] = float(score[i][0])
pos_score[i] = pos_score[i][0]
self.rescore_source = source
self.rescore_hypo = hypo
self.rescore_score = score
self.rescore_target = target
self.rescore_pos_score = pos_score
self.backwards = backwards
self.right_to_left = right_to_left
self.target_lengths = target_lengths
self.source_lengths = source_lengths
class BitextOutputFromGen(object):
def __init__(
self,
predictions_bpe_file,
bpe_symbol=None,
nbest=False,
prefix_len=None,
target_prefix_frac=None,
):
if nbest:
(
pred_source,
pred_hypo,
pred_score,
pred_target,
pred_pos_score,
) = reprocess_nbest(predictions_bpe_file)
else:
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(
predictions_bpe_file
)
assert len(pred_source) == len(pred_hypo)
assert len(pred_source) == len(pred_score)
assert len(pred_source) == len(pred_target)
assert len(pred_source) == len(pred_pos_score)
# remove length penalty so we can use raw scores
pred_score, num_bpe_tokens = get_score_from_pos(
pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False
)
self.source = pred_source
self.target = pred_target
self.score = pred_score
self.pos_score = pred_pos_score
self.hypo = pred_hypo
self.target_lengths = {}
self.source_lengths = {}
self.no_bpe_source = remove_bpe_dict(pred_source.copy(), bpe_symbol)
self.no_bpe_hypo = remove_bpe_dict(pred_hypo.copy(), bpe_symbol)
self.no_bpe_target = remove_bpe_dict(pred_target.copy(), bpe_symbol)
# indexes to match those from the rescoring models
self.rescore_source = {}
self.rescore_target = {}
self.rescore_pos_score = {}
self.rescore_hypo = {}
self.rescore_score = {}
self.num_hypos = {}
self.backwards = False
self.right_to_left = False
index = 0
for i in sorted(pred_source.keys()):
for j in range(len(pred_hypo[i])):
self.target_lengths[index] = len(self.hypo[i][j].split())
self.source_lengths[index] = len(self.source[i].split())
self.rescore_source[index] = self.no_bpe_source[i]
self.rescore_target[index] = self.no_bpe_target[i]
self.rescore_hypo[index] = self.no_bpe_hypo[i][j]
self.rescore_score[index] = float(pred_score[i][j])
self.rescore_pos_score[index] = pred_pos_score[i][j]
self.num_hypos[index] = len(pred_hypo[i])
index += 1
def get_score_from_pos(
pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards
):
score_dict = {}
num_bpe_tokens_dict = {}
assert prefix_len is None or hypo_frac is None
for key in pos_score_dict:
score_dict[key] = []
num_bpe_tokens_dict[key] = []
for i in range(len(pos_score_dict[key])):
if prefix_len is not None and not backwards:
num_bpe_tokens = get_num_bpe_tokens_from_len(
hypo_dict[key][i], bpe_symbol, prefix_len
)
score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
num_bpe_tokens_dict[key].append(num_bpe_tokens)
elif hypo_frac is not None:
num_words, shortened, hypo_prefix_len = calc_length_from_frac(
hypo_dict[key][i], hypo_frac, bpe_symbol
)
score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
num_bpe_tokens_dict[key].append(hypo_prefix_len)
else:
score_dict[key].append(sum(pos_score_dict[key][i]))
num_bpe_tokens_dict[key].append(len(pos_score_dict[key][i]))
return score_dict, num_bpe_tokens_dict
class LMOutput(object):
def __init__(
self,
lm_score_file,
lm_dict=None,
prefix_len=None,
bpe_symbol=None,
target_prefix_frac=None,
):
(
lm_sentences,
lm_sen_scores,
lm_sen_pos_scores,
lm_no_bpe_sentences,
lm_bpe_tokens,
) = parse_lm(
lm_score_file,
prefix_len=prefix_len,
bpe_symbol=bpe_symbol,
target_prefix_frac=target_prefix_frac,
)
self.sentences = lm_sentences
self.score = lm_sen_scores
self.pos_score = lm_sen_pos_scores
self.lm_dict = lm_dict
self.no_bpe_sentences = lm_no_bpe_sentences
self.bpe_tokens = lm_bpe_tokens
def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
"""parse output of eval_lm"""
with open(input_file, "r") as f:
text = f.readlines()
text = text[7:]
cleaned_text = text[:-2]
sentences = {}
sen_scores = {}
sen_pos_scores = {}
no_bpe_sentences = {}
num_bpe_tokens_dict = {}
for _i, line in enumerate(cleaned_text):
tokens = line.split()
if tokens[0].isdigit():
line_id = int(tokens[0])
scores = [float(x[1:-1]) for x in tokens[2::2]]
sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n"
if bpe_symbol is not None:
# exclude <eos> symbol to match output from generate.py
bpe_sen = " ".join(tokens[1::2][:-1]) + "\n"
no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
no_bpe_sentences[line_id] = no_bpe_sen
if prefix_len is not None:
num_bpe_tokens = get_num_bpe_tokens_from_len(
bpe_sen, bpe_symbol, prefix_len
)
sen_scores[line_id] = sum(scores[:num_bpe_tokens])
num_bpe_tokens_dict[line_id] = num_bpe_tokens
elif target_prefix_frac is not None:
num_words, shortened, target_prefix_len = calc_length_from_frac(
bpe_sen, target_prefix_frac, bpe_symbol
)
sen_scores[line_id] = sum(scores[:target_prefix_len])
num_bpe_tokens_dict[line_id] = target_prefix_len
else:
sen_scores[line_id] = sum(scores)
num_bpe_tokens_dict[line_id] = len(scores)
sen_pos_scores[line_id] = scores
return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
def get_directories(
data_dir_name,
num_rescore,
gen_subset,
fw_name,
shard_id,
num_shards,
sampling=False,
prefix_len=None,
target_prefix_frac=None,
source_prefix_frac=None,
):
nbest_file_id = (
"nbest_"
+ str(num_rescore)
+ "_subset_"
+ gen_subset
+ "_fw_name_"
+ fw_name
+ "_shard_"
+ str(shard_id)
+ "_of_"
+ str(num_shards)
)
if sampling:
nbest_file_id += "_sampling"
# the directory containing all information for this nbest list
pre_gen = (
os.path.join(os.path.dirname(__file__))
+ "/rerank_data/"
+ data_dir_name
+ "/"
+ nbest_file_id
)
# the directory to store the preprocessed nbest list, for left to right rescoring
left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed"
if source_prefix_frac is not None:
left_to_right_preprocessed_dir = (
left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
)
# the directory to store the preprocessed nbest list, for right to left rescoring
right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed"
# the directory to store the preprocessed nbest list, for backwards rescoring
backwards_preprocessed_dir = pre_gen + "/backwards"
if target_prefix_frac is not None:
backwards_preprocessed_dir = (
backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac)
)
elif prefix_len is not None:
backwards_preprocessed_dir = (
backwards_preprocessed_dir + "/prefix_" + str(prefix_len)
)
# the directory to store the preprocessed nbest list, for rescoring with P(T)
lm_preprocessed_dir = pre_gen + "/lm_preprocessed"
return (
pre_gen,
left_to_right_preprocessed_dir,
right_to_left_preprocessed_dir,
backwards_preprocessed_dir,
lm_preprocessed_dir,
)
def lm_scoring(
preprocess_directory,
bpe_status,
gen_output,
pre_gen,
cur_lm_dict,
cur_lm_name,
cur_language_model,
cur_lm_bpe_code,
batch_size,
lm_score_file,
target_lang,
source_lang,
prefix_len=None,
):
if prefix_len is not None:
assert (
bpe_status == "different"
), "bpe status must be different to use prefix len"
if bpe_status == "no bpe":
# run lm on output without bpe
write_reprocessed(
gen_output.no_bpe_source,
gen_output.no_bpe_hypo,
gen_output.no_bpe_target,
pre_gen + "/rescore_data_no_bpe.de",
pre_gen + "/rescore_data_no_bpe.en",
pre_gen + "/reference_file_no_bpe",
)
preprocess_lm_param = [
"--only-source",
"--trainpref",
pre_gen + "/rescore_data_no_bpe." + target_lang,
"--srcdict",
cur_lm_dict,
"--destdir",
preprocess_directory,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [
preprocess_directory,
"--path",
cur_language_model,
"--output-word-probs",
"--batch-size",
str(batch_size),
"--max-tokens",
"1024",
"--sample-break-mode",
"eos",
"--gen-subset",
"train",
]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, "w") as f:
with redirect_stdout(f):
eval_lm.main(input_args)
elif bpe_status == "shared":
preprocess_lm_param = [
"--only-source",
"--trainpref",
pre_gen + "/rescore_data." + target_lang,
"--srcdict",
cur_lm_dict,
"--destdir",
preprocess_directory,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [
preprocess_directory,
"--path",
cur_language_model,
"--output-word-probs",
"--batch-size",
str(batch_size),
"--sample-break-mode",
"eos",
"--gen-subset",
"train",
]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, "w") as f:
with redirect_stdout(f):
eval_lm.main(input_args)
elif bpe_status == "different":
rescore_file = pre_gen + "/rescore_data_no_bpe"
rescore_bpe = pre_gen + "/rescore_data_new_bpe"
rescore_file += "."
rescore_bpe += "."
write_reprocessed(
gen_output.no_bpe_source,
gen_output.no_bpe_hypo,
gen_output.no_bpe_target,
rescore_file + source_lang,
rescore_file + target_lang,
pre_gen + "/reference_file_no_bpe",
bpe_symbol=None,
)
# apply LM bpe to nbest list
bpe_src_param = [
"-c",
cur_lm_bpe_code,
"--input",
rescore_file + target_lang,
"--output",
rescore_bpe + target_lang,
]
subprocess.call(
[
"python",
os.path.join(
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
),
]
+ bpe_src_param,
shell=False,
)
# uncomment to use fastbpe instead of subword-nmt bpe
# bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
# subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
preprocess_dir = preprocess_directory
preprocess_lm_param = [
"--only-source",
"--trainpref",
rescore_bpe + target_lang,
"--srcdict",
cur_lm_dict,
"--destdir",
preprocess_dir,
]
preprocess_parser = options.get_preprocessing_parser()
input_args = preprocess_parser.parse_args(preprocess_lm_param)
preprocess.main(input_args)
eval_lm_param = [
preprocess_dir,
"--path",
cur_language_model,
"--output-word-probs",
"--batch-size",
str(batch_size),
"--max-tokens",
"1024",
"--sample-break-mode",
"eos",
"--gen-subset",
"train",
]
eval_lm_parser = options.get_eval_lm_parser()
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
with open(lm_score_file, "w") as f:
with redirect_stdout(f):
eval_lm.main(input_args)
def rescore_file_name(
nbest_dir,
prefix_len,
scorer_name,
lm_file=False,
target_prefix_frac=None,
source_prefix_frac=None,
backwards=None,
):
if lm_file:
score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt"
else:
score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt"
if backwards:
if prefix_len is not None:
score_file += "prefix_len" + str(prefix_len)
elif target_prefix_frac is not None:
score_file += "target_prefix_frac" + str(target_prefix_frac)
else:
if source_prefix_frac is not None:
score_file += "source_prefix_frac" + str(source_prefix_frac)
return score_file
# Non-autoregressive Neural Machine Translation (NAT)
This page mainly includes instructions for reproducing results from the following papers
* [Levenshtein Transformer (Gu et al., 2019)](https://arxiv.org/abs/1905.11006).
* [Understanding Knowledge Distillation in Non-autoregressive Machine Translation (Zhou et al., 2019)](https://arxiv.org/abs/1911.02727).
We also provided our own implementations for several popular non-autoregressive-based models as reference:<br>
* [Non-Autoregressive Neural Machine Translation (Gu et al., 2017)](https://arxiv.org/abs/1711.02281)<br>
* [Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al., 2018)](https://arxiv.org/abs/1802.06901)<br>
* [Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al., 2019)](https://arxiv.org/abs/1902.03249)<br>
* [Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)](https://arxiv.org/abs/1904.09324v2)<br>
* [Fast Structured Decoding for Sequence Models (Sun et al., 2019)](https://arxiv.org/abs/1910.11555)
## Dataset
First, follow the [instructions to download and preprocess the WMT'14 En-De dataset](../translation#wmt14-english-to-german-convolutional).
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
### Knowledge Distillation
Following [Gu et al. 2019](https://arxiv.org/abs/1905.11006), [knowledge distillation](https://arxiv.org/abs/1606.07947) from an autoregressive model can effectively simplify the training data distribution, which is sometimes essential for NAT-based models to learn good translations.
The easiest way of performing distillation is to follow the [instructions of training a standard transformer model](../translation) on the same data, and then decode the training set to produce a distillation dataset for NAT.
### Download
We also provided the preprocessed [original](http://dl.fbaipublicfiles.com/nat/original_dataset.zip) and [distillation](http://dl.fbaipublicfiles.com/nat/distill_dataset.zip) datasets. Please build the binarized dataset on your own.
## Train a model
Then we can train a nonautoregressive model using the `translation_lev` task and a new criterion `nat_loss`.
Use the `--noise` flag to specify the input noise used on the target sentences.
In default, we run the task for *Levenshtein Transformer*, with `--noise='random_delete'`. Full scripts to run other models can also be found [here](./scripts.md).
The following command will train a *Levenshtein Transformer* on the binarized dataset.
```bash
fairseq-train \
data-bin/wmt14_en_de_distill \
--save-dir checkpoints \
--ddp-backend=no_c10d \
--task translation_lev \
--criterion nat_loss \
--arch levenshtein_transformer \
--noise random_delete \
--share-all-embeddings \
--optimizer adam --adam-betas '(0.9,0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt \
--min-lr '1e-09' --warmup-updates 10000 \
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
--dropout 0.3 --weight-decay 0.01 \
--decoder-learned-pos \
--encoder-learned-pos \
--apply-bert-init \
--log-format 'simple' --log-interval 100 \
--fixed-validation-seed 7 \
--max-tokens 8000 \
--save-interval-updates 10000 \
--max-update 300000
```
## Translate
Once a model is trained, we can generate translations using an `iterative_refinement_generator` which will based on the model's initial output and iteratively read and greedily refine the translation until (1) the model predicts the same translations for two consecutive iterations; or (2) the generator reaches the maximum iterations (`--iter-decode-max-iter`). Use `--print-step` to check the actual # of iteration for each sentence.
For *Levenshtein Transformer*, it sometimes helps to apply a `--iter-decode-eos-penalty` (typically, 0~3) to penalize the model finishing generation too early and generating too short translations.
For example, to generate with `--iter-decode-max-iter=9`:
```bash
fairseq-generate \
data-bin/wmt14_en_de_distill \
--gen-subset test \
--task translation_lev \
--path checkpoints/checkpoint_best.pt \
--iter-decode-max-iter 9 \
--iter-decode-eos-penalty 0 \
--beam 1 --remove-bpe \
--print-step \
--batch-size 400
```
In the end of the generation, we can see the tokenized BLEU score for the translation.
## Advanced Decoding Methods
### Ensemble
The NAT models use special implementations of [ensembling](https://github.com/fairinternal/fairseq-py/blob/b98d88da52f2f21f1b169bab8c70c1c4ca19a768/fairseq/sequence_generator.py#L522) to support iterative refinement and a variety of parallel operations in different models, while it shares the same API as standard autoregressive models as follows:
```bash
fairseq-generate \
data-bin/wmt14_en_de_distill \
--gen-subset test \
--task translation_lev \
--path checkpoint_1.pt:checkpoint_2.pt:checkpoint_3.pt \
--iter-decode-max-iter 9 \
--iter-decode-eos-penalty 0 \
--beam 1 --remove-bpe \
--print-step \
--batch-size 400
```
We use ``:`` to split multiple models. Note that, not all NAT models support ensembling for now.
### Length-beam
For models that predict lengths before decoding (e.g. the vanilla NAT, Mask-Predict, etc), it is possible to improve the translation quality by varying the target lengths around the predicted value, and translating the same example multiple times in parallel. We can select the best translation with the highest scores defined by your model's output.
Note that, not all models support length beams. For models which dynamically change the lengths (e.g. *Insertion Transformer*, *Levenshtein Transformer*), the same trick does not apply.
### Re-ranking
If the model generates multiple translations with length beam, we can also introduce an autoregressive model to rerank the translations considering scoring from an autoregressive model is much faster than decoding from that.
For example, to generate translations with length beam and reranking,
```bash
fairseq-generate \
data-bin/wmt14_en_de_distill \
--gen-subset test \
--task translation_lev \
--path checkpoints/checkpoint_best.pt:at_checkpoints/checkpoint_best.pt \
--iter-decode-max-iter 9 \
--iter-decode-eos-penalty 0 \
--iter-decode-with-beam 9 \
--iter-decode-with-external-reranker \
--beam 1 --remove-bpe \
--print-step \
--batch-size 100
```
Note that we need to make sure the autoregressive model shares the same vocabulary as our target non-autoregressive model.
## Citation
```bibtex
@incollection{NIPS2019_9297,
title = {Levenshtein Transformer},
author = {Gu, Jiatao and Wang, Changhan and Zhao, Junbo},
booktitle = {Advances in Neural Information Processing Systems 32},
editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
pages = {11179--11189},
year = {2019},
publisher = {Curran Associates, Inc.},
url = {http://papers.nips.cc/paper/9297-levenshtein-transformer.pdf}
}
```
```bibtex
@article{zhou2019understanding,
title={Understanding Knowledge Distillation in Non-autoregressive Machine Translation},
author={Zhou, Chunting and Neubig, Graham and Gu, Jiatao},
journal={arXiv preprint arXiv:1911.02727},
year={2019}
}
```
# Examples of Training scripts for Non-autoregressive Machine Translation models
### Non-autoregressive Transformer (NAT, Gu et al., 2017)
Note that we need to have an additional module to perform "length prediction" (`--length-loss-factor`) before generating the whole sequence.
```bash
fairseq-train \
data-bin/wmt14_en_de_distill \
--save-dir checkpoints \
--ddp-backend=no_c10d \
--task translation_lev \
--criterion nat_loss \
--arch nonautoregressive_transformer \
--noise full_mask \
--share-all-embeddings \
--optimizer adam --adam-betas '(0.9,0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt \
--min-lr '1e-09' --warmup-updates 10000 \
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
--dropout 0.3 --weight-decay 0.01 \
--decoder-learned-pos \
--encoder-learned-pos \
--pred-length-offset \
--length-loss-factor 0.1 \
--apply-bert-init \
--log-format 'simple' --log-interval 100 \
--fixed-validation-seed 7 \
--max-tokens 8000 \
--save-interval-updates 10000 \
--max-update 300000
```
### Fast Structured Decoding for Sequence Models (NAT-CRF, Sun et al., 2019)
Note that we implemented a low-rank appromixated CRF model by setting `--crf-lowrank-approx=32` and `--crf-beam-approx=64` as discribed in the original paper. All other settings are the same as the vanilla NAT model.
```bash
fairseq-train \
data-bin/wmt14_en_de_distill \
--save-dir checkpoints \
--ddp-backend=no_c10d \
--task translation_lev \
--criterion nat_loss \
--arch nacrf_transformer \
--noise full_mask \
--share-all-embeddings \
--optimizer adam --adam-betas '(0.9,0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt \
--min-lr '1e-09' --warmup-updates 10000 \
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
--dropout 0.3 --weight-decay 0.01 \
--decoder-learned-pos \
--encoder-learned-pos \
--pred-length-offset \
--length-loss-factor 0.1 \
--word-ins-loss-factor 0.5 \
--crf-lowrank-approx 32 \
--crf-beam-approx 64 \
--apply-bert-init \
--log-format 'simple' --log-interval 100 \
--fixed-validation-seed 7 \
--max-tokens 8000 \
--save-interval-updates 10000 \
--max-update 300000
```
### Non-autoregressive Transformer with Iterative Refinement (iNAT, Lee et al., 2018)
Note that `--train-step` means how many iterations of refinement we used during training, and `--dae-ratio` controls the ratio of denoising auto-encoder training described in the original paper.
```bash
fairseq-train \
data-bin/wmt14_en_de_distill \
--save-dir checkpoints \
--ddp-backend=no_c10d \
--task translation_lev \
--criterion nat_loss \
--arch iterative_nonautoregressive_transformer \
--noise full_mask \
--share-all-embeddings \
--optimizer adam --adam-betas '(0.9,0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt \
--min-lr '1e-09' --warmup-updates 10000 \
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
--dropout 0.3 --weight-decay 0.01 \
--decoder-learned-pos \
--encoder-learned-pos \
--pred-length-offset \
--length-loss-factor 0.1 \
--train-step 4 \
--dae-ratio 0.5 \
--stochastic-approx \
--apply-bert-init \
--log-format 'simple' --log-interval 100 \
--fixed-validation-seed 7 \
--max-tokens 8000 \
--save-interval-updates 10000 \
--max-update 300000
```
### Insertion Transformer (InsT, Stern et al., 2019)
Note that we need to specify the "slot-loss" (uniform or balanced tree) described in the original paper. Here we use `--label-tau` to control the temperature.
```bash
fairseq-train \
data-bin/wmt14_en_de_distill \
--save-dir checkpoints \
--ddp-backend=no_c10d \
--task translation_lev \
--criterion nat_loss \
--arch insertion_transformer \
--noise random_delete \
--share-all-embeddings \
--optimizer adam --adam-betas '(0.9,0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt \
--min-lr '1e-09' --warmup-updates 10000 \
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
--dropout 0.3 --weight-decay 0.01 \
--decoder-learned-pos \
--encoder-learned-pos \
--apply-bert-init \
--log-format 'simple' --log-interval 100 \
--fixed-validation-seed 7 \
--max-tokens 8000 \
--save-interval-updates 10000 \
--max-update 300000
```
### Mask Predict (CMLM, Ghazvininejad et al., 2019)
```bash
fairseq-train \
data-bin/wmt14_en_de_distill \
--save-dir checkpoints \
--ddp-backend=no_c10d \
--task translation_lev \
--criterion nat_loss \
--arch cmlm_transformer \
--noise random_mask \
--share-all-embeddings \
--optimizer adam --adam-betas '(0.9,0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt \
--min-lr '1e-09' --warmup-updates 10000 \
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
--dropout 0.3 --weight-decay 0.01 \
--decoder-learned-pos \
--encoder-learned-pos \
--apply-bert-init \
--log-format 'simple' --log-interval 100 \
--fixed-validation-seed 7 \
--max-tokens 8000 \
--save-interval-updates 10000 \
--max-update 300000
```
### Levenshtein Transformer (LevT, Gu et al., 2019)
```bash
fairseq-train \
data-bin/wmt14_en_de_distill \
--save-dir checkpoints \
--ddp-backend=no_c10d \
--task translation_lev \
--criterion nat_loss \
--arch levenshtein_transformer \
--noise random_delete \
--share-all-embeddings \
--optimizer adam --adam-betas '(0.9,0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt \
--min-lr '1e-09' --warmup-updates 10000 \
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
--dropout 0.3 --weight-decay 0.01 \
--decoder-learned-pos \
--encoder-learned-pos \
--apply-bert-init \
--log-format 'simple' --log-interval 100 \
--fixed-validation-seed 7 \
--max-tokens 8000 \
--save-interval-updates 10000 \
--max-update 300000
```
# Paraphrasing with round-trip translation and mixture of experts
Machine translation models can be used to paraphrase text by translating it to
an intermediate language and back (round-trip translation).
This example shows how to paraphrase text by first passing it to an
English-French translation model, followed by a French-English [mixture of
experts translation model](/examples/translation_moe).
##### 0. Setup
Clone fairseq from source and install necessary dependencies:
```bash
git clone https://github.com/pytorch/fairseq.git
cd fairseq
pip install --editable .
pip install sacremoses sentencepiece
```
##### 1. Download models
```bash
wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.en-fr.tar.gz
wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.fr-en.hMoEup.tar.gz
tar -xzvf paraphraser.en-fr.tar.gz
tar -xzvf paraphraser.fr-en.hMoEup.tar.gz
```
##### 2. Paraphrase
```bash
python examples/paraphraser/paraphrase.py \
--en2fr paraphraser.en-fr \
--fr2en paraphraser.fr-en.hMoEup
# Example input:
# The new date for the Games, postponed for a year in response to the coronavirus pandemic, gives athletes time to recalibrate their training schedules.
# Example outputs:
# Delayed one year in response to the coronavirus pandemic, the new date of the Games gives athletes time to rebalance their training schedule.
# The new date of the Games, which was rescheduled one year in response to the coronavirus (CV) pandemic, gives athletes time to rebalance their training schedule.
# The new date of the Games, postponed one year in response to the coronavirus pandemic, provides athletes with time to rebalance their training schedule.
# The Games' new date, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule.
# The new Games date, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule.
# The new date of the Games, which was postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule.
# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule.
# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to re-balance their training schedule.
# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their schedule of training.
# The new date of the Games, postponed one year in response to the pandemic of coronavirus, gives the athletes time to rebalance their training schedule.
```
#!/usr/bin/env python3 -u
import argparse
import fileinput
import logging
import os
import sys
from fairseq.models.transformer import TransformerModel
logging.getLogger().setLevel(logging.INFO)
def main():
parser = argparse.ArgumentParser(description="")
parser.add_argument("--en2fr", required=True, help="path to en2fr model")
parser.add_argument(
"--fr2en", required=True, help="path to fr2en mixture of experts model"
)
parser.add_argument(
"--user-dir", help="path to fairseq examples/translation_moe/src directory"
)
parser.add_argument(
"--num-experts",
type=int,
default=10,
help="(keep at 10 unless using a different model)",
)
parser.add_argument(
"files",
nargs="*",
default=["-"],
help='input files to paraphrase; "-" for stdin',
)
args = parser.parse_args()
if args.user_dir is None:
args.user_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
"translation_moe",
"src",
)
if os.path.exists(args.user_dir):
logging.info("found user_dir:" + args.user_dir)
else:
raise RuntimeError(
"cannot find fairseq examples/translation_moe/src "
"(tried looking here: {})".format(args.user_dir)
)
logging.info("loading en2fr model from:" + args.en2fr)
en2fr = TransformerModel.from_pretrained(
model_name_or_path=args.en2fr,
tokenizer="moses",
bpe="sentencepiece",
).eval()
logging.info("loading fr2en model from:" + args.fr2en)
fr2en = TransformerModel.from_pretrained(
model_name_or_path=args.fr2en,
tokenizer="moses",
bpe="sentencepiece",
user_dir=args.user_dir,
task="translation_moe",
).eval()
def gen_paraphrases(en):
fr = en2fr.translate(en)
return [
fr2en.translate(fr, inference_step_args={"expert": i})
for i in range(args.num_experts)
]
logging.info("Type the input sentence and press return:")
for line in fileinput.input(args.files):
line = line.strip()
if len(line) == 0:
continue
for paraphrase in gen_paraphrases(line):
print(paraphrase)
if __name__ == "__main__":
main()
# Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)
This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://arxiv.org/abs/1901.10430).
## Citation:
```bibtex
@inproceedings{wu2018pay,
title = {Pay Less Attention with Lightweight and Dynamic Convolutions},
author = {Felix Wu and Angela Fan and Alexei Baevski and Yann Dauphin and Michael Auli},
booktitle = {International Conference on Learning Representations},
year = {2019},
url = {https://arxiv.org/abs/1901.10430},
}
```
## Translation
### Pre-trained models
For some datasets we release models without GLUs which are faster at inference.
Model | Description | Dataset | Download
---|---|---|---
`lightconv.no_glu.iwslt14.de-en` | LightConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.lightconv.tar.gz) <br> IWSLT14 test: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2)
`dynamicconv.no_glu.iwslt14.de-en` | DynamicConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.dynamicconv.tar.gz) <br> IWSLT14 test: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2)
`lightconv.no_glu.wmt16.en-de` | LightConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
`dynamicconv.no_glu.wmt16.en-de` | DynamicConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
`lightconv.glu.wmt16.en-de` | LightConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
`dynamicconv.glu.wmt16.en-de` | DynamicConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
`lightconv.glu.wmt14.en-fr` | LightConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.lightconv-glu.tar.gz) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
`dynamicconv.glu.wmt14.en-fr` | DynamicConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.dynamicconv-glu.tar.gz) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
`lightconv.glu.wmt17.zh-en` | LightConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.lightconv-glu.tar.gz) <br> newstest2017: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2)
`dynamicconv.glu.wmt17.zh-en` | DynamicConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.dynamicconv-glu.tar.gz) <br> newstest2017: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2)
### Memory-Efficient CUDA Kernels
Since the PyTorch implementations of Light/Dynamic conv are quite memory intensive, we have developed CUDA kernels that implement the light and dynamic convolution operator in a memory-efficient and performant manner. For large sequence lengths, these kernels save about 50% memory compared to the PyTorch equivalent.
To install the kernels, use the commands below. Once installed, they will automatically be used in place of the PyTorch implementations whenever a light or dynamic convolution is used.
```sh
# to install lightconv
cd fairseq/modules/lightconv_layer
python cuda_function_gen.py
python setup.py install
# to install dynamicconv
cd fairseq/modules/dynamicconv_layer
python cuda_function_gen.py
python setup.py install
```
### Example usage (torch.hub)
We require a few additional Python dependencies for preprocessing:
```bash
pip install sacremoses subword_nmt
```
Interactive translation via PyTorch Hub:
```python
import torch
# List available models
torch.hub.list('pytorch/fairseq') # [..., 'lightconv.glu.wmt17.zh-en', ... ]
# Load a transformer trained on WMT'16 En-De
zh2en = torch.hub.load('pytorch/fairseq', 'lightconv.glu.wmt17.zh-en', tokenizer='moses', bpe='subword_nmt')
# The underlying model is available under the *models* attribute
assert isinstance(zh2en.models[0], fairseq.models.lightconv.LightConvModel)
# Translate a sentence
zh2en.translate('你好 世界')
# 'Hello World'
```
Loading custom models:
```python
from fairseq.models.lightconv import LightConvModel
en2fr = LightConvModel.from_pretrained(
'/path/to/checkpoints',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='data-bin/wmt14_en_fr',
bpe='subword_nmt',
bpe_codes='data-bin/wmt14_en_fr/en.code'
)
en2fr.translate('Hello world!')
# 'Bonjour le monde'
```
### Preprocessing the training datasets
Please follow the instructions in [`examples/translation/README.md`](../translation/README.md) to preprocess the data.
### Training and evaluation options:
To use the model without GLU, please set `--encoder-glu 0 --decoder-glu 0`.
For LightConv, please use `--encoder-conv-type lightweight --decoder-conv-type lightweight`, otherwise the default is DynamicConv.
For best BLEU results, lenpen may need to be manually tuned.
To use the CUDA kernels, first install the PyTorch modules using the commands
above. Once the CUDA modules are installed, they will automatically be used
instead of the PyTorch modules.
### IWSLT14 De-En
Training and evaluating DynamicConv (without GLU) on a GPU:
```sh
# Training
SAVE="save/dynamic_conv_iwslt"
mkdir -p $SAVE
CUDA_VISIBLE_DEVICES=0 $(which fairseq-train) data-bin/iwslt14.tokenized.de-en \
--clip-norm 0 --optimizer adam --lr 0.0005 \
--source-lang de --target-lang en --max-tokens 4000 --no-progress-bar \
--log-interval 100 --min-lr '1e-09' --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler inverse_sqrt \
--ddp-backend=no_c10d \
--max-update 50000 --warmup-updates 4000 --warmup-init-lr '1e-07' \
--adam-betas '(0.9, 0.98)' --keep-last-epochs 10 \
-a lightconv_iwslt_de_en --save-dir $SAVE \
--dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \
--encoder-glu 0 --decoder-glu 0
python scripts/average_checkpoints.py --inputs $SAVE \
--num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt"
# Evaluation
CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/iwslt14.tokenized.de-en --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 128 --beam 4 --remove-bpe --lenpen 1 --gen-subset test --quiet
```
### WMT16 En-De
Training and evaluating DynamicConv (with GLU) on WMT16 En-De using cosine scheduler on one machine with 8 V100 GPUs:
```sh
# Training
SAVE="save/dynamic_conv_wmt16en2de"
mkdir -p $SAVE
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
data-bin/wmt16_en_de_bpe32k --fp16 --log-interval 100 --no-progress-bar \
--max-update 30000 --share-all-embeddings --optimizer adam \
--adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
--ddp-backend=no_c10d --max-tokens 3584 \
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
--lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \
--t-mult 1 --lr-period-updates 20000 \
--arch lightconv_wmt_en_de_big --save-dir $SAVE \
--dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \
--encoder-glu 1 --decoder-glu 1
# Evaluation
CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/wmt16.en-de.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.5 --gen-subset test > wmt16_gen.txt
bash scripts/compound_split_bleu.sh wmt16_gen.txt
```
### WMT14 En-Fr
Training DynamicConv (with GLU) on WMT14 En-Fr using cosine scheduler on one machine with 8 V100 GPUs:
```sh
# Training
SAVE="save/dynamic_conv_wmt14en2fr"
mkdir -p $SAVE
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
data-bin/wmt14_en_fr --fp16 --log-interval 100 --no-progress-bar \
--max-update 30000 --share-all-embeddings --optimizer adam \
--adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
--ddp-backend=no_c10d --max-tokens 3584 \
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
--lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \
--t-mult 1 --lr-period-updates 70000 \
--arch lightconv_wmt_en_fr_big --save-dir $SAVE \
--dropout 0.1 --attention-dropout 0.1 --weight-dropout 0.1 \
--encoder-glu 1 --decoder-glu 1
# Evaluation
CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/wmt14.en-fr.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.9 --gen-subset test
```
# Transformer with Pointer-Generator Network
This page describes the `transformer_pointer_generator` model that incorporates
a pointing mechanism in the Transformer model that facilitates copying of input
words to the output. This architecture is described in [Enarvi et al. (2020)](https://www.aclweb.org/anthology/2020.nlpmc-1.4/).
## Background
The pointer-generator network was introduced in [See et al. (2017)](https://arxiv.org/abs/1704.04368)
for RNN encoder-decoder attention models. A similar mechanism can be
incorporated in a Transformer model by reusing one of the many attention
distributions for pointing. The attention distribution over the input words is
interpolated with the normal output distribution over the vocabulary words. This
allows the model to generate words that appear in the input, even if they don't
appear in the vocabulary, helping especially with small vocabularies.
## Implementation
The mechanism for copying out-of-vocabulary words from the input has been
implemented differently to See et al. In their [implementation](https://github.com/abisee/pointer-generator)
they convey the word identities through the model in order to be able to produce
words that appear in the input sequence but not in the vocabulary. A different
approach was taken in the Fairseq implementation to keep it self-contained in
the model file, avoiding any changes to the rest of the code base. Copying
out-of-vocabulary words is possible by pre-processing the input and
post-processing the output. This is described in detail in the next section.
## Usage
The training and evaluation procedure is outlined below. You can also find a
more detailed example for the XSum dataset on [this page](README.xsum.md).
##### 1. Create a vocabulary and extend it with source position markers
The pointing mechanism is especially helpful with small vocabularies, if we are
able to recover the identities of any out-of-vocabulary words that are copied
from the input. For this purpose, the model allows extending the vocabulary with
special tokens that can be used in place of `<unk>` tokens to identify different
input positions. For example, the user may add `<unk-0>`, `<unk-1>`, `<unk-2>`,
etc. to the end of the vocabulary, after the normal words. Below is an example
of how to create a vocabulary of 10000 most common words and add 1000 input
position markers.
```bash
vocab_size=10000
position_markers=1000
export LC_ALL=C
cat train.src train.tgt |
tr -s '[:space:]' '\n' |
sort |
uniq -c |
sort -k1,1bnr -k2 |
head -n "$((vocab_size - 4))" |
awk '{ print $2 " " $1 }' >dict.pg.txt
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt
```
##### 2. Preprocess the text data
The idea is that any `<unk>` tokens in the text are replaced with `<unk-0>` if
it appears in the first input position, `<unk-1>` if it appears in the second
input position, and so on. This can be achieved using the `preprocess.py` script
that is provided in this directory.
##### 3. Train a model
The number of these special tokens is given to the model with the
`--source-position-markers` argument—the model simply maps all of these to the
same word embedding as `<unk>`.
The attention distribution that is used for pointing is selected using the
`--alignment-heads` and `--alignment-layer` command-line arguments in the same
way as with the `transformer_align` model.
##### 4. Generate text and postprocess it
When using the model to generate text, you want to preprocess the input text in
the same way that training data was processed, replacing out-of-vocabulary words
with `<unk-N>` tokens. If any of these tokens are copied to the output, the
actual words can be retrieved from the unprocessed input text. Any `<unk-N>`
token should be replaced with the word at position N in the original input
sequence. This can be achieved using the `postprocess.py` script.
## Training a pointer-generator model on the Extreme Summarization dataset
##### 1. Download the Extreme Summarization data and preprocess it
Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to obtain
the original Extreme Summarization dataset. You should have six files,
{train,validation,test}.{document,summary}.
##### 2. Create a vocabulary and extend it with source position markers
```bash
vocab_size=10000
position_markers=1000
export LC_ALL=C
cat train.document train.summary |
tr -s '[:space:]' '\n' |
sort |
uniq -c |
sort -k1,1bnr -k2 |
head -n "$((vocab_size - 4))" |
awk '{ print $2 " " $1 }' >dict.pg.txt
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt
```
This creates the file dict.pg.txt that contains the 10k most frequent words,
followed by 1k source position markers:
```
the 4954867
. 4157552
, 3439668
to 2212159
a 1916857
of 1916820
and 1823350
...
<unk-0> 0
<unk-1> 0
<unk-2> 0
<unk-3> 0
<unk-4> 0
...
```
##### 2. Preprocess the text data
```bash
./preprocess.py --source train.document --target train.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out train.pg.src --target-out train.pg.tgt
./preprocess.py --source validation.document --target validation.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out valid.pg.src --target-out valid.pg.tgt
./preprocess.py --source test.document --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out test.pg.src
```
The data should now contain `<unk-N>` tokens in place of out-of-vocabulary words.
##### 3. Binarize the dataset:
```bash
fairseq-preprocess \
--source-lang src \
--target-lang tgt \
--trainpref train.pg \
--validpref valid.pg \
--destdir bin \
--workers 60 \
--srcdict dict.pg.txt \
--joined-dictionary
```
##### 3. Train a model
```bash
total_updates=20000
warmup_updates=500
lr=0.001
max_tokens=4096
update_freq=4
pointer_layer=-2
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train bin \
--user-dir examples/pointer_generator/pointer_generator_src \
--max-tokens "$max_tokens" \
--task translation \
--source-lang src --target-lang tgt \
--truncate-source \
--layernorm-embedding \
--share-all-embeddings \
--encoder-normalize-before \
--decoder-normalize-before \
--required-batch-size-multiple 1 \
--arch transformer_pointer_generator \
--alignment-layer "$pointer_layer" \
--alignment-heads 1 \
--source-position-markers 1000 \
--criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
--clip-norm 0.1 \
--lr-scheduler inverse_sqrt --lr "$lr" --max-update "$total_updates" --warmup-updates "$warmup_updates" \
--update-freq "$update_freq" \
--skip-invalid-size-inputs-valid-test
```
Above we specify that our dictionary contains 1000 source position markers, and
that we want to use one attention head from the penultimate decoder layer for
pointing. It should run in 5.5 hours on one node with eight 32GB V100 GPUs. The
logged messages confirm that dictionary indices above 10000 will be mapped to
the `<unk>` embedding:
```
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [src] dictionary: 11000 types
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [tgt] dictionary: 11000 types
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.src
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.tgt
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | bin valid src-tgt 11332 examples
2020-09-24 20:43:53 | INFO | fairseq.models.transformer_pg | dictionary indices from 10000 to 10999 will be mapped to 3
```
##### 4. Summarize the test sequences
```bash
batch_size=32
beam_size=6
max_length=60
length_penalty=1.0
fairseq-interactive bin \
--user-dir examples/pointer_generator/pointer_generator_src \
--batch-size "$batch_size" \
--task translation \
--source-lang src --target-lang tgt \
--path checkpoints/checkpoint_last.pt \
--input test.pg.src \
--buffer-size 200 \
--max-len-a 0 \
--max-len-b "$max_length" \
--lenpen "$length_penalty" \
--beam "$beam_size" \
--skip-invalid-size-inputs-valid-test |
tee generate.out
grep ^H generate.out | cut -f 3- >generate.hyp
```
Now you should have the generated sequences in `generate.hyp`. They contain
`<unk-N>` tokens that the model has copied from the source sequence. In order to
retrieve the original words, we need the unprocessed source sequences from
`test.document`.
##### 5. Process the generated output
Since we skipped too long inputs when producing `generate.hyp`, we also have to
skip too long sequences now that we read `test.document`.
```bash
./postprocess.py \
--source <(awk 'NF<1024' test.document) \
--target generate.hyp \
--target-out generate.hyp.processed
```
Now you'll find the final sequences from `generate.hyp.processed`, with
`<unk-N>` replaced with the original word from the source sequence.
##### An example of a summarized sequence
The original source document in `test.document`:
> de roon moved to teesside in june 2016 for an initial # 8.8 m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page .
The preprocessed source document in `test.src.pg`:
> de \<unk-1> moved to \<unk-4> in june 2016 for an initial # \<unk-12> m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page .
The generated summary in `generate.hyp`:
> middlesbrough striker \<unk> de \<unk-1> has joined spanish side \<unk> on a season-long loan .
The generated summary after postprocessing in `generate.hyp.processed`:
> middlesbrough striker \<unk> de roon has joined spanish side \<unk> on a season-long loan .
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import transformer_pg # noqa
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment