Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
import os
import os.path as op
import zipfile
from functools import reduce
from glob import glob
from multiprocessing import cpu_count
from typing import Any, Dict, List
import numpy as np
import sentencepiece as sp
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
from tqdm import tqdm
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
def gen_vocab(
input_path: str,
output_path_prefix: str,
model_type="bpe",
vocab_size=1000,
):
# Train SentencePiece Model
arguments = [
f"--input={input_path}",
f"--model_prefix={output_path_prefix}",
f"--model_type={model_type}",
f"--vocab_size={vocab_size}",
"--character_coverage=1.0",
f"--num_threads={cpu_count()}",
f"--unk_id={UNK_TOKEN_ID}",
f"--bos_id={BOS_TOKEN_ID}",
f"--eos_id={EOS_TOKEN_ID}",
f"--pad_id={PAD_TOKEN_ID}",
]
sp.SentencePieceTrainer.Train(" ".join(arguments))
# Export fairseq dictionary
spm = sp.SentencePieceProcessor()
spm.Load(output_path_prefix + ".model")
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
assert (
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
)
vocab = {
i: s
for i, s in vocab.items()
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
}
with open(output_path_prefix + ".txt", "w") as f_out:
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
f_out.write(f"{s} 1\n")
def extract_fbank_features(
waveform,
sample_rate,
output_path=None,
n_mel_bins=80,
apply_utterance_cmvn=True,
overwrite=False,
):
if output_path is not None and op.exists(output_path) and not overwrite:
return
_waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers
_waveform = _waveform.squeeze().numpy()
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
if apply_utterance_cmvn:
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
features = cmvn(features)
if output_path is not None:
np.save(output_path, features)
else:
return features
def create_zip(data_root, zip_path):
cwd = os.path.abspath(os.curdir)
os.chdir(data_root)
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
for filename in tqdm(glob("*.npy")):
f.write(filename)
os.chdir(cwd)
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def get_zip_manifest(zip_root, zip_filename):
zip_path = op.join(zip_root, zip_filename)
with zipfile.ZipFile(zip_path, mode="r") as f:
info = f.infolist()
manifest = {}
for i in tqdm(info):
utt_id = op.splitext(i.filename)[0]
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
with open(zip_path, "rb") as f:
f.seek(offset)
data = f.read(file_size)
assert len(data) > 1 and is_npy_data(data)
return manifest
def gen_config_yaml(
data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb"
):
assert specaugment_policy in {"lb", "ld"}
data_root = op.abspath(data_root)
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
writer.set_audio_root(op.abspath(data_root))
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
writer.set_input_channels(1)
writer.set_input_feat_per_channel(80)
if specaugment_policy == "lb":
writer.set_specaugment_lb_policy()
else:
writer.set_specaugment_ld_policy()
writer.set_bpe_tokenizer(
{
"bpe": "sentencepiece",
"sentencepiece_model": op.join(data_root, spm_filename),
}
)
writer.set_feature_transforms("_train", ["specaugment"])
writer.flush()
def save_df_to_tsv(dataframe, path):
dataframe.to_csv(
path,
sep="\t",
header=True,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
)
def filter_manifest_df(
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
):
filters = {
"no speech": df["audio"] == "",
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
"empty sentence": df["tgt_text"] == "",
}
if is_train_split:
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
if extra_filters is not None:
filters.update(extra_filters)
invalid = reduce(lambda x, y: x | y, filters.values())
valid = ~invalid
print(
"| "
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
)
return df[valid]
class S2TDataConfigWriter(object):
DEFAULT_VOCAB_FILENAME = "dict.txt"
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
DEFAULT_INPUT_CHANNELS = 1
def __init__(self, yaml_path):
try:
import yaml
except ImportError:
print("Please install PyYAML to load YAML files for S2T data config")
self.yaml = yaml
self.yaml_path = yaml_path
self.config = {}
def flush(self):
with open(self.yaml_path, "w") as f:
self.yaml.dump(self.config, f)
def set_audio_root(self, audio_root=""):
self.config["audio_root"] = audio_root
def set_vocab_filename(self, vocab_filename="dict.txt"):
self.config["vocab_filename"] = vocab_filename
def set_specaugment(
self,
time_wrap_w: int,
freq_mask_n: int,
freq_mask_f: int,
time_mask_n: int,
time_mask_t: int,
time_mask_p: float,
):
self.config["specaugment"] = {
"time_wrap_W": time_wrap_w,
"freq_mask_N": freq_mask_n,
"freq_mask_F": freq_mask_f,
"time_mask_N": time_mask_n,
"time_mask_T": time_mask_t,
"time_mask_p": time_mask_p,
}
def set_specaugment_lb_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=1,
freq_mask_f=27,
time_mask_n=1,
time_mask_t=100,
time_mask_p=1.0,
)
def set_specaugment_ld_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=27,
time_mask_n=2,
time_mask_t=100,
time_mask_p=1.0,
)
def set_input_channels(self, input_channels=1):
self.config["input_channels"] = input_channels
def set_input_feat_per_channel(self, input_feat_per_channel=80):
self.config["input_feat_per_channel"] = input_feat_per_channel
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
self.config["bpe_tokenizer"] = bpe_tokenizer
def set_feature_transforms(self, split, transforms: List[str]):
if "transforms" not in self.config:
self.config["transforms"] = {}
self.config["transforms"][split] = transforms
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import csv
import logging
import os
import os.path as op
import shutil
from tempfile import NamedTemporaryFile
from typing import Optional, Tuple
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive
from tqdm import tqdm
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class CoVoST(Dataset):
"""Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
Args:
root (str): root path to the dataset and generated manifests/features
source_language (str): source (audio) language
target_language (str, optional): target (text) language,
None for no translation (default: None)
version (int, optional): CoVoST version. (default: 2)
download (bool, optional): Whether to download the dataset if it is not
found at root path. (default: ``False``).
"""
CV_URL_TEMPLATE = (
"https://voice-prod-bundler-ee1969a6ce8178826482b88"
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
)
COVOST_URL_TEMPLATE = (
"https://dl.fbaipublicfiles.com/covost/"
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
)
VERSIONS = {2}
SPLITS = ["train", "dev", "test"]
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
XX_EN_LANGUAGES = {
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
2: [
"fr",
"de",
"es",
"ca",
"it",
"ru",
"zh-CN",
"pt",
"fa",
"et",
"mn",
"nl",
"tr",
"ar",
"sv-SE",
"lv",
"sl",
"ta",
"ja",
"id",
"cy",
],
}
EN_XX_LANGUAGES = {
1: [],
2: [
"de",
"tr",
"fa",
"sv-SE",
"mn",
"zh-CN",
"cy",
"ca",
"sl",
"et",
"id",
"ar",
"ta",
"lv",
"ja",
],
}
def __init__(
self,
root: str,
split: str,
source_language: str,
target_language: Optional[str] = None,
version: int = 2,
download: bool = False,
) -> None:
assert version in self.VERSIONS and split in self.SPLITS
assert source_language is not None
self.no_translation = target_language is None
if not self.no_translation:
assert "en" in {source_language, target_language}
if source_language == "en":
assert target_language in self.EN_XX_LANGUAGES[version]
else:
assert source_language in self.XX_EN_LANGUAGES[version]
else:
# Hack here so that we can get "split" column from CoVoST TSV.
# Note that we use CoVoST train split for ASR which is an extension
# to Common Voice train split.
target_language = "de" if source_language == "en" else "en"
self.root = os.path.join(root, "raw")
os.makedirs(self.root, exist_ok=True)
cv_url = self.CV_URL_TEMPLATE.format(
ver=self.CV_VERSION_ID[version], lang=source_language
)
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
if download:
if not os.path.isfile(cv_archive):
download_url(cv_url, self.root, hash_value=None)
extract_archive(cv_archive)
covost_url = self.COVOST_URL_TEMPLATE.format(
src_lang=source_language, tgt_lang=target_language
)
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
if download:
if not os.path.isfile(covost_archive):
download_url(covost_url, self.root, hash_value=None)
extract_archive(covost_archive)
cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
covost_tsv = self.load_from_tsv(
os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
)
df = pd.merge(
left=cv_tsv[["path", "sentence", "client_id"]],
right=covost_tsv[["path", "translation", "split"]],
how="inner",
on="path",
)
if split == "train":
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
else:
df = df[df["split"] == split]
self.data = df.to_dict(orient="index").items()
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
@classmethod
def load_from_tsv(cls, path: str):
return pd.read_csv(
path,
sep="\t",
header=0,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
na_filter=False,
)
def __getitem__(
self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
sample_id)``
"""
data = self.data[n]
path = os.path.join(self.root, "clips", data["path"])
waveform, sample_rate = torchaudio.load(path)
sentence = data["sentence"]
translation = None if self.no_translation else data["translation"]
speaker_id = data["client_id"]
_id = data["path"].replace(".mp3", "")
return waveform, sample_rate, sentence, translation, speaker_id, _id
def __len__(self) -> int:
return len(self.data)
def process(args):
root = op.join(args.data_root, args.src_lang)
os.makedirs(root, exist_ok=True)
# Extract features
feature_root = op.join(root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in CoVoST.SPLITS:
print(f"Fetching split {split}...")
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
)
# Pack features into ZIP
zip_filename = "fbank80.zip"
zip_path = op.join(root, zip_filename)
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
# Generate TSV manifest
print("Generating manifest...")
train_text = []
task = f"asr_{args.src_lang}"
if args.tgt_lang is not None:
task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
manifest["speaker"].append(speaker_id)
is_train_split = split.startswith("train")
if is_train_split:
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
# Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
)
# Generate config YAML
gen_config_yaml(
root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=1000, type=int)
parser.add_argument("--src-lang", "-s", required=True, type=str)
parser.add_argument("--tgt-lang", "-t", type=str)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
import os.path as op
import shutil
from tempfile import NamedTemporaryFile
import pandas as pd
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torchaudio.datasets import LIBRISPEECH
from tqdm import tqdm
log = logging.getLogger(__name__)
SPLITS = [
"train-clean-100",
"train-clean-360",
"train-other-500",
"dev-clean",
"dev-other",
"test-clean",
"test-other",
]
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
def process(args):
os.makedirs(args.output_root, exist_ok=True)
# Extract features
feature_root = op.join(args.output_root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in SPLITS:
print(f"Fetching split {split}...")
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
print("Extracting log mel filter bank features...")
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
extract_fbank_features(
wav, sample_rate, op.join(feature_root, f"{sample_id}.npy")
)
# Pack features into ZIP
zip_filename = "fbank80.zip"
zip_path = op.join(args.output_root, zip_filename)
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
# Generate TSV manifest
print("Generating manifest...")
train_text = []
for split in SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = LIBRISPEECH(args.output_root, url=split)
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
manifest["id"].append(sample_id)
manifest["audio"].append(zip_manifest[sample_id])
duration_ms = int(wav.size(1) / sample_rate * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(utt)
manifest["speaker"].append(spk_id)
save_df_to_tsv(
pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv")
)
if split.startswith("train"):
train_text.extend(manifest["tgt_text"])
# Generate vocab
vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
f.name,
op.join(args.output_root, spm_filename_prefix),
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
gen_config_yaml(
args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld"
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output-root", "-o", required=True, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=10000, type=int)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
import os.path as op
import shutil
from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torch import Tensor
from torch.utils.data import Dataset
from tqdm import tqdm
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
TASKS = ["asr", "st"]
class MUSTC(Dataset):
"""
Create a Dataset for MuST-C. Each item is a tuple of the form:
waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id
"""
SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
_root = op.join(root, f"en-{lang}", "data", split)
wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt")
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
# Load audio segments
try:
import yaml
except ImportError:
print("Please install PyYAML to load YAML files for " "the MuST-C dataset")
with open(op.join(txt_root, f"{split}.yaml")) as f:
segments = yaml.load(f, Loader=yaml.BaseLoader)
# Load source and target utterances
for _lang in ["en", lang]:
with open(op.join(txt_root, f"{split}.{_lang}")) as f:
utterances = [r.strip() for r in f]
assert len(segments) == len(utterances)
for i, u in enumerate(utterances):
segments[i][_lang] = u
# Gather info
self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = op.join(wav_root, wav_filename)
sample_rate = torchaudio.info(wav_path)[0].rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{op.splitext(wav_filename)[0]}_{i}"
self.data.append(
(
wav_path,
offset,
n_frames,
sample_rate,
segment["en"],
segment[lang],
segment["speaker_id"],
_id,
)
)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
def __len__(self) -> int:
return len(self.data)
def process(args):
for lang in MUSTC.LANGUAGES:
cur_root = op.join(args.data_root, f"en-{lang}")
if not op.isdir(cur_root):
print(f"{cur_root} does not exist. Skipped.")
continue
# Extract features
feature_root = op.join(cur_root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in MUSTC.SPLITS:
print(f"Fetching split {split}...")
dataset = MUSTC(args.data_root, lang, split)
print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
)
# Pack features into ZIP
zip_filename = "fbank80.zip"
zip_path = op.join(cur_root, zip_filename)
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}")
# Generate TSV manifest
print("Generating manifest...")
train_text = {task: [] for task in TASKS}
for split in MUSTC.SPLITS:
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
text = {task: [] for task in TASKS}
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
text["asr"].append(src_utt)
text["st"].append(tgt_utt)
manifest["speaker"].append(speaker_id)
if is_train_split:
for task in TASKS:
train_text[task].extend(text[task])
for task in TASKS:
manifest["tgt_text"] = text[task]
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv"))
# Generate vocab
for task in TASKS:
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
if task == "st":
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
vocab_size_str = "" if vocab_type == "char" else str(vocab_size)
spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text[task]:
f.write(t + "\n")
gen_vocab(
f.name,
op.join(cur_root, spm_filename_prefix),
vocab_type,
vocab_size,
)
# Generate config YAML
gen_config_yaml(
cur_root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument(
"--asr-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument(
"--st-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--asr-vocab-size", default=5000, type=int)
parser.add_argument("--st-vocab-size", default=8000, type=int)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# Hierarchical Neural Story Generation (Fan et al., 2018)
The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset.
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2)
We provide sample stories generated by the [convolutional seq2seq model](https://dl.fbaipublicfiles.com/fairseq/data/seq2seq_stories.txt) and [fusion model](https://dl.fbaipublicfiles.com/fairseq/data/fusion_stories.txt) from [Fan et al., 2018](https://arxiv.org/abs/1805.04833). The corresponding prompts for the fusion model can be found [here](https://dl.fbaipublicfiles.com/fairseq/data/fusion_prompts.txt). Note that there are unk in the file, as we modeled a small full vocabulary (no BPE or pre-training). We did not use these unk prompts for human evaluation.
## Dataset
The dataset can be downloaded like this:
```bash
cd examples/stories
curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvzf -
```
and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token.
## Example usage
First we will preprocess the dataset. Note that the dataset release is the full data, but the paper models the first 1000 words of each story. Here is example code that trims the dataset to the first 1000 words of each story:
```python
data = ["train", "test", "valid"]
for name in data:
with open(name + ".wp_target") as f:
stories = f.readlines()
stories = [" ".join(i.split()[0:1000]) for i in stories]
with open(name + ".wp_target", "w") as o:
for line in stories:
o.write(line.strip() + "\n")
```
Once we've trimmed the data we can binarize it and train our model:
```bash
# Binarize the dataset:
export TEXT=examples/stories/writingPrompts
fairseq-preprocess --source-lang wp_source --target-lang wp_target \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/writingPrompts --padding-factor 1 --thresholdtgt 10 --thresholdsrc 10
# Train the model:
fairseq-train data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --optimizer nag --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False
# Train a fusion model:
# add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint
# Generate:
# Note: to load the pretrained model at generation time, you need to pass in a model-override argument to communicate to the fusion model at generation time where you have placed the pretrained checkpoint. By default, it will load the exact path of the fusion model's pretrained model from training time. You should use model-override if you have moved the pretrained model (or are using our provided models). If you are generating from a non-fusion model, the model-override argument is not necessary.
fairseq-generate data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}"
```
## Citation
```bibtex
@inproceedings{fan2018hierarchical,
title = {Hierarchical Neural Story Generation},
author = {Fan, Angela and Lewis, Mike and Dauphin, Yann},
booktitle = {Conference of the Association for Computational Linguistics (ACL)},
year = 2018,
}
```
# Neural Machine Translation
This README contains instructions for [using pretrained translation models](#example-usage-torchhub)
as well as [training new models](#training-a-new-model).
## Pre-trained models
Model | Description | Dataset | Download
---|---|---|---
`conv.wmt14.en-fr` | Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
`conv.wmt14.en-de` | Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
`conv.wmt17.en-de` | Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
`transformer.wmt14.en-fr` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
`transformer.wmt16.en-de` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
`transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
`transformer.wmt19.en-de` | Transformer <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) <br> WMT'19 winner | [WMT'19 English-German](http://www.statmt.org/wmt19/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz)
`transformer.wmt19.de-en` | Transformer <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) <br> WMT'19 winner | [WMT'19 German-English](http://www.statmt.org/wmt19/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz)
`transformer.wmt19.en-ru` | Transformer <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) <br> WMT'19 winner | [WMT'19 English-Russian](http://www.statmt.org/wmt19/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz)
`transformer.wmt19.ru-en` | Transformer <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) <br> WMT'19 winner | [WMT'19 Russian-English](http://www.statmt.org/wmt19/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz)
## Example usage (torch.hub)
We require a few additional Python dependencies for preprocessing:
```bash
pip install fastBPE sacremoses subword_nmt
```
Interactive translation via PyTorch Hub:
```python
import torch
# List available models
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ]
# Load a transformer trained on WMT'16 En-De
# Note: WMT'19 models use fastBPE instead of subword_nmt, see instructions below
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de',
tokenizer='moses', bpe='subword_nmt')
en2de.eval() # disable dropout
# The underlying model is available under the *models* attribute
assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel)
# Move model to GPU for faster translation
en2de.cuda()
# Translate a sentence
en2de.translate('Hello world!')
# 'Hallo Welt!'
# Batched translation
en2de.translate(['Hello world!', 'The cat sat on the mat.'])
# ['Hallo Welt!', 'Die Katze saß auf der Matte.']
```
Loading custom models:
```python
from fairseq.models.transformer import TransformerModel
zh2en = TransformerModel.from_pretrained(
'/path/to/checkpoints',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='data-bin/wmt17_zh_en_full',
bpe='subword_nmt',
bpe_codes='data-bin/wmt17_zh_en_full/zh.code'
)
zh2en.translate('你好 世界')
# 'Hello World'
```
If you are using a `transformer.wmt19` models, you will need to set the `bpe`
argument to `'fastbpe'` and (optionally) load the 4-model ensemble:
```python
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de',
checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
tokenizer='moses', bpe='fastbpe')
en2de.eval() # disable dropout
```
## Example usage (CLI tools)
Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
```bash
mkdir -p data-bin
curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
fairseq-generate data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
# ...
# | Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s)
# | Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
# Compute BLEU score
grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
fairseq-score --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
# BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
```
## Training a new model
### IWSLT'14 German to English (Transformer)
The following instructions can be used to train a Transformer model on the [IWSLT'14 German to English dataset](http://workshop2014.iwslt.org/downloads/proceeding.pdf).
First download and preprocess the data:
```bash
# Download and prepare the data
cd examples/translation/
bash prepare-iwslt14.sh
cd ../..
# Preprocess/binarize the data
TEXT=examples/translation/iwslt14.tokenized.de-en
fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/iwslt14.tokenized.de-en \
--workers 20
```
Next we'll train a Transformer translation model over this data:
```bash
CUDA_VISIBLE_DEVICES=0 fairseq-train \
data-bin/iwslt14.tokenized.de-en \
--arch transformer_iwslt_de_en --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--dropout 0.3 --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 4096 \
--eval-bleu \
--eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
--eval-bleu-detok moses \
--eval-bleu-remove-bpe \
--eval-bleu-print-samples \
--best-checkpoint-metric bleu --maximize-best-checkpoint-metric
```
Finally we can evaluate our trained model:
```bash
fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--batch-size 128 --beam 5 --remove-bpe
```
### WMT'14 English to German (Convolutional)
The following instructions can be used to train a Convolutional translation model on the WMT English to German dataset.
See the [Scaling NMT README](../scaling_nmt/README.md) for instructions to train a Transformer translation model on this data.
The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script.
By default it will produce a dataset that was modeled after [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762), but with additional news-commentary-v12 data from WMT'17.
To use only data available in WMT'14 or to replicate results obtained in the original [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](https://arxiv.org/abs/1705.03122) paper, please use the `--icml17` option.
```bash
# Download and prepare the data
cd examples/translation/
# WMT'17 data:
bash prepare-wmt14en2de.sh
# or to use WMT'14 data:
# bash prepare-wmt14en2de.sh --icml17
cd ../..
# Binarize the dataset
TEXT=examples/translation/wmt17_en_de
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0 \
--workers 20
# Train the model
mkdir -p checkpoints/fconv_wmt_en_de
fairseq-train \
data-bin/wmt17_en_de \
--arch fconv_wmt_en_de \
--dropout 0.2 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer nag --clip-norm 0.1 \
--lr 0.5 --lr-scheduler fixed --force-anneal 50 \
--max-tokens 4000 \
--save-dir checkpoints/fconv_wmt_en_de
# Evaluate
fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/fconv_wmt_en_de/checkpoint_best.pt \
--beam 5 --remove-bpe
```
### WMT'14 English to French
```bash
# Download and prepare the data
cd examples/translation/
bash prepare-wmt14en2fr.sh
cd ../..
# Binarize the dataset
TEXT=examples/translation/wmt14_en_fr
fairseq-preprocess \
--source-lang en --target-lang fr \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0 \
--workers 60
# Train the model
mkdir -p checkpoints/fconv_wmt_en_fr
fairseq-train \
data-bin/wmt14_en_fr \
--arch fconv_wmt_en_fr \
--dropout 0.1 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer nag --clip-norm 0.1 \
--lr 0.5 --lr-scheduler fixed --force-anneal 50 \
--max-tokens 3000 \
--save-dir checkpoints/fconv_wmt_en_fr
# Evaluate
fairseq-generate \
data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt \
--beam 5 --remove-bpe
```
## Multilingual Translation
We also support training multilingual translation models. In this example we'll
train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets.
Note that we use slightly different preprocessing here than for the IWSLT'14
En-De data above. In particular we learn a joint BPE code for all three
languages and use fairseq-interactive and sacrebleu for scoring the test set.
```bash
# First install sacrebleu and sentencepiece
pip install sacrebleu sentencepiece
# Then download and preprocess the data
cd examples/translation/
bash prepare-iwslt17-multilingual.sh
cd ../..
# Binarize the de-en dataset
TEXT=examples/translation/iwslt17.de_fr.en.bpe16k
fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train.bpe.de-en \
--validpref $TEXT/valid0.bpe.de-en,$TEXT/valid1.bpe.de-en,$TEXT/valid2.bpe.de-en,$TEXT/valid3.bpe.de-en,$TEXT/valid4.bpe.de-en,$TEXT/valid5.bpe.de-en \
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
--workers 10
# Binarize the fr-en dataset
# NOTE: it's important to reuse the en dictionary from the previous step
fairseq-preprocess --source-lang fr --target-lang en \
--trainpref $TEXT/train.bpe.fr-en \
--validpref $TEXT/valid0.bpe.fr-en,$TEXT/valid1.bpe.fr-en,$TEXT/valid2.bpe.fr-en,$TEXT/valid3.bpe.fr-en,$TEXT/valid4.bpe.fr-en,$TEXT/valid5.bpe.fr-en \
--tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
--workers 10
# Train a multilingual transformer model
# NOTE: the command below assumes 1 GPU, but accumulates gradients from
# 8 fwd/bwd passes to simulate training on 8 GPUs
mkdir -p checkpoints/multilingual_transformer
CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \
--max-epoch 50 \
--ddp-backend=no_c10d \
--task multilingual_translation --lang-pairs de-en,fr-en \
--arch multilingual_transformer_iwslt_de_en \
--share-decoders --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--label-smoothing 0.1 --criterion label_smoothed_cross_entropy \
--dropout 0.3 --weight-decay 0.0001 \
--save-dir checkpoints/multilingual_transformer \
--max-tokens 4000 \
--update-freq 8
# Generate and score the test set with sacrebleu
SRC=de
sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
| python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
> iwslt17.test.${SRC}-en.${SRC}.bpe
cat iwslt17.test.${SRC}-en.${SRC}.bpe \
| fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
--task multilingual_translation --lang-pairs de-en,fr-en \
--source-lang ${SRC} --target-lang en \
--path checkpoints/multilingual_transformer/checkpoint_best.pt \
--buffer-size 2000 --batch-size 128 \
--beam 5 --remove-bpe=sentencepiece \
> iwslt17.test.${SRC}-en.en.sys
grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
| sacrebleu --test-set iwslt17 --language-pair ${SRC}-en
```
##### Argument format during inference
During inference it is required to specify a single `--source-lang` and
`--target-lang`, which indicates the inference langauge direction.
`--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to
the same value as training.
#!/usr/bin/env bash
#
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
LC=$SCRIPTS/tokenizer/lowercase.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=10000
URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz"
GZ=de-en.tgz
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=de
tgt=en
lang=de-en
prep=iwslt14.tokenized.de-en
tmp=$prep/tmp
orig=orig
mkdir -p $orig $tmp $prep
echo "Downloading data from ${URL}..."
cd $orig
wget "$URL"
if [ -f $GZ ]; then
echo "Data successfully downloaded."
else
echo "Data not successfully downloaded."
exit
fi
tar zxvf $GZ
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
f=train.tags.$lang.$l
tok=train.tags.$lang.tok.$l
cat $orig/$lang/$f | \
grep -v '<url>' | \
grep -v '<talkid>' | \
grep -v '<keywords>' | \
sed -e 's/<title>//g' | \
sed -e 's/<\/title>//g' | \
sed -e 's/<description>//g' | \
sed -e 's/<\/description>//g' | \
perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
echo ""
done
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175
for l in $src $tgt; do
perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
done
echo "pre-processing valid/test data..."
for l in $src $tgt; do
for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do
fname=${o##*/}
f=$tmp/${fname%.*}
echo $o $f
grep '<seg id' $o | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -l $l | \
perl $LC > $f
echo ""
done
done
echo "creating train, valid, test..."
for l in $src $tgt; do
awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l
awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l
cat $tmp/IWSLT14.TED.dev2010.de-en.$l \
$tmp/IWSLT14.TEDX.dev2012.de-en.$l \
$tmp/IWSLT14.TED.tst2010.de-en.$l \
$tmp/IWSLT14.TED.tst2011.de-en.$l \
$tmp/IWSLT14.TED.tst2012.de-en.$l \
> $tmp/test.$l
done
TRAIN=$tmp/train.en-de
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
cat $tmp/train.$l >> $TRAIN
done
echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
for L in $src $tgt; do
for f in train.$L valid.$L test.$L; do
echo "apply_bpe.py to ${f}..."
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f
done
done
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
SRCS=(
"de"
"fr"
)
TGT=en
ROOT=$(dirname "$0")
SCRIPTS=$ROOT/../../scripts
SPM_TRAIN=$SCRIPTS/spm_train.py
SPM_ENCODE=$SCRIPTS/spm_encode.py
BPESIZE=16384
ORIG=$ROOT/iwslt17_orig
DATA=$ROOT/iwslt17.de_fr.en.bpe16k
mkdir -p "$ORIG" "$DATA"
TRAIN_MINLEN=1 # remove sentences with <1 BPE token
TRAIN_MAXLEN=250 # remove sentences with >250 BPE tokens
URLS=(
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/de/en/de-en.tgz"
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz"
)
ARCHIVES=(
"de-en.tgz"
"fr-en.tgz"
)
VALID_SETS=(
"IWSLT17.TED.dev2010.de-en IWSLT17.TED.tst2010.de-en IWSLT17.TED.tst2011.de-en IWSLT17.TED.tst2012.de-en IWSLT17.TED.tst2013.de-en IWSLT17.TED.tst2014.de-en IWSLT17.TED.tst2015.de-en"
"IWSLT17.TED.dev2010.fr-en IWSLT17.TED.tst2010.fr-en IWSLT17.TED.tst2011.fr-en IWSLT17.TED.tst2012.fr-en IWSLT17.TED.tst2013.fr-en IWSLT17.TED.tst2014.fr-en IWSLT17.TED.tst2015.fr-en"
)
# download and extract data
for ((i=0;i<${#URLS[@]};++i)); do
ARCHIVE=$ORIG/${ARCHIVES[i]}
if [ -f "$ARCHIVE" ]; then
echo "$ARCHIVE already exists, skipping download"
else
URL=${URLS[i]}
wget -P "$ORIG" "$URL"
if [ -f "$ARCHIVE" ]; then
echo "$URL successfully downloaded."
else
echo "$URL not successfully downloaded."
exit 1
fi
fi
FILE=${ARCHIVE: -4}
if [ -e "$FILE" ]; then
echo "$FILE already exists, skipping extraction"
else
tar -C "$ORIG" -xzvf "$ARCHIVE"
fi
done
echo "pre-processing train data..."
for SRC in "${SRCS[@]}"; do
for LANG in "${SRC}" "${TGT}"; do
cat "$ORIG/${SRC}-${TGT}/train.tags.${SRC}-${TGT}.${LANG}" \
| grep -v '<url>' \
| grep -v '<talkid>' \
| grep -v '<keywords>' \
| grep -v '<speaker>' \
| grep -v '<reviewer' \
| grep -v '<translator' \
| grep -v '<doc' \
| grep -v '</doc>' \
| sed -e 's/<title>//g' \
| sed -e 's/<\/title>//g' \
| sed -e 's/<description>//g' \
| sed -e 's/<\/description>//g' \
| sed 's/^\s*//g' \
| sed 's/\s*$//g' \
> "$DATA/train.${SRC}-${TGT}.${LANG}"
done
done
echo "pre-processing valid data..."
for ((i=0;i<${#SRCS[@]};++i)); do
SRC=${SRCS[i]}
VALID_SET=(${VALID_SETS[i]})
for ((j=0;j<${#VALID_SET[@]};++j)); do
FILE=${VALID_SET[j]}
for LANG in "$SRC" "$TGT"; do
grep '<seg id' "$ORIG/${SRC}-${TGT}/${FILE}.${LANG}.xml" \
| sed -e 's/<seg id="[0-9]*">\s*//g' \
| sed -e 's/\s*<\/seg>\s*//g' \
| sed -e "s/\’/\'/g" \
> "$DATA/valid${j}.${SRC}-${TGT}.${LANG}"
done
done
done
# learn BPE with sentencepiece
TRAIN_FILES=$(for SRC in "${SRCS[@]}"; do echo $DATA/train.${SRC}-${TGT}.${SRC}; echo $DATA/train.${SRC}-${TGT}.${TGT}; done | tr "\n" ",")
echo "learning joint BPE over ${TRAIN_FILES}..."
python "$SPM_TRAIN" \
--input=$TRAIN_FILES \
--model_prefix=$DATA/sentencepiece.bpe \
--vocab_size=$BPESIZE \
--character_coverage=1.0 \
--model_type=bpe
# encode train/valid
echo "encoding train with learned BPE..."
for SRC in "${SRCS[@]}"; do
python "$SPM_ENCODE" \
--model "$DATA/sentencepiece.bpe.model" \
--output_format=piece \
--inputs $DATA/train.${SRC}-${TGT}.${SRC} $DATA/train.${SRC}-${TGT}.${TGT} \
--outputs $DATA/train.bpe.${SRC}-${TGT}.${SRC} $DATA/train.bpe.${SRC}-${TGT}.${TGT} \
--min-len $TRAIN_MINLEN --max-len $TRAIN_MAXLEN
done
echo "encoding valid with learned BPE..."
for ((i=0;i<${#SRCS[@]};++i)); do
SRC=${SRCS[i]}
VALID_SET=(${VALID_SETS[i]})
for ((j=0;j<${#VALID_SET[@]};++j)); do
python "$SPM_ENCODE" \
--model "$DATA/sentencepiece.bpe.model" \
--output_format=piece \
--inputs $DATA/valid${j}.${SRC}-${TGT}.${SRC} $DATA/valid${j}.${SRC}-${TGT}.${TGT} \
--outputs $DATA/valid${j}.bpe.${SRC}-${TGT}.${SRC} $DATA/valid${j}.bpe.${SRC}-${TGT}.${TGT}
done
done
#!/bin/bash
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=40000
URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz"
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
"http://statmt.org/wmt14/test-full.tgz"
)
FILES=(
"training-parallel-europarl-v7.tgz"
"training-parallel-commoncrawl.tgz"
"training-parallel-nc-v12.tgz"
"dev.tgz"
"test-full.tgz"
)
CORPORA=(
"training/europarl-v7.de-en"
"commoncrawl.de-en"
"training/news-commentary-v12.de-en"
)
# This will make the dataset compatible to the one used in "Convolutional Sequence to Sequence Learning"
# https://arxiv.org/abs/1705.03122
if [ "$1" == "--icml17" ]; then
URLS[2]="http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
FILES[2]="training-parallel-nc-v9.tgz"
CORPORA[2]="training/news-commentary-v9.de-en"
OUTDIR=wmt14_en_de
else
OUTDIR=wmt17_en_de
fi
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=en
tgt=de
lang=en-de
prep=$OUTDIR
tmp=$prep/tmp
orig=orig
dev=dev/newstest2013
mkdir -p $orig $tmp $prep
cd $orig
for ((i=0;i<${#URLS[@]};++i)); do
file=${FILES[i]}
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
url=${URLS[i]}
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit -1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
fi
fi
done
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
rm $tmp/train.tags.$lang.tok.$l
for f in "${CORPORA[@]}"; do
cat $orig/$f.$l | \
perl $NORM_PUNC $l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
done
done
echo "pre-processing test data..."
for l in $src $tgt; do
if [ "$l" == "$src" ]; then
t="src"
else
t="ref"
fi
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
echo ""
done
echo "splitting train and valid..."
for l in $src $tgt; do
awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
done
TRAIN=$tmp/train.de-en
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
cat $tmp/train.$l >> $TRAIN
done
echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
for L in $src $tgt; do
for f in train.$L valid.$L test.$L; do
echo "apply_bpe.py to ${f}..."
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
done
done
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
for L in $src $tgt; do
cp $tmp/bpe.test.$L $prep/test.$L
done
#!/bin/bash
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=40000
URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
"http://statmt.org/wmt13/training-parallel-un.tgz"
"http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
"http://statmt.org/wmt10/training-giga-fren.tar"
"http://statmt.org/wmt14/test-full.tgz"
)
FILES=(
"training-parallel-europarl-v7.tgz"
"training-parallel-commoncrawl.tgz"
"training-parallel-un.tgz"
"training-parallel-nc-v9.tgz"
"training-giga-fren.tar"
"test-full.tgz"
)
CORPORA=(
"training/europarl-v7.fr-en"
"commoncrawl.fr-en"
"un/undoc.2000.fr-en"
"training/news-commentary-v9.fr-en"
"giga-fren.release2.fixed"
)
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=en
tgt=fr
lang=en-fr
prep=wmt14_en_fr
tmp=$prep/tmp
orig=orig
mkdir -p $orig $tmp $prep
cd $orig
for ((i=0;i<${#URLS[@]};++i)); do
file=${FILES[i]}
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
url=${URLS[i]}
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit -1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
fi
fi
done
gunzip giga-fren.release2.fixed.*.gz
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
rm $tmp/train.tags.$lang.tok.$l
for f in "${CORPORA[@]}"; do
cat $orig/$f.$l | \
perl $NORM_PUNC $l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
done
done
echo "pre-processing test data..."
for l in $src $tgt; do
if [ "$l" == "$src" ]; then
t="src"
else
t="ref"
fi
grep '<seg id' $orig/test-full/newstest2014-fren-$t.$l.sgm | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
echo ""
done
echo "splitting train and valid..."
for l in $src $tgt; do
awk '{if (NR%1333 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
awk '{if (NR%1333 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
done
TRAIN=$tmp/train.fr-en
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
cat $tmp/train.$l >> $TRAIN
done
echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
for L in $src $tgt; do
for f in train.$L valid.$L test.$L; do
echo "apply_bpe.py to ${f}..."
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
done
done
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
for L in $src $tgt; do
cp $tmp/bpe.test.$L $prep/test.$L
done
# Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)
This page includes instructions for reproducing results from the paper [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).
## Download data
First, follow the [instructions to download and preprocess the WMT'17 En-De dataset](../translation#prepare-wmt14en2desh).
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
## Train a model
Then we can train a mixture of experts model using the `translation_moe` task.
Use the `--method` flag to choose the MoE variant; we support hard mixtures with a learned or uniform prior (`--method hMoElp` and `hMoEup`, respectively) and soft mixures (`--method sMoElp` and `sMoEup`).
The model is trained with online responsibility assignment and shared parameterization.
The following command will train a `hMoElp` model with `3` experts:
```bash
fairseq-train --ddp-backend='no_c10d' \
data-bin/wmt17_en_de \
--max-update 100000 \
--task translation_moe --user-dir examples/translation_moe/translation_moe_src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
--lr 0.0007 --min-lr 1e-09 \
--dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \
--max-tokens 3584
```
## Translate
Once a model is trained, we can generate translations from different experts using the `--gen-expert` option.
For example, to generate from expert 0:
```bash
fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--task translation_moe --user-dir examples/translation_moe/translation_moe_src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert 0
```
## Evaluate
First download a tokenized version of the WMT'14 En-De test set with multiple references:
```bash
wget dl.fbaipublicfiles.com/fairseq/data/wmt14-en-de.extra_refs.tok
```
Next apply BPE on the fly and run generation for each expert:
```bash
BPE_CODE=examples/translation/wmt17_en_de/code
for EXPERT in $(seq 0 2); do \
cat wmt14-en-de.extra_refs.tok \
| grep ^S | cut -f 2 \
| fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 \
--bpe subword_nmt --bpe-codes $BPE_CODE \
--buffer-size 500 --max-tokens 6000 \
--task translation_moe --user-dir examples/translation_moe/translation_moe_src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert $EXPERT ; \
done > wmt14-en-de.extra_refs.tok.gen.3experts
```
Finally use `score_moe.py` to compute pairwise BLUE and average oracle BLEU:
```bash
python examples/translation_moe/score.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
# pairwise BLEU: 48.26
# #refs covered: 2.11
# multi-reference BLEU (leave-one-out): 59.46
```
This matches row 3 from Table 7 in the paper.
## Citation
```bibtex
@article{shen2019mixture,
title = {Mixture Models for Diverse Machine Translation: Tricks of the Trade},
author = {Tianxiao Shen and Myle Ott and Michael Auli and Marc'Aurelio Ranzato},
journal = {International Conference on Machine Learning},
year = 2019,
}
```
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
candidate hypotheses.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
"""
import argparse
import random
import sys
from itertools import chain
import numpy as np
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
def main():
parser = argparse.ArgumentParser(sys.argv[0])
parser.add_argument(
"--sys", nargs="*", default="", metavar="FILE", help="path to system output"
)
parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
parser.add_argument(
"--output",
default="",
metavar="FILE",
help="print outputs into a pretty format",
)
args = parser.parse_args()
if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys)
print("pairwise BLEU: %.2f" % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
if args.ref:
_, _, refs = load_ref(args.ref)
if args.sys:
multi_ref(refs, hypos)
else:
intra_ref(refs)
def dictolist(d):
a = sorted(d.items(), key=lambda i: i[0])
return [i[1] for i in a]
def load_sys(paths):
src, tgt, hypos, log_probs = {}, {}, {}, {}
for path in paths:
with open(path) as f:
for line in f:
line = line.rstrip()
# S: source
# T: target
# D: detokenized system output
if line.startswith(("S-", "T-", "D-")):
i = int(line[line.find("-") + 1 : line.find("\t")])
if line.startswith("S-"):
src[i] = line.split("\t")[1]
if line.startswith("T-"):
tgt[i] = line.split("\t")[1]
if line.startswith("D-"):
if i not in hypos:
hypos[i] = []
log_probs[i] = []
hypos[i].append(line.split("\t")[2])
log_probs[i].append(float(line.split("\t")[1]))
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
def load_ref(path):
with open(path) as f:
lines = f.readlines()
src, tgt, refs = [], [], []
i = 0
while i < len(lines):
if lines[i].startswith("S-"):
src.append(lines[i].split("\t")[1].rstrip())
i += 1
elif lines[i].startswith("T-"):
tgt.append(lines[i].split("\t")[1].rstrip())
i += 1
else:
a = []
while i < len(lines) and lines[i].startswith("R"):
a.append(lines[i].split("\t")[1].rstrip())
i += 1
refs.append(a)
return src, tgt, refs
def merge(src, tgt, hypos, log_probs, path):
with open(path, "w") as f:
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
f.write(s + "\n")
f.write(t + "\n")
f.write("\n")
for h, lp in zip(hs, lps):
f.write("\t%f\t%s\n" % (lp, h.strip()))
f.write("------------------------------------------------------\n")
def corpus_bleu(sys_stream, ref_streams):
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
return bleu.score
def sentence_bleu(hypothesis, reference):
bleu = _corpus_bleu(hypothesis, reference)
for i in range(1, 4):
bleu.counts[i] += 1
bleu.totals[i] += 1
bleu = compute_bleu(
bleu.counts,
bleu.totals,
bleu.sys_len,
bleu.ref_len,
smooth_method="exp",
)
return bleu.score
def pairwise(sents):
_ref, _hypo = [], []
for s in sents:
for i in range(len(s)):
for j in range(len(s)):
if i != j:
_ref.append(s[i])
_hypo.append(s[j])
return corpus_bleu(_hypo, [_ref])
def multi_ref(refs, hypos):
_ref, _hypo = [], []
ref_cnt = 0
assert len(refs) == len(hypos)
# count number of refs covered
for rs, hs in zip(refs, hypos):
a = set()
for h in hs:
s = [sentence_bleu(h, r) for r in rs]
j = np.argmax(s)
_ref.append(rs[j])
_hypo.append(h)
best = [k for k in range(len(rs)) if s[k] == s[j]]
a.add(random.choice(best))
ref_cnt += len(a)
print("#refs covered: %.2f" % (ref_cnt / len(refs)))
# transpose refs and hypos
refs = list(zip(*refs))
hypos = list(zip(*hypos))
# compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
k = len(hypos)
m = len(refs)
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
loo_bleus = []
for held_out_ref in range(m):
remaining_refs = (
duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
)
assert len(remaining_refs) == m - 1
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
def intra_ref(refs):
print("ref pairwise BLEU: %.2f" % pairwise(refs))
refs = list(zip(*refs))
m = len(refs)
concat_h = []
concat_rest = [[] for j in range(m - 1)]
for i, h in enumerate(refs):
rest = refs[:i] + refs[i + 1 :]
concat_h.append(h)
for j in range(m - 1):
concat_rest[j].extend(rest[j])
concat_h = list(chain.from_iterable(concat_h))
bleu = corpus_bleu(concat_h, concat_rest)
print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import translation_moe # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
class LogSumExpMoE(torch.autograd.Function):
"""Standard LogSumExp forward pass, but use *posterior* for the backward.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
"""
@staticmethod
def forward(ctx, logp, posterior, dim=-1):
ctx.save_for_backward(posterior)
ctx.dim = dim
return torch.logsumexp(logp, dim=dim)
@staticmethod
def backward(ctx, grad_output):
(posterior,) = ctx.saved_tensors
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
return grad_logp, None, None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
class MeanPoolGatingNetwork(torch.nn.Module):
"""A simple mean-pooling gating network for selecting experts.
This module applies mean pooling over an encoder's output and returns
reponsibilities for each expert. The encoder format is expected to match
:class:`fairseq.models.transformer.TransformerEncoder`.
"""
def __init__(self, embed_dim, num_experts, dropout=None):
super().__init__()
self.embed_dim = embed_dim
self.num_experts = num_experts
self.fc1 = torch.nn.Linear(embed_dim, embed_dim)
self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None
self.fc2 = torch.nn.Linear(embed_dim, num_experts)
def forward(self, encoder_out):
if not (
hasattr(encoder_out, "encoder_out")
and hasattr(encoder_out, "encoder_padding_mask")
and encoder_out.encoder_out.size(2) == self.embed_dim
):
raise ValueError("Unexpected format for encoder_out")
# mean pooling over time
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0
ntokens = torch.sum(~encoder_padding_mask, dim=1, keepdim=True)
x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out)
else:
x = torch.mean(encoder_out, dim=1)
x = torch.tanh(self.fc1(x))
if self.dropout is not None:
x = self.dropout(x)
x = self.fc2(x)
return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import metrics, utils
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask
from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
@register_task("translation_moe")
class TranslationMoETask(TranslationTask):
"""
Translation task for Mixture of Experts (MoE) models.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
Args:
src_dict (~fairseq.data.Dictionary): dictionary for the source language
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument('--method', default='hMoEup',
choices=['sMoElp', 'sMoEup', 'hMoElp', 'hMoEup'])
parser.add_argument('--num-experts', default=3, type=int, metavar='N',
help='number of experts')
parser.add_argument('--mean-pool-gating-network', action='store_true',
help='use a simple mean-pooling gating network')
parser.add_argument('--mean-pool-gating-network-dropout', type=float,
help='dropout for mean-pooling gating network')
parser.add_argument('--mean-pool-gating-network-encoder-dim', type=float,
help='encoder output dim for mean-pooling gating network')
parser.add_argument('--gen-expert', type=int, default=0,
help='which expert to use for generation')
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
if args.method == "sMoElp":
# soft MoE with learned prior
self.uniform_prior = False
self.hard_selection = False
elif args.method == "sMoEup":
# soft MoE with uniform prior
self.uniform_prior = True
self.hard_selection = False
elif args.method == "hMoElp":
# hard MoE with learned prior
self.uniform_prior = False
self.hard_selection = True
elif args.method == "hMoEup":
# hard MoE with uniform prior
self.uniform_prior = True
self.hard_selection = True
# add indicator tokens for each expert
for i in range(args.num_experts):
# add to both dictionaries in case we're sharing embeddings
src_dict.add_symbol("<expert_{}>".format(i))
tgt_dict.add_symbol("<expert_{}>".format(i))
super().__init__(args, src_dict, tgt_dict)
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if not self.uniform_prior and not hasattr(model, "gating_network"):
if self.args.mean_pool_gating_network:
if getattr(args, "mean_pool_gating_network_encoder_dim", None):
encoder_dim = args.mean_pool_gating_network_encoder_dim
elif getattr(args, "encoder_embed_dim", None):
# assume that encoder_embed_dim is the encoder's output dimension
encoder_dim = args.encoder_embed_dim
else:
raise ValueError(
"Must specify --mean-pool-gating-network-encoder-dim"
)
if getattr(args, "mean_pool_gating_network_dropout", None):
dropout = args.mean_pool_gating_network_dropout
elif getattr(args, "dropout", None):
dropout = args.dropout
else:
raise ValueError("Must specify --mean-pool-gating-network-dropout")
model.gating_network = MeanPoolGatingNetwork(
encoder_dim,
args.num_experts,
dropout,
)
else:
raise ValueError(
"translation_moe task with learned prior requires the model to "
"have a gating network; try using --mean-pool-gating-network"
)
return model
def expert_index(self, i):
return i + self.tgt_dict.index("<expert_0>")
def _get_loss(self, sample, model, criterion):
assert hasattr(
criterion, "compute_loss"
), "translation_moe task requires the criterion to implement the compute_loss() method"
k = self.args.num_experts
bsz = sample["target"].size(0)
def get_lprob_y(encoder_out, prev_output_tokens_k):
net_output = model.decoder(
prev_output_tokens=prev_output_tokens_k,
encoder_out=encoder_out,
)
loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False)
loss = loss.view(bsz, -1)
return -loss.sum(dim=1, keepdim=True) # -> B x 1
def get_lprob_yz(winners=None):
encoder_out = model.encoder(
src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
)
if winners is None:
lprob_y = []
for i in range(k):
prev_output_tokens_k = sample["net_input"][
"prev_output_tokens"
].clone()
assert not prev_output_tokens_k.requires_grad
prev_output_tokens_k[:, 0] = self.expert_index(i)
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
else:
prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone()
prev_output_tokens_k[:, 0] = self.expert_index(winners)
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
if self.uniform_prior:
lprob_yz = lprob_y
else:
lprob_z = model.gating_network(encoder_out) # B x K
if winners is not None:
lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1))
lprob_yz = lprob_y + lprob_z.type_as(lprob_y) # B x K
return lprob_yz
# compute responsibilities without dropout
with utils.model_eval(model): # disable dropout
with torch.no_grad(): # disable autograd
lprob_yz = get_lprob_yz() # B x K
prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
assert not prob_z_xy.requires_grad
# compute loss with dropout
if self.hard_selection:
winners = prob_z_xy.max(dim=1)[1]
loss = -get_lprob_yz(winners)
else:
lprob_yz = get_lprob_yz() # B x K
loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
loss = loss.sum()
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": bsz,
"sample_size": sample_size,
"posterior": prob_z_xy.float().sum(dim=0).cpu(),
}
return loss, sample_size, logging_output
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(
self,
generator,
models,
sample,
prefix_tokens=None,
expert=None,
constraints=None,
):
expert = expert or self.args.gen_expert
with torch.no_grad():
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
bos_token=self.expert_index(expert),
)
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
metrics.log_scalar(
"posterior",
sum(log["posterior"] for log in logging_outputs if "posterior" in log),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment