Commit 18d27e00 authored by wangwei990215's avatar wangwei990215
Browse files

initial commit

parent 541f4c7a
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules.fairseq_dropout import FairseqDropout
default_conv_enc_config = """[
(400, 13, 170, 0.2),
(440, 14, 0, 0.214),
(484, 15, 0, 0.22898),
(532, 16, 0, 0.2450086),
(584, 17, 0, 0.262159202),
(642, 18, 0, 0.28051034614),
(706, 19, 0, 0.30014607037),
(776, 20, 0, 0.321156295296),
(852, 21, 0, 0.343637235966),
(936, 22, 0, 0.367691842484),
(1028, 23, 0, 0.393430271458),
(1130, 24, 0, 0.42097039046),
(1242, 25, 0, 0.450438317792),
(1366, 26, 0, 0.481969000038),
(1502, 27, 0, 0.51570683004),
(1652, 28, 0, 0.551806308143),
(1816, 29, 0, 0.590432749713),
]"""
@register_model("asr_w2l_conv_glu_encoder")
class W2lConvGluEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--conv-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one conv layer
[(out_channels, kernel_size, padding, dropout), ...]
""",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
encoder = W2lConvGluEncoder(
vocab_size=len(task.target_dictionary),
input_feat_per_channel=args.input_feat_per_channel,
in_channels=args.in_channels,
conv_enc_config=eval(conv_enc_config),
)
return cls(encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = False
return lprobs
class W2lConvGluEncoder(FairseqEncoder):
def __init__(
self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config
):
super().__init__(None)
self.input_dim = input_feat_per_channel
if in_channels != 1:
raise ValueError("only 1 input channel is currently supported")
self.conv_layers = nn.ModuleList()
self.linear_layers = nn.ModuleList()
self.dropouts = []
cur_channels = input_feat_per_channel
for out_channels, kernel_size, padding, dropout in conv_enc_config:
layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding)
layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init
self.conv_layers.append(nn.utils.weight_norm(layer))
self.dropouts.append(
FairseqDropout(dropout, module_name=self.__class__.__name__)
)
if out_channels % 2 != 0:
raise ValueError("odd # of out_channels is incompatible with GLU")
cur_channels = out_channels // 2 # halved by GLU
for out_channels in [2 * cur_channels, vocab_size]:
layer = nn.Linear(cur_channels, out_channels)
layer.weight.data.mul_(math.sqrt(3))
self.linear_layers.append(nn.utils.weight_norm(layer))
cur_channels = out_channels // 2
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
B, T, _ = src_tokens.size()
x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
x = F.glu(x, dim=1)
x = self.dropouts[layer_idx](x)
x = x.transpose(1, 2).contiguous() # (B, T, 908)
x = self.linear_layers[0](x)
x = F.glu(x, dim=2)
x = self.dropouts[-1](x)
x = self.linear_layers[1](x)
assert x.size(0) == B
assert x.size(1) == T
encoder_out = x.transpose(0, 1) # (T, B, vocab_size)
# need to debug this -- find a simpler/elegant way in pytorch APIs
encoder_padding_mask = (
torch.arange(T).view(1, T).expand(B, -1).to(x.device)
>= src_lengths.view(B, 1).expand(-1, T)
).t() # (B x T) -> (T x B)
return {
"encoder_out": encoder_out, # (T, B, vocab_size)
"encoder_padding_mask": encoder_padding_mask, # (T, B)
}
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return (1e6, 1e6) # an arbitrary large number
@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc")
def w2l_conv_glu_enc(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.in_channels = getattr(args, "in_channels", 1)
args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
task_name = file[: file.find(".py")]
importlib.import_module("examples.speech_recognition.tasks." + task_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import re
import sys
import torch
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
from fairseq.data import Dictionary
from fairseq.tasks import LegacyFairseqTask, register_task
def get_asr_dataset_from_json(data_json_path, tgt_dict):
"""
Parse data json and create dataset.
See scripts/asr_prep_json.py which pack json from raw files
Json example:
{
"utts": {
"4771-29403-0025": {
"input": {
"length_ms": 170,
"path": "/tmp/file1.flac"
},
"output": {
"text": "HELLO \n",
"token": "HE LLO",
"tokenid": "4815, 861"
}
},
"1564-142299-0096": {
...
}
}
"""
if not os.path.isfile(data_json_path):
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
with open(data_json_path, "rb") as f:
data_samples = json.load(f)["utts"]
assert len(data_samples) != 0
sorted_samples = sorted(
data_samples.items(),
key=lambda sample: int(sample[1]["input"]["length_ms"]),
reverse=True,
)
aud_paths = [s[1]["input"]["path"] for s in sorted_samples]
ids = [s[0] for s in sorted_samples]
speakers = []
for s in sorted_samples:
m = re.search("(.+?)-(.+?)-(.+?)", s[0])
speakers.append(m.group(1) + "_" + m.group(2))
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
tgt = [
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
for s in sorted_samples
]
# append eos
tgt = [[*t, tgt_dict.eos()] for t in tgt]
return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
@register_task("speech_recognition")
class SpeechRecognitionTask(LegacyFairseqTask):
"""
Task for training speech recognition model.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("data", help="path to data directory")
parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
)
parser.add_argument(
"--max-source-positions",
default=sys.maxsize,
type=int,
metavar="N",
help="max number of frames in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, tgt_dict):
super().__init__(args)
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries)."""
dict_path = os.path.join(args.data, "dict.txt")
if not os.path.isfile(dict_path):
raise FileNotFoundError("Dict not found: {}".format(dict_path))
tgt_dict = Dictionary.load(dict_path)
if args.criterion == "ctc_loss":
tgt_dict.add_symbol("<ctc_blank>")
elif args.criterion == "asg_loss":
for i in range(1, args.max_replabel + 1):
tgt_dict.add_symbol(replabel_symbol(i))
print("| dictionary: {} types".format(len(tgt_dict)))
return cls(args, tgt_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
def build_generator(self, models, args, **unused):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, self.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, self.target_dictionary)
elif w2l_decoder == "fairseqlm":
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
return W2lFairseqLMDecoder(args, self.target_dictionary)
else:
return super().build_generator(models, args)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.tgt_dict
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
return None
def max_positions(self):
"""Return the max speech and sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import re
from collections import deque
from enum import Enum
import numpy as np
"""
Utility modules for computation of Word Error Rate,
Alignments, as well as more granular metrics like
deletion, insersion and substitutions.
"""
class Code(Enum):
match = 1
substitution = 2
insertion = 3
deletion = 4
class Token(object):
def __init__(self, lbl="", st=np.nan, en=np.nan):
if np.isnan(st):
self.label, self.start, self.end = "", 0.0, 0.0
else:
self.label, self.start, self.end = lbl, st, en
class AlignmentResult(object):
def __init__(self, refs, hyps, codes, score):
self.refs = refs # std::deque<int>
self.hyps = hyps # std::deque<int>
self.codes = codes # std::deque<Code>
self.score = score # float
def coordinate_to_offset(row, col, ncols):
return int(row * ncols + col)
def offset_to_row(offset, ncols):
return int(offset / ncols)
def offset_to_col(offset, ncols):
return int(offset % ncols)
def trimWhitespace(str):
return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str)))
def str2toks(str):
pieces = trimWhitespace(str).split(" ")
toks = []
for p in pieces:
toks.append(Token(p, 0.0, 0.0))
return toks
class EditDistance(object):
def __init__(self, time_mediated):
self.time_mediated_ = time_mediated
self.scores_ = np.nan # Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>
self.backtraces_ = (
np.nan
) # Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic> backtraces_;
self.confusion_pairs_ = {}
def cost(self, ref, hyp, code):
if self.time_mediated_:
if code == Code.match:
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end)
elif code == Code.insertion:
return hyp.end - hyp.start
elif code == Code.deletion:
return ref.end - ref.start
else: # substitution
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1
else:
if code == Code.match:
return 0
elif code == Code.insertion or code == Code.deletion:
return 3
else: # substitution
return 4
def get_result(self, refs, hyps):
res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan)
num_rows, num_cols = self.scores_.shape
res.score = self.scores_[num_rows - 1, num_cols - 1]
curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
while curr_offset != 0:
curr_row = offset_to_row(curr_offset, num_cols)
curr_col = offset_to_col(curr_offset, num_cols)
prev_offset = self.backtraces_[curr_row, curr_col]
prev_row = offset_to_row(prev_offset, num_cols)
prev_col = offset_to_col(prev_offset, num_cols)
res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++
res.hyps.appendleft(curr_col - 1)
if curr_row - 1 == prev_row and curr_col == prev_col:
res.codes.appendleft(Code.deletion)
elif curr_row == prev_row and curr_col - 1 == prev_col:
res.codes.appendleft(Code.insertion)
else:
# assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
ref_str = refs[res.refs[0]].label
hyp_str = hyps[res.hyps[0]].label
if ref_str == hyp_str:
res.codes.appendleft(Code.match)
else:
res.codes.appendleft(Code.substitution)
confusion_pair = "%s -> %s" % (ref_str, hyp_str)
if confusion_pair not in self.confusion_pairs_:
self.confusion_pairs_[confusion_pair] = 1
else:
self.confusion_pairs_[confusion_pair] += 1
curr_offset = prev_offset
return res
def align(self, refs, hyps):
if len(refs) == 0 and len(hyps) == 0:
return np.nan
# NOTE: we're not resetting the values in these matrices because every value
# will be overridden in the loop below. If this assumption doesn't hold,
# be sure to set all entries in self.scores_ and self.backtraces_ to 0.
self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1))
self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1))
num_rows, num_cols = self.scores_.shape
for i in range(num_rows):
for j in range(num_cols):
if i == 0 and j == 0:
self.scores_[i, j] = 0.0
self.backtraces_[i, j] = 0
continue
if i == 0:
self.scores_[i, j] = self.scores_[i, j - 1] + self.cost(
None, hyps[j - 1], Code.insertion
)
self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols)
continue
if j == 0:
self.scores_[i, j] = self.scores_[i - 1, j] + self.cost(
refs[i - 1], None, Code.deletion
)
self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols)
continue
# Below here both i and j are greater than 0
ref = refs[i - 1]
hyp = hyps[j - 1]
best_score = self.scores_[i - 1, j - 1] + (
self.cost(ref, hyp, Code.match)
if (ref.label == hyp.label)
else self.cost(ref, hyp, Code.substitution)
)
prev_row = i - 1
prev_col = j - 1
ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion)
if ins < best_score:
best_score = ins
prev_row = i
prev_col = j - 1
delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion)
if delt < best_score:
best_score = delt
prev_row = i - 1
prev_col = j
self.scores_[i, j] = best_score
self.backtraces_[i, j] = coordinate_to_offset(
prev_row, prev_col, num_cols
)
return self.get_result(refs, hyps)
class WERTransformer(object):
def __init__(self, hyp_str, ref_str, verbose=True):
self.ed_ = EditDistance(False)
self.id2oracle_errs_ = {}
self.utts_ = 0
self.words_ = 0
self.insertions_ = 0
self.deletions_ = 0
self.substitutions_ = 0
self.process(["dummy_str", hyp_str, ref_str])
if verbose:
print("'%s' vs '%s'" % (hyp_str, ref_str))
self.report_result()
def process(self, input): # std::vector<std::string>&& input
if len(input) < 3:
print(
"Input must be of the form <id> ... <hypo> <ref> , got ",
len(input),
" inputs:",
)
return None
# Align
# std::vector<Token> hyps;
# std::vector<Token> refs;
hyps = str2toks(input[-2])
refs = str2toks(input[-1])
alignment = self.ed_.align(refs, hyps)
if alignment is None:
print("Alignment is null")
return np.nan
# Tally errors
ins = 0
dels = 0
subs = 0
for code in alignment.codes:
if code == Code.substitution:
subs += 1
elif code == Code.insertion:
ins += 1
elif code == Code.deletion:
dels += 1
# Output
row = input
row.append(str(len(refs)))
row.append(str(ins))
row.append(str(dels))
row.append(str(subs))
# print(row)
# Accumulate
kIdIndex = 0
kNBestSep = "/"
pieces = input[kIdIndex].split(kNBestSep)
if len(pieces) == 0:
print(
"Error splitting ",
input[kIdIndex],
" on '",
kNBestSep,
"', got empty list",
)
return np.nan
id = pieces[0]
if id not in self.id2oracle_errs_:
self.utts_ += 1
self.words_ += len(refs)
self.insertions_ += ins
self.deletions_ += dels
self.substitutions_ += subs
self.id2oracle_errs_[id] = [ins, dels, subs]
else:
curr_err = ins + dels + subs
prev_err = np.sum(self.id2oracle_errs_[id])
if curr_err < prev_err:
self.id2oracle_errs_[id] = [ins, dels, subs]
return 0
def report_result(self):
# print("---------- Summary ---------------")
if self.words_ == 0:
print("No words counted")
return
# 1-best
best_wer = (
100.0
* (self.insertions_ + self.deletions_ + self.substitutions_)
/ self.words_
)
print(
"\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, "
"%0.2f%% dels, %0.2f%% subs)"
% (
best_wer,
self.utts_,
self.words_,
100.0 * self.insertions_ / self.words_,
100.0 * self.deletions_ / self.words_,
100.0 * self.substitutions_ / self.words_,
)
)
def wer(self):
if self.words_ == 0:
wer = np.nan
else:
wer = (
100.0
* (self.insertions_ + self.deletions_ + self.substitutions_)
/ self.words_
)
return wer
def stats(self):
if self.words_ == 0:
stats = {}
else:
wer = (
100.0
* (self.insertions_ + self.deletions_ + self.substitutions_)
/ self.words_
)
stats = dict(
{
"wer": wer,
"utts": self.utts_,
"numwords": self.words_,
"ins": self.insertions_,
"dels": self.deletions_,
"subs": self.substitutions_,
"confusion_pairs": self.ed_.confusion_pairs_,
}
)
return stats
def calc_wer(hyp_str, ref_str):
t = WERTransformer(hyp_str, ref_str, verbose=0)
return t.wer()
def calc_wer_stats(hyp_str, ref_str):
t = WERTransformer(hyp_str, ref_str, verbose=0)
return t.stats()
def get_wer_alignment_codes(hyp_str, ref_str):
"""
INPUT: hypothesis string, reference string
OUTPUT: List of alignment codes (intermediate results from WER computation)
"""
t = WERTransformer(hyp_str, ref_str, verbose=0)
return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes
def merge_counts(x, y):
# Merge two hashes which have 'counts' as their values
# This can be used for example to merge confusion pair counts
# conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs'])
for k, v in y.items():
if k not in x:
x[k] = 0
x[k] += v
return x
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wav2letter decoders.
"""
import gc
import itertools as it
import os.path as osp
import warnings
from collections import deque, namedtuple
import numpy as np
import torch
from examples.speech_recognition.data.replabels import unpack_replabels
from fairseq import tasks
from fairseq.utils import apply_to_sample
try:
from wav2letter.common import create_word_dict, load_words
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from wav2letter.decoder import (
CriterionType,
DecoderOptions,
KenLM,
LM,
LMState,
SmearingMode,
Trie,
LexiconDecoder,
LexiconFreeDecoder,
)
except:
warnings.warn(
"wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings"
)
LM = object
LMState = object
class W2lDecoder(object):
def __init__(self, args, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = args.nbest
# criterion-specific init
if args.criterion == "ctc":
self.criterion_type = CriterionType.CTC
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
self.asg_transitions = None
elif args.criterion == "asg_loss":
self.criterion_type = CriterionType.ASG
self.blank = -1
self.asg_transitions = args.asg_transitions
self.max_replabel = args.max_replabel
assert len(self.asg_transitions) == self.vocab_size ** 2
else:
raise RuntimeError(f"unknown criterion: {args.criterion}")
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
# encoder_out = models[0].encoder(**encoder_input)
encoder_out = models[0](**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
elif self.criterion_type == CriterionType.ASG:
emissions = encoder_out["encoder_out"]
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
if self.criterion_type == CriterionType.CTC:
idxs = filter(lambda x: x != self.blank, idxs)
elif self.criterion_type == CriterionType.ASG:
idxs = filter(lambda x: x >= 0, idxs)
idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]
class W2lKenLMDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.silence = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
self.lexicon = load_words(args.lexicon)
self.word_dict = create_word_dict(self.lexicon)
self.unk_word = self.word_dict.get_index("<unk>")
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.trie = Trie(self.vocab_size, self.silence)
start_state = self.lm.start(False)
for i, (word, spellings) in enumerate(self.lexicon.items()):
word_idx = self.word_dict.get_index(word)
_, score = self.lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = DecoderOptions(
args.beam,
int(getattr(args, "beam_size_token", len(tgt_dict))),
args.beam_threshold,
args.lm_weight,
args.word_score,
args.unk_weight,
args.sil_weight,
0,
False,
self.criterion_type,
)
if self.asg_transitions is None:
N = 768
# self.asg_transitions = torch.FloatTensor(N, N).zero_()
self.asg_transitions = []
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
self.asg_transitions,
False,
)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append(
[
{
"tokens": self.get_tokens(result.tokens),
"score": result.score,
"words": [
self.word_dict.get_entry(x) for x in result.words if x >= 0
],
}
for result in nbest_results
]
)
return hypos
FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"])
class FairseqLM(LM):
def __init__(self, dictionary, model):
LM.__init__(self)
self.dictionary = dictionary
self.model = model
self.unk = self.dictionary.unk()
self.save_incremental = False # this currently does not work properly
self.max_cache = 20_000
model.cuda()
model.eval()
model.make_generation_fast_()
self.states = {}
self.stateq = deque()
def start(self, start_with_nothing):
state = LMState()
prefix = torch.LongTensor([[self.dictionary.eos()]])
incremental_state = {} if self.save_incremental else None
with torch.no_grad():
res = self.model(prefix.cuda(), incremental_state=incremental_state)
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
if incremental_state is not None:
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
self.states[state] = FairseqLMState(
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
)
self.stateq.append(state)
return state
def score(self, state: LMState, token_index: int, no_cache: bool = False):
"""
Evaluate language model based on the current lm state and new word
Parameters:
-----------
state: current lm state
token_index: index of the word
(can be lexicon index then you should store inside LM the
mapping between indices of lexicon and lm, or lm index of a word)
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
curr_state = self.states[state]
def trim_cache(targ_size):
while len(self.stateq) > targ_size:
rem_k = self.stateq.popleft()
rem_st = self.states[rem_k]
rem_st = FairseqLMState(rem_st.prefix, None, None)
self.states[rem_k] = rem_st
if curr_state.probs is None:
new_incremental_state = (
curr_state.incremental_state.copy()
if curr_state.incremental_state is not None
else None
)
with torch.no_grad():
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cuda(), new_incremental_state
)
elif self.save_incremental:
new_incremental_state = {}
res = self.model(
torch.from_numpy(curr_state.prefix).cuda(),
incremental_state=new_incremental_state,
)
probs = self.model.get_normalized_probs(
res, log_probs=True, sample=None
)
if new_incremental_state is not None:
new_incremental_state = apply_to_sample(
lambda x: x.cpu(), new_incremental_state
)
curr_state = FairseqLMState(
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
)
if not no_cache:
self.states[state] = curr_state
self.stateq.append(state)
score = curr_state.probs[token_index].item()
trim_cache(self.max_cache)
outstate = state.child(token_index)
if outstate not in self.states and not no_cache:
prefix = np.concatenate(
[curr_state.prefix, torch.LongTensor([[token_index]])], -1
)
incr_state = curr_state.incremental_state
self.states[outstate] = FairseqLMState(prefix, incr_state, None)
if token_index == self.unk:
score = float("-inf")
return outstate, score
def finish(self, state: LMState):
"""
Evaluate eos for language model based on the current lm state
Returns:
--------
(LMState, float): pair of (new state, score for the current word)
"""
return self.score(state, self.dictionary.eos())
def empty_cache(self):
self.states = {}
self.stateq = deque()
gc.collect()
class W2lFairseqLMDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.silence = tgt_dict.bos()
self.unit_lm = getattr(args, "unit_lm", False)
self.lexicon = load_words(args.lexicon) if args.lexicon else None
self.idx_to_wrd = {}
checkpoint = torch.load(args.kenlm_model, map_location="cpu")
lm_args = checkpoint["args"]
lm_args.data = osp.dirname(args.kenlm_model)
print(lm_args)
task = tasks.setup_task(lm_args)
model = task.build_model(lm_args)
model.load_state_dict(checkpoint["model"], strict=False)
self.trie = Trie(self.vocab_size, self.silence)
self.word_dict = task.dictionary
self.unk_word = self.word_dict.unk()
self.lm = FairseqLM(self.word_dict, model)
self.decoder_opts = DecoderOptions(
args.beam,
int(getattr(args, "beam_size_token", len(tgt_dict))),
args.beam_threshold,
args.lm_weight,
args.word_score,
args.unk_weight,
args.sil_weight,
0,
False,
self.criterion_type,
)
if self.lexicon:
start_state = self.lm.start(False)
for i, (word, spellings) in enumerate(self.lexicon.items()):
if self.unit_lm:
word_idx = i
self.idx_to_wrd[i] = word
score = 0
else:
word_idx = self.word_dict.index(word)
_, score = self.lm.score(start_state, word_idx, no_cache=True)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
assert (
tgt_dict.unk() not in spelling_idxs
), f"{spelling} {spelling_idxs}"
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder = LexiconDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
[],
self.unit_lm,
)
else:
self.decoder = LexiconFreeDecoder(
self.decoder_opts, self.lm, self.silence, self.blank, []
)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
def idx_to_word(idx):
if self.unit_lm:
return self.idx_to_wrd[idx]
else:
return self.word_dict[idx]
def make_hypo(result):
hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
if self.lexicon:
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
return hypo
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append([make_hypo(result) for result in nbest_results])
self.lm.empty_cache()
return hypos
# Speech-to-Text (S2T) Modeling
## Data Preparation
S2T modeling data consists of source speech features, target text and other optional information
(source text, speaker id, etc.). Fairseq S2T uses per-dataset-split TSV manifest files
to store these information. Each data field is represented by a column in the TSV file.
Unlike text token embeddings, speech features (e.g. log mel-filter banks) are usually fixed
during model training and can be pre-computed. The manifest file contains the path to
either the feature file in NumPy format or the WAV/FLAC audio file. For the latter,
features will be extracted on-the-fly by fairseq S2T. Optionally, feature/audio files can be packed
into uncompressed ZIP files (then accessed via byte offset and length) to improve I/O performance.
Fairseq S2T also employs a YAML file for data related configurations: tokenizer type and dictionary path
for the target text, feature transforms such as CMVN (cepstral mean and variance normalization) and SpecAugment,
temperature-based resampling, etc.
## Model Training & Evaluation
Fairseq S2T uses the unified `fairseq-train`/`fairseq-generate` interface for model training and evaluation.
It requires arguments `--task speech_to_text` and `--arch <arch in fairseq.models.speech_to_text.*>`.
## Example 1: Speech Recognition (ASR) on LibriSpeech
#### Data preparation
Download and preprocess LibriSpeech data with
```bash
python examples/speech_to_text/prep_librispeech_data.py \
--output-root ${LS_ROOT} --vocab-type unigram --vocab-size 10000
```
where `LS_ROOT` is the root path for downloaded data as well as generated manifest and feature files.
#### Training
```bash
fairseq-train ${LS_ROOT} --train-subset train --valid-subset dev --save-dir ${SAVE_DIR} --num-workers 4 \
--max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy --max-update 300000 \
--arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 \
--clip-norm 10.0 --seed 1 --update-freq 8
```
where `SAVE_DIR` is the checkpoint root path. Here we use `--arch s2t_transformer_s` (31M parameters) as example.
You may switch to `s2t_transformer_m` (71M) or `s2t_transformer_l` (268M) for better performance. We set
`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
#### Inference & Evaluation
Average the last 10 checkpoints and evaluate on the 4 splits
(`dev-clean`, `dev-other`, `test-clean` and `test-other`):
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${SAVE_DIR} --num-epoch-checkpoints 10 --output "${SAVE_DIR}/${CHECKPOINT_FILENAME}"
for SUBSET in dev-clean dev-other test-clean test-other; do
fairseq-generate ${LS_ROOT} --gen-subset ${SUBSET} --task speech_to_text \
--path ${SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring wer
done
```
#### Result
| --arch | Params | dev-clean | dev-other | test-clean | test-other |
|---|---|---|---|---|---|
| s2t_transformer_s | 30M | 4.1 | 9.3 | 4.4 | 9.2 |
| s2t_transformer_sp | 35M | 3.9 | 9.3 | 4.3 | 8.8 |
| s2t_transformer_m | 71M | 3.5 | 8.1 | 3.7 | 8.1 |
| s2t_transformer_mp | 84M | 3.3 | 7.8 | 3.7 | 8.2 |
| s2t_transformer_l | 268M | 3.3 | 7.7 | 3.5 | 7.8 |
| s2t_transformer_lp | 318M | 3.1 | 7.5 | 3.4 | 7.6 |
## Example 2: Speech Translation (ST) on MuST-C
#### Data Preparation
[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path `MUSTC_ROOT`, then preprocess it with
```bash
python examples/speech_to_text/prep_mustc_data.py --data-root ${MUSTC_ROOT} \
--asr-vocab-type unigram --asr-vocab-size 5000 \
--st-vocab-type unigram --st-vocab-size 8000
```
The generated manifest and feature files will be available under `MUSTC_ROOT`.
#### ASR
###### Training
```bash
fairseq-train ${MUSTC_ROOT} --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
```
where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
###### Inference & Evaluation
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_asr --task speech_to_text \
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
```
###### Result
| --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru |
|---|---|---|---|---|---|---|---|---|---|
| s2t_transformer_s | 31M | 18.2 | 17.6 | 17.7 | 17.2 | 17.9 | 19.1 | 18.1 | 17.7 |
#### ST
###### Training
```bash
fairseq-train ${MUSTC_ROOT} --train-subset train_st --valid-subset dev_st --save-dir ${ST_SAVE_DIR} \
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
```
where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by ASR for faster training and better
performance: `--load-pretrained-encoder-from <ASR checkpoint path>`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
###### Inference & Evaluation
Average the last 10 checkpoints and evaluate on the `tst-COMMON` split:
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_st --task speech_to_text \
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu
```
###### Result
| --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru |
|---|---|---|---|---|---|---|---|---|---|
| s2t_transformer_s | 31M | 22.7 | 27.3 | 27.2 | 32.9 | 22.7 | 28.1 | 21.9 | 15.3 |
## Example 3: ST on CoVoST
#### Data Preparation
Download and preprocess CoVoST data with
```bash
# En ASR
python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \
--vocab-type char --src-lang en
# ST
python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \
--vocab-type char --src-lang fr --tgt-lang en
```
where `COVOST_ROOT` is the root path for downloaded data as well as generated manifest and feature files.
#### ASR
###### Training
```bash
fairseq-train ${COVOST_ROOT} --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
```
where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
###### Inference & Evaluation
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${COVOST_ROOT} --gen-subset test_asr_en --task speech_to_text \
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
```
###### Result
| --arch | Params | En |
|---|---|---|
| s2t_transformer_s | 31M | 25.6 |
#### ST
###### Training
```bash
fairseq-train ${COVOST_ROOT} --train-subset train_st_fr_en --valid-subset dev_st_fr_en --save-dir ${ST_SAVE_DIR} \
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
```
where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by En ASR for faster training and better
performance: `--load-pretrained-encoder-from <ASR checkpoint path>`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
###### Inference & Evaluation
Average the last 10 checkpoints and evaluate on test split:
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${COVOST_ROOT} --gen-subset test_st_fr_en --task speech_to_text \
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu
```
###### Result
| --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et |
|---|---|---|---|---|---|---|---|---|---|
| s2t_transformer_s | 31M | 26.3 | 17.1 | 23.0 | 18.8 | 16.3 | 21.8 | 13.1 | 13.2 |
## Citation
Please cite as:
```
@inproceedings{wang2020fairseqs2t,
title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
year = {2020},
}
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
year = {2019},
}
```
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
import os
import os.path as op
import zipfile
from functools import reduce
from glob import glob
from multiprocessing import cpu_count
from typing import Any, Dict, List
import numpy as np
import sentencepiece as sp
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
from tqdm import tqdm
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
def gen_vocab(
input_path: str,
output_path_prefix: str,
model_type="bpe",
vocab_size=1000,
):
# Train SentencePiece Model
arguments = [
f"--input={input_path}",
f"--model_prefix={output_path_prefix}",
f"--model_type={model_type}",
f"--vocab_size={vocab_size}",
"--character_coverage=1.0",
f"--num_threads={cpu_count()}",
f"--unk_id={UNK_TOKEN_ID}",
f"--bos_id={BOS_TOKEN_ID}",
f"--eos_id={EOS_TOKEN_ID}",
f"--pad_id={PAD_TOKEN_ID}",
]
sp.SentencePieceTrainer.Train(" ".join(arguments))
# Export fairseq dictionary
spm = sp.SentencePieceProcessor()
spm.Load(output_path_prefix + ".model")
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
assert (
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
)
vocab = {
i: s
for i, s in vocab.items()
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
}
with open(output_path_prefix + ".txt", "w") as f_out:
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
f_out.write(f"{s} 1\n")
def extract_fbank_features(
waveform,
sample_rate,
output_path=None,
n_mel_bins=80,
apply_utterance_cmvn=True,
overwrite=False,
):
if output_path is not None and op.exists(output_path) and not overwrite:
return
_waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers
_waveform = _waveform.squeeze().numpy()
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
if apply_utterance_cmvn:
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
features = cmvn(features)
if output_path is not None:
np.save(output_path, features)
else:
return features
def create_zip(data_root, zip_path):
cwd = os.path.abspath(os.curdir)
os.chdir(data_root)
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
for filename in tqdm(glob("*.npy")):
f.write(filename)
os.chdir(cwd)
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def get_zip_manifest(zip_root, zip_filename):
zip_path = op.join(zip_root, zip_filename)
with zipfile.ZipFile(zip_path, mode="r") as f:
info = f.infolist()
manifest = {}
for i in tqdm(info):
utt_id = op.splitext(i.filename)[0]
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
with open(zip_path, "rb") as f:
f.seek(offset)
data = f.read(file_size)
assert len(data) > 1 and is_npy_data(data)
return manifest
def gen_config_yaml(
data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb"
):
assert specaugment_policy in {"lb", "ld"}
data_root = op.abspath(data_root)
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
writer.set_audio_root(op.abspath(data_root))
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
writer.set_input_channels(1)
writer.set_input_feat_per_channel(80)
if specaugment_policy == "lb":
writer.set_specaugment_lb_policy()
else:
writer.set_specaugment_ld_policy()
writer.set_bpe_tokenizer(
{
"bpe": "sentencepiece",
"sentencepiece_model": op.join(data_root, spm_filename),
}
)
writer.set_feature_transforms("_train", ["specaugment"])
writer.flush()
def save_df_to_tsv(dataframe, path):
dataframe.to_csv(
path,
sep="\t",
header=True,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
)
def filter_manifest_df(
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
):
filters = {
"no speech": df["audio"] == "",
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
"empty sentence": df["tgt_text"] == "",
}
if is_train_split:
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
if extra_filters is not None:
filters.update(extra_filters)
invalid = reduce(lambda x, y: x | y, filters.values())
valid = ~invalid
print(
"| "
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
)
return df[valid]
class S2TDataConfigWriter(object):
DEFAULT_VOCAB_FILENAME = "dict.txt"
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
DEFAULT_INPUT_CHANNELS = 1
def __init__(self, yaml_path):
try:
import yaml
except ImportError:
print("Please install PyYAML to load YAML files for S2T data config")
self.yaml = yaml
self.yaml_path = yaml_path
self.config = {}
def flush(self):
with open(self.yaml_path, "w") as f:
self.yaml.dump(self.config, f)
def set_audio_root(self, audio_root=""):
self.config["audio_root"] = audio_root
def set_vocab_filename(self, vocab_filename="dict.txt"):
self.config["vocab_filename"] = vocab_filename
def set_specaugment(
self,
time_wrap_w: int,
freq_mask_n: int,
freq_mask_f: int,
time_mask_n: int,
time_mask_t: int,
time_mask_p: float,
):
self.config["specaugment"] = {
"time_wrap_W": time_wrap_w,
"freq_mask_N": freq_mask_n,
"freq_mask_F": freq_mask_f,
"time_mask_N": time_mask_n,
"time_mask_T": time_mask_t,
"time_mask_p": time_mask_p,
}
def set_specaugment_lb_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=1,
freq_mask_f=27,
time_mask_n=1,
time_mask_t=100,
time_mask_p=1.0,
)
def set_specaugment_ld_policy(self):
self.set_specaugment(
time_wrap_w=0,
freq_mask_n=2,
freq_mask_f=27,
time_mask_n=2,
time_mask_t=100,
time_mask_p=1.0,
)
def set_input_channels(self, input_channels=1):
self.config["input_channels"] = input_channels
def set_input_feat_per_channel(self, input_feat_per_channel=80):
self.config["input_feat_per_channel"] = input_feat_per_channel
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
self.config["bpe_tokenizer"] = bpe_tokenizer
def set_feature_transforms(self, split, transforms: List[str]):
if "transforms" not in self.config:
self.config["transforms"] = {}
self.config["transforms"][split] = transforms
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import csv
import logging
import os
import os.path as op
import shutil
from tempfile import NamedTemporaryFile
from typing import Optional, Tuple
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url, extract_archive
from tqdm import tqdm
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
class CoVoST(Dataset):
"""Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
Args:
root (str): root path to the dataset and generated manifests/features
source_language (str): source (audio) language
target_language (str, optional): target (text) language,
None for no translation (default: None)
version (int, optional): CoVoST version. (default: 2)
download (bool, optional): Whether to download the dataset if it is not
found at root path. (default: ``False``).
"""
CV_URL_TEMPLATE = (
"https://voice-prod-bundler-ee1969a6ce8178826482b88"
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
)
COVOST_URL_TEMPLATE = (
"https://dl.fbaipublicfiles.com/covost/"
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
)
VERSIONS = {2}
SPLITS = ["train", "dev", "test"]
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
XX_EN_LANGUAGES = {
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
2: [
"fr",
"de",
"es",
"ca",
"it",
"ru",
"zh-CN",
"pt",
"fa",
"et",
"mn",
"nl",
"tr",
"ar",
"sv-SE",
"lv",
"sl",
"ta",
"ja",
"id",
"cy",
],
}
EN_XX_LANGUAGES = {
1: [],
2: [
"de",
"tr",
"fa",
"sv-SE",
"mn",
"zh-CN",
"cy",
"ca",
"sl",
"et",
"id",
"ar",
"ta",
"lv",
"ja",
],
}
def __init__(
self,
root: str,
split: str,
source_language: str,
target_language: Optional[str] = None,
version: int = 2,
download: bool = False,
) -> None:
assert version in self.VERSIONS and split in self.SPLITS
assert source_language is not None
self.no_translation = target_language is None
if not self.no_translation:
assert "en" in {source_language, target_language}
if source_language == "en":
assert target_language in self.EN_XX_LANGUAGES[version]
else:
assert source_language in self.XX_EN_LANGUAGES[version]
else:
# Hack here so that we can get "split" column from CoVoST TSV.
# Note that we use CoVoST train split for ASR which is an extension
# to Common Voice train split.
target_language = "de" if source_language == "en" else "en"
self.root = os.path.join(root, "raw")
os.makedirs(self.root, exist_ok=True)
cv_url = self.CV_URL_TEMPLATE.format(
ver=self.CV_VERSION_ID[version], lang=source_language
)
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
if download:
if not os.path.isfile(cv_archive):
download_url(cv_url, self.root, hash_value=None)
extract_archive(cv_archive)
covost_url = self.COVOST_URL_TEMPLATE.format(
src_lang=source_language, tgt_lang=target_language
)
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
if download:
if not os.path.isfile(covost_archive):
download_url(covost_url, self.root, hash_value=None)
extract_archive(covost_archive)
cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
covost_tsv = self.load_from_tsv(
os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
)
df = pd.merge(
left=cv_tsv[["path", "sentence", "client_id"]],
right=covost_tsv[["path", "translation", "split"]],
how="inner",
on="path",
)
if split == "train":
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
else:
df = df[df["split"] == split]
self.data = df.to_dict(orient="index").items()
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
@classmethod
def load_from_tsv(cls, path: str):
return pd.read_csv(
path,
sep="\t",
header=0,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_NONE,
na_filter=False,
)
def __getitem__(
self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
sample_id)``
"""
data = self.data[n]
path = os.path.join(self.root, "clips", data["path"])
waveform, sample_rate = torchaudio.load(path)
sentence = data["sentence"]
translation = None if self.no_translation else data["translation"]
speaker_id = data["client_id"]
_id = data["path"].replace(".mp3", "")
return waveform, sample_rate, sentence, translation, speaker_id, _id
def __len__(self) -> int:
return len(self.data)
def process(args):
root = op.join(args.data_root, args.src_lang)
os.makedirs(root, exist_ok=True)
# Extract features
feature_root = op.join(root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in CoVoST.SPLITS:
print(f"Fetching split {split}...")
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
)
# Pack features into ZIP
zip_filename = "fbank80.zip"
zip_path = op.join(root, zip_filename)
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
# Generate TSV manifest
print("Generating manifest...")
train_text = []
task = f"asr_{args.src_lang}"
if args.tgt_lang is not None:
task = f"st_{args.src_lang}_{args.tgt_lang}"
for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
manifest["speaker"].append(speaker_id)
is_train_split = split.startswith("train")
if is_train_split:
train_text.extend(manifest["tgt_text"])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
# Generate vocab
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
)
# Generate config YAML
gen_config_yaml(
root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=1000, type=int)
parser.add_argument("--src-lang", "-s", required=True, type=str)
parser.add_argument("--tgt-lang", "-t", type=str)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
import os.path as op
import shutil
from tempfile import NamedTemporaryFile
import pandas as pd
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torchaudio.datasets import LIBRISPEECH
from tqdm import tqdm
log = logging.getLogger(__name__)
SPLITS = [
"train-clean-100",
"train-clean-360",
"train-other-500",
"dev-clean",
"dev-other",
"test-clean",
"test-other",
]
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
def process(args):
os.makedirs(args.output_root, exist_ok=True)
# Extract features
feature_root = op.join(args.output_root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in SPLITS:
print(f"Fetching split {split}...")
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
print("Extracting log mel filter bank features...")
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
extract_fbank_features(
wav, sample_rate, op.join(feature_root, f"{sample_id}.npy")
)
# Pack features into ZIP
zip_filename = "fbank80.zip"
zip_path = op.join(args.output_root, zip_filename)
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
# Generate TSV manifest
print("Generating manifest...")
train_text = []
for split in SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = LIBRISPEECH(args.output_root, url=split)
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
manifest["id"].append(sample_id)
manifest["audio"].append(zip_manifest[sample_id])
duration_ms = int(wav.size(1) / sample_rate * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
manifest["tgt_text"].append(utt)
manifest["speaker"].append(spk_id)
save_df_to_tsv(
pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv")
)
if split.startswith("train"):
train_text.extend(manifest["tgt_text"])
# Generate vocab
vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text:
f.write(t + "\n")
gen_vocab(
f.name,
op.join(args.output_root, spm_filename_prefix),
args.vocab_type,
args.vocab_size,
)
# Generate config YAML
gen_config_yaml(
args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld"
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output-root", "-o", required=True, type=str)
parser.add_argument(
"--vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--vocab-size", default=10000, type=int)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import os
import os.path as op
import shutil
from itertools import groupby
from tempfile import NamedTemporaryFile
from typing import Tuple
import pandas as pd
import torchaudio
from examples.speech_to_text.data_utils import (
create_zip,
extract_fbank_features,
filter_manifest_df,
gen_config_yaml,
gen_vocab,
get_zip_manifest,
save_df_to_tsv,
)
from torch import Tensor
from torch.utils.data import Dataset
from tqdm import tqdm
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
TASKS = ["asr", "st"]
class MUSTC(Dataset):
"""
Create a Dataset for MuST-C. Each item is a tuple of the form:
waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id
"""
SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
def __init__(self, root: str, lang: str, split: str) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
_root = op.join(root, f"en-{lang}", "data", split)
wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt")
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
# Load audio segments
try:
import yaml
except ImportError:
print("Please install PyYAML to load YAML files for " "the MuST-C dataset")
with open(op.join(txt_root, f"{split}.yaml")) as f:
segments = yaml.load(f, Loader=yaml.BaseLoader)
# Load source and target utterances
for _lang in ["en", lang]:
with open(op.join(txt_root, f"{split}.{_lang}")) as f:
utterances = [r.strip() for r in f]
assert len(segments) == len(utterances)
for i, u in enumerate(utterances):
segments[i][_lang] = u
# Gather info
self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
wav_path = op.join(wav_root, wav_filename)
sample_rate = torchaudio.info(wav_path)[0].rate
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
for i, segment in enumerate(seg_group):
offset = int(float(segment["offset"]) * sample_rate)
n_frames = int(float(segment["duration"]) * sample_rate)
_id = f"{op.splitext(wav_filename)[0]}_{i}"
self.data.append(
(
wav_path,
offset,
n_frames,
sample_rate,
segment["en"],
segment[lang],
segment["speaker_id"],
_id,
)
)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
def __len__(self) -> int:
return len(self.data)
def process(args):
for lang in MUSTC.LANGUAGES:
cur_root = op.join(args.data_root, f"en-{lang}")
if not op.isdir(cur_root):
print(f"{cur_root} does not exist. Skipped.")
continue
# Extract features
feature_root = op.join(cur_root, "fbank80")
os.makedirs(feature_root, exist_ok=True)
for split in MUSTC.SPLITS:
print(f"Fetching split {split}...")
dataset = MUSTC(args.data_root, lang, split)
print("Extracting log mel filter bank features...")
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
)
# Pack features into ZIP
zip_filename = "fbank80.zip"
zip_path = op.join(cur_root, zip_filename)
print("ZIPing features...")
create_zip(feature_root, zip_path)
print("Fetching ZIP manifest...")
zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}")
# Generate TSV manifest
print("Generating manifest...")
train_text = {task: [] for task in TASKS}
for split in MUSTC.SPLITS:
is_train_split = split.startswith("train")
manifest = {c: [] for c in MANIFEST_COLUMNS}
text = {task: [] for task in TASKS}
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest["id"].append(utt_id)
manifest["audio"].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
text["asr"].append(src_utt)
text["st"].append(tgt_utt)
manifest["speaker"].append(speaker_id)
if is_train_split:
for task in TASKS:
train_text[task].extend(text[task])
for task in TASKS:
manifest["tgt_text"] = text[task]
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv"))
# Generate vocab
for task in TASKS:
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
if task == "st":
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
vocab_size_str = "" if vocab_type == "char" else str(vocab_size)
spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}"
with NamedTemporaryFile(mode="w") as f:
for t in train_text[task]:
f.write(t + "\n")
gen_vocab(
f.name,
op.join(cur_root, spm_filename_prefix),
vocab_type,
vocab_size,
)
# Generate config YAML
gen_config_yaml(
cur_root,
spm_filename_prefix + ".model",
yaml_filename=f"config_{task}.yaml",
specaugment_policy="lb",
)
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", "-d", required=True, type=str)
parser.add_argument(
"--asr-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument(
"--st-vocab-type",
default="unigram",
required=True,
type=str,
choices=["bpe", "unigram", "char"],
),
parser.add_argument("--asr-vocab-size", default=5000, type=int)
parser.add_argument("--st-vocab-size", default=8000, type=int)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()
# Hierarchical Neural Story Generation (Fan et al., 2018)
The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset.
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2)
We provide sample stories generated by the [convolutional seq2seq model](https://dl.fbaipublicfiles.com/fairseq/data/seq2seq_stories.txt) and [fusion model](https://dl.fbaipublicfiles.com/fairseq/data/fusion_stories.txt) from [Fan et al., 2018](https://arxiv.org/abs/1805.04833). The corresponding prompts for the fusion model can be found [here](https://dl.fbaipublicfiles.com/fairseq/data/fusion_prompts.txt). Note that there are unk in the file, as we modeled a small full vocabulary (no BPE or pre-training). We did not use these unk prompts for human evaluation.
## Dataset
The dataset can be downloaded like this:
```bash
cd examples/stories
curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvzf -
```
and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token.
## Example usage
First we will preprocess the dataset. Note that the dataset release is the full data, but the paper models the first 1000 words of each story. Here is example code that trims the dataset to the first 1000 words of each story:
```python
data = ["train", "test", "valid"]
for name in data:
with open(name + ".wp_target") as f:
stories = f.readlines()
stories = [" ".join(i.split()[0:1000]) for i in stories]
with open(name + ".wp_target", "w") as o:
for line in stories:
o.write(line.strip() + "\n")
```
Once we've trimmed the data we can binarize it and train our model:
```bash
# Binarize the dataset:
export TEXT=examples/stories/writingPrompts
fairseq-preprocess --source-lang wp_source --target-lang wp_target \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/writingPrompts --padding-factor 1 --thresholdtgt 10 --thresholdsrc 10
# Train the model:
fairseq-train data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --optimizer nag --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False
# Train a fusion model:
# add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint
# Generate:
# Note: to load the pretrained model at generation time, you need to pass in a model-override argument to communicate to the fusion model at generation time where you have placed the pretrained checkpoint. By default, it will load the exact path of the fusion model's pretrained model from training time. You should use model-override if you have moved the pretrained model (or are using our provided models). If you are generating from a non-fusion model, the model-override argument is not necessary.
fairseq-generate data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}"
```
## Citation
```bibtex
@inproceedings{fan2018hierarchical,
title = {Hierarchical Neural Story Generation},
author = {Fan, Angela and Lewis, Mike and Dauphin, Yann},
booktitle = {Conference of the Association for Computational Linguistics (ACL)},
year = 2018,
}
```
# Neural Machine Translation
This README contains instructions for [using pretrained translation models](#example-usage-torchhub)
as well as [training new models](#training-a-new-model).
## Pre-trained models
Model | Description | Dataset | Download
---|---|---|---
`conv.wmt14.en-fr` | Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2) <br> newstest2012/2013: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2)
`conv.wmt14.en-de` | Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2)
`conv.wmt17.en-de` | Convolutional <br> ([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2)
`transformer.wmt14.en-fr` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
`transformer.wmt16.en-de` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
`transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
`transformer.wmt19.en-de` | Transformer <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) <br> WMT'19 winner | [WMT'19 English-German](http://www.statmt.org/wmt19/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz)
`transformer.wmt19.de-en` | Transformer <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) <br> WMT'19 winner | [WMT'19 German-English](http://www.statmt.org/wmt19/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz)
`transformer.wmt19.en-ru` | Transformer <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) <br> WMT'19 winner | [WMT'19 English-Russian](http://www.statmt.org/wmt19/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz)
`transformer.wmt19.ru-en` | Transformer <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) <br> WMT'19 winner | [WMT'19 Russian-English](http://www.statmt.org/wmt19/translation-task.html) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz)
## Example usage (torch.hub)
We require a few additional Python dependencies for preprocessing:
```bash
pip install fastBPE sacremoses subword_nmt
```
Interactive translation via PyTorch Hub:
```python
import torch
# List available models
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ]
# Load a transformer trained on WMT'16 En-De
# Note: WMT'19 models use fastBPE instead of subword_nmt, see instructions below
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de',
tokenizer='moses', bpe='subword_nmt')
en2de.eval() # disable dropout
# The underlying model is available under the *models* attribute
assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel)
# Move model to GPU for faster translation
en2de.cuda()
# Translate a sentence
en2de.translate('Hello world!')
# 'Hallo Welt!'
# Batched translation
en2de.translate(['Hello world!', 'The cat sat on the mat.'])
# ['Hallo Welt!', 'Die Katze saß auf der Matte.']
```
Loading custom models:
```python
from fairseq.models.transformer import TransformerModel
zh2en = TransformerModel.from_pretrained(
'/path/to/checkpoints',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='data-bin/wmt17_zh_en_full',
bpe='subword_nmt',
bpe_codes='data-bin/wmt17_zh_en_full/zh.code'
)
zh2en.translate('你好 世界')
# 'Hello World'
```
If you are using a `transformer.wmt19` models, you will need to set the `bpe`
argument to `'fastbpe'` and (optionally) load the 4-model ensemble:
```python
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de',
checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
tokenizer='moses', bpe='fastbpe')
en2de.eval() # disable dropout
```
## Example usage (CLI tools)
Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
```bash
mkdir -p data-bin
curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin
curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin
fairseq-generate data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
# ...
# | Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s)
# | Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
# Compute BLEU score
grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
fairseq-score --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
# BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
```
## Training a new model
### IWSLT'14 German to English (Transformer)
The following instructions can be used to train a Transformer model on the [IWSLT'14 German to English dataset](http://workshop2014.iwslt.org/downloads/proceeding.pdf).
First download and preprocess the data:
```bash
# Download and prepare the data
cd examples/translation/
bash prepare-iwslt14.sh
cd ../..
# Preprocess/binarize the data
TEXT=examples/translation/iwslt14.tokenized.de-en
fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/iwslt14.tokenized.de-en \
--workers 20
```
Next we'll train a Transformer translation model over this data:
```bash
CUDA_VISIBLE_DEVICES=0 fairseq-train \
data-bin/iwslt14.tokenized.de-en \
--arch transformer_iwslt_de_en --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--dropout 0.3 --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 4096 \
--eval-bleu \
--eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
--eval-bleu-detok moses \
--eval-bleu-remove-bpe \
--eval-bleu-print-samples \
--best-checkpoint-metric bleu --maximize-best-checkpoint-metric
```
Finally we can evaluate our trained model:
```bash
fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--batch-size 128 --beam 5 --remove-bpe
```
### WMT'14 English to German (Convolutional)
The following instructions can be used to train a Convolutional translation model on the WMT English to German dataset.
See the [Scaling NMT README](../scaling_nmt/README.md) for instructions to train a Transformer translation model on this data.
The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script.
By default it will produce a dataset that was modeled after [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762), but with additional news-commentary-v12 data from WMT'17.
To use only data available in WMT'14 or to replicate results obtained in the original [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](https://arxiv.org/abs/1705.03122) paper, please use the `--icml17` option.
```bash
# Download and prepare the data
cd examples/translation/
# WMT'17 data:
bash prepare-wmt14en2de.sh
# or to use WMT'14 data:
# bash prepare-wmt14en2de.sh --icml17
cd ../..
# Binarize the dataset
TEXT=examples/translation/wmt17_en_de
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0 \
--workers 20
# Train the model
mkdir -p checkpoints/fconv_wmt_en_de
fairseq-train \
data-bin/wmt17_en_de \
--arch fconv_wmt_en_de \
--dropout 0.2 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer nag --clip-norm 0.1 \
--lr 0.5 --lr-scheduler fixed --force-anneal 50 \
--max-tokens 4000 \
--save-dir checkpoints/fconv_wmt_en_de
# Evaluate
fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/fconv_wmt_en_de/checkpoint_best.pt \
--beam 5 --remove-bpe
```
### WMT'14 English to French
```bash
# Download and prepare the data
cd examples/translation/
bash prepare-wmt14en2fr.sh
cd ../..
# Binarize the dataset
TEXT=examples/translation/wmt14_en_fr
fairseq-preprocess \
--source-lang en --target-lang fr \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0 \
--workers 60
# Train the model
mkdir -p checkpoints/fconv_wmt_en_fr
fairseq-train \
data-bin/wmt14_en_fr \
--arch fconv_wmt_en_fr \
--dropout 0.1 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer nag --clip-norm 0.1 \
--lr 0.5 --lr-scheduler fixed --force-anneal 50 \
--max-tokens 3000 \
--save-dir checkpoints/fconv_wmt_en_fr
# Evaluate
fairseq-generate \
data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt \
--beam 5 --remove-bpe
```
## Multilingual Translation
We also support training multilingual translation models. In this example we'll
train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets.
Note that we use slightly different preprocessing here than for the IWSLT'14
En-De data above. In particular we learn a joint BPE code for all three
languages and use fairseq-interactive and sacrebleu for scoring the test set.
```bash
# First install sacrebleu and sentencepiece
pip install sacrebleu sentencepiece
# Then download and preprocess the data
cd examples/translation/
bash prepare-iwslt17-multilingual.sh
cd ../..
# Binarize the de-en dataset
TEXT=examples/translation/iwslt17.de_fr.en.bpe16k
fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train.bpe.de-en \
--validpref $TEXT/valid0.bpe.de-en,$TEXT/valid1.bpe.de-en,$TEXT/valid2.bpe.de-en,$TEXT/valid3.bpe.de-en,$TEXT/valid4.bpe.de-en,$TEXT/valid5.bpe.de-en \
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
--workers 10
# Binarize the fr-en dataset
# NOTE: it's important to reuse the en dictionary from the previous step
fairseq-preprocess --source-lang fr --target-lang en \
--trainpref $TEXT/train.bpe.fr-en \
--validpref $TEXT/valid0.bpe.fr-en,$TEXT/valid1.bpe.fr-en,$TEXT/valid2.bpe.fr-en,$TEXT/valid3.bpe.fr-en,$TEXT/valid4.bpe.fr-en,$TEXT/valid5.bpe.fr-en \
--tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
--workers 10
# Train a multilingual transformer model
# NOTE: the command below assumes 1 GPU, but accumulates gradients from
# 8 fwd/bwd passes to simulate training on 8 GPUs
mkdir -p checkpoints/multilingual_transformer
CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \
--max-epoch 50 \
--ddp-backend=no_c10d \
--task multilingual_translation --lang-pairs de-en,fr-en \
--arch multilingual_transformer_iwslt_de_en \
--share-decoders --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)' \
--lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--label-smoothing 0.1 --criterion label_smoothed_cross_entropy \
--dropout 0.3 --weight-decay 0.0001 \
--save-dir checkpoints/multilingual_transformer \
--max-tokens 4000 \
--update-freq 8
# Generate and score the test set with sacrebleu
SRC=de
sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
| python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
> iwslt17.test.${SRC}-en.${SRC}.bpe
cat iwslt17.test.${SRC}-en.${SRC}.bpe \
| fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
--task multilingual_translation --lang-pairs de-en,fr-en \
--source-lang ${SRC} --target-lang en \
--path checkpoints/multilingual_transformer/checkpoint_best.pt \
--buffer-size 2000 --batch-size 128 \
--beam 5 --remove-bpe=sentencepiece \
> iwslt17.test.${SRC}-en.en.sys
grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
| sacrebleu --test-set iwslt17 --language-pair ${SRC}-en
```
##### Argument format during inference
During inference it is required to specify a single `--source-lang` and
`--target-lang`, which indicates the inference langauge direction.
`--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to
the same value as training.
#!/usr/bin/env bash
#
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
LC=$SCRIPTS/tokenizer/lowercase.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=10000
URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz"
GZ=de-en.tgz
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=de
tgt=en
lang=de-en
prep=iwslt14.tokenized.de-en
tmp=$prep/tmp
orig=orig
mkdir -p $orig $tmp $prep
echo "Downloading data from ${URL}..."
cd $orig
wget "$URL"
if [ -f $GZ ]; then
echo "Data successfully downloaded."
else
echo "Data not successfully downloaded."
exit
fi
tar zxvf $GZ
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
f=train.tags.$lang.$l
tok=train.tags.$lang.tok.$l
cat $orig/$lang/$f | \
grep -v '<url>' | \
grep -v '<talkid>' | \
grep -v '<keywords>' | \
sed -e 's/<title>//g' | \
sed -e 's/<\/title>//g' | \
sed -e 's/<description>//g' | \
sed -e 's/<\/description>//g' | \
perl $TOKENIZER -threads 8 -l $l > $tmp/$tok
echo ""
done
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175
for l in $src $tgt; do
perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l
done
echo "pre-processing valid/test data..."
for l in $src $tgt; do
for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do
fname=${o##*/}
f=$tmp/${fname%.*}
echo $o $f
grep '<seg id' $o | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -l $l | \
perl $LC > $f
echo ""
done
done
echo "creating train, valid, test..."
for l in $src $tgt; do
awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l
awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l
cat $tmp/IWSLT14.TED.dev2010.de-en.$l \
$tmp/IWSLT14.TEDX.dev2012.de-en.$l \
$tmp/IWSLT14.TED.tst2010.de-en.$l \
$tmp/IWSLT14.TED.tst2011.de-en.$l \
$tmp/IWSLT14.TED.tst2012.de-en.$l \
> $tmp/test.$l
done
TRAIN=$tmp/train.en-de
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
cat $tmp/train.$l >> $TRAIN
done
echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
for L in $src $tgt; do
for f in train.$L valid.$L test.$L; do
echo "apply_bpe.py to ${f}..."
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f
done
done
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
SRCS=(
"de"
"fr"
)
TGT=en
ROOT=$(dirname "$0")
SCRIPTS=$ROOT/../../scripts
SPM_TRAIN=$SCRIPTS/spm_train.py
SPM_ENCODE=$SCRIPTS/spm_encode.py
BPESIZE=16384
ORIG=$ROOT/iwslt17_orig
DATA=$ROOT/iwslt17.de_fr.en.bpe16k
mkdir -p "$ORIG" "$DATA"
TRAIN_MINLEN=1 # remove sentences with <1 BPE token
TRAIN_MAXLEN=250 # remove sentences with >250 BPE tokens
URLS=(
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/de/en/de-en.tgz"
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz"
)
ARCHIVES=(
"de-en.tgz"
"fr-en.tgz"
)
VALID_SETS=(
"IWSLT17.TED.dev2010.de-en IWSLT17.TED.tst2010.de-en IWSLT17.TED.tst2011.de-en IWSLT17.TED.tst2012.de-en IWSLT17.TED.tst2013.de-en IWSLT17.TED.tst2014.de-en IWSLT17.TED.tst2015.de-en"
"IWSLT17.TED.dev2010.fr-en IWSLT17.TED.tst2010.fr-en IWSLT17.TED.tst2011.fr-en IWSLT17.TED.tst2012.fr-en IWSLT17.TED.tst2013.fr-en IWSLT17.TED.tst2014.fr-en IWSLT17.TED.tst2015.fr-en"
)
# download and extract data
for ((i=0;i<${#URLS[@]};++i)); do
ARCHIVE=$ORIG/${ARCHIVES[i]}
if [ -f "$ARCHIVE" ]; then
echo "$ARCHIVE already exists, skipping download"
else
URL=${URLS[i]}
wget -P "$ORIG" "$URL"
if [ -f "$ARCHIVE" ]; then
echo "$URL successfully downloaded."
else
echo "$URL not successfully downloaded."
exit 1
fi
fi
FILE=${ARCHIVE: -4}
if [ -e "$FILE" ]; then
echo "$FILE already exists, skipping extraction"
else
tar -C "$ORIG" -xzvf "$ARCHIVE"
fi
done
echo "pre-processing train data..."
for SRC in "${SRCS[@]}"; do
for LANG in "${SRC}" "${TGT}"; do
cat "$ORIG/${SRC}-${TGT}/train.tags.${SRC}-${TGT}.${LANG}" \
| grep -v '<url>' \
| grep -v '<talkid>' \
| grep -v '<keywords>' \
| grep -v '<speaker>' \
| grep -v '<reviewer' \
| grep -v '<translator' \
| grep -v '<doc' \
| grep -v '</doc>' \
| sed -e 's/<title>//g' \
| sed -e 's/<\/title>//g' \
| sed -e 's/<description>//g' \
| sed -e 's/<\/description>//g' \
| sed 's/^\s*//g' \
| sed 's/\s*$//g' \
> "$DATA/train.${SRC}-${TGT}.${LANG}"
done
done
echo "pre-processing valid data..."
for ((i=0;i<${#SRCS[@]};++i)); do
SRC=${SRCS[i]}
VALID_SET=(${VALID_SETS[i]})
for ((j=0;j<${#VALID_SET[@]};++j)); do
FILE=${VALID_SET[j]}
for LANG in "$SRC" "$TGT"; do
grep '<seg id' "$ORIG/${SRC}-${TGT}/${FILE}.${LANG}.xml" \
| sed -e 's/<seg id="[0-9]*">\s*//g' \
| sed -e 's/\s*<\/seg>\s*//g' \
| sed -e "s/\’/\'/g" \
> "$DATA/valid${j}.${SRC}-${TGT}.${LANG}"
done
done
done
# learn BPE with sentencepiece
TRAIN_FILES=$(for SRC in "${SRCS[@]}"; do echo $DATA/train.${SRC}-${TGT}.${SRC}; echo $DATA/train.${SRC}-${TGT}.${TGT}; done | tr "\n" ",")
echo "learning joint BPE over ${TRAIN_FILES}..."
python "$SPM_TRAIN" \
--input=$TRAIN_FILES \
--model_prefix=$DATA/sentencepiece.bpe \
--vocab_size=$BPESIZE \
--character_coverage=1.0 \
--model_type=bpe
# encode train/valid
echo "encoding train with learned BPE..."
for SRC in "${SRCS[@]}"; do
python "$SPM_ENCODE" \
--model "$DATA/sentencepiece.bpe.model" \
--output_format=piece \
--inputs $DATA/train.${SRC}-${TGT}.${SRC} $DATA/train.${SRC}-${TGT}.${TGT} \
--outputs $DATA/train.bpe.${SRC}-${TGT}.${SRC} $DATA/train.bpe.${SRC}-${TGT}.${TGT} \
--min-len $TRAIN_MINLEN --max-len $TRAIN_MAXLEN
done
echo "encoding valid with learned BPE..."
for ((i=0;i<${#SRCS[@]};++i)); do
SRC=${SRCS[i]}
VALID_SET=(${VALID_SETS[i]})
for ((j=0;j<${#VALID_SET[@]};++j)); do
python "$SPM_ENCODE" \
--model "$DATA/sentencepiece.bpe.model" \
--output_format=piece \
--inputs $DATA/valid${j}.${SRC}-${TGT}.${SRC} $DATA/valid${j}.${SRC}-${TGT}.${TGT} \
--outputs $DATA/valid${j}.bpe.${SRC}-${TGT}.${SRC} $DATA/valid${j}.bpe.${SRC}-${TGT}.${TGT}
done
done
#!/bin/bash
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=40000
URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz"
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
"http://statmt.org/wmt14/test-full.tgz"
)
FILES=(
"training-parallel-europarl-v7.tgz"
"training-parallel-commoncrawl.tgz"
"training-parallel-nc-v12.tgz"
"dev.tgz"
"test-full.tgz"
)
CORPORA=(
"training/europarl-v7.de-en"
"commoncrawl.de-en"
"training/news-commentary-v12.de-en"
)
# This will make the dataset compatible to the one used in "Convolutional Sequence to Sequence Learning"
# https://arxiv.org/abs/1705.03122
if [ "$1" == "--icml17" ]; then
URLS[2]="http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
FILES[2]="training-parallel-nc-v9.tgz"
CORPORA[2]="training/news-commentary-v9.de-en"
OUTDIR=wmt14_en_de
else
OUTDIR=wmt17_en_de
fi
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=en
tgt=de
lang=en-de
prep=$OUTDIR
tmp=$prep/tmp
orig=orig
dev=dev/newstest2013
mkdir -p $orig $tmp $prep
cd $orig
for ((i=0;i<${#URLS[@]};++i)); do
file=${FILES[i]}
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
url=${URLS[i]}
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit -1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
fi
fi
done
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
rm $tmp/train.tags.$lang.tok.$l
for f in "${CORPORA[@]}"; do
cat $orig/$f.$l | \
perl $NORM_PUNC $l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
done
done
echo "pre-processing test data..."
for l in $src $tgt; do
if [ "$l" == "$src" ]; then
t="src"
else
t="ref"
fi
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
echo ""
done
echo "splitting train and valid..."
for l in $src $tgt; do
awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
done
TRAIN=$tmp/train.de-en
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
cat $tmp/train.$l >> $TRAIN
done
echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
for L in $src $tgt; do
for f in train.$L valid.$L test.$L; do
echo "apply_bpe.py to ${f}..."
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
done
done
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
for L in $src $tgt; do
cp $tmp/bpe.test.$L $prep/test.$L
done
#!/bin/bash
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
git clone https://github.com/rsennrich/subword-nmt.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=40000
URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
"http://statmt.org/wmt13/training-parallel-un.tgz"
"http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
"http://statmt.org/wmt10/training-giga-fren.tar"
"http://statmt.org/wmt14/test-full.tgz"
)
FILES=(
"training-parallel-europarl-v7.tgz"
"training-parallel-commoncrawl.tgz"
"training-parallel-un.tgz"
"training-parallel-nc-v9.tgz"
"training-giga-fren.tar"
"test-full.tgz"
)
CORPORA=(
"training/europarl-v7.fr-en"
"commoncrawl.fr-en"
"un/undoc.2000.fr-en"
"training/news-commentary-v9.fr-en"
"giga-fren.release2.fixed"
)
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=en
tgt=fr
lang=en-fr
prep=wmt14_en_fr
tmp=$prep/tmp
orig=orig
mkdir -p $orig $tmp $prep
cd $orig
for ((i=0;i<${#URLS[@]};++i)); do
file=${FILES[i]}
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
url=${URLS[i]}
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit -1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
fi
fi
done
gunzip giga-fren.release2.fixed.*.gz
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
rm $tmp/train.tags.$lang.tok.$l
for f in "${CORPORA[@]}"; do
cat $orig/$f.$l | \
perl $NORM_PUNC $l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
done
done
echo "pre-processing test data..."
for l in $src $tgt; do
if [ "$l" == "$src" ]; then
t="src"
else
t="ref"
fi
grep '<seg id' $orig/test-full/newstest2014-fren-$t.$l.sgm | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
echo ""
done
echo "splitting train and valid..."
for l in $src $tgt; do
awk '{if (NR%1333 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
awk '{if (NR%1333 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
done
TRAIN=$tmp/train.fr-en
BPE_CODE=$prep/code
rm -f $TRAIN
for l in $src $tgt; do
cat $tmp/train.$l >> $TRAIN
done
echo "learn_bpe.py on ${TRAIN}..."
python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
for L in $src $tgt; do
for f in train.$L valid.$L test.$L; do
echo "apply_bpe.py to ${f}..."
python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
done
done
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
for L in $src $tgt; do
cp $tmp/bpe.test.$L $prep/test.$L
done
# Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)
This page includes instructions for reproducing results from the paper [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).
## Download data
First, follow the [instructions to download and preprocess the WMT'17 En-De dataset](../translation#prepare-wmt14en2desh).
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
## Train a model
Then we can train a mixture of experts model using the `translation_moe` task.
Use the `--method` flag to choose the MoE variant; we support hard mixtures with a learned or uniform prior (`--method hMoElp` and `hMoEup`, respectively) and soft mixures (`--method sMoElp` and `sMoEup`).
The model is trained with online responsibility assignment and shared parameterization.
The following command will train a `hMoElp` model with `3` experts:
```bash
fairseq-train --ddp-backend='no_c10d' \
data-bin/wmt17_en_de \
--max-update 100000 \
--task translation_moe --user-dir examples/translation_moe/translation_moe_src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
--lr 0.0007 --min-lr 1e-09 \
--dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \
--max-tokens 3584
```
## Translate
Once a model is trained, we can generate translations from different experts using the `--gen-expert` option.
For example, to generate from expert 0:
```bash
fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--task translation_moe --user-dir examples/translation_moe/translation_moe_src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert 0
```
## Evaluate
First download a tokenized version of the WMT'14 En-De test set with multiple references:
```bash
wget dl.fbaipublicfiles.com/fairseq/data/wmt14-en-de.extra_refs.tok
```
Next apply BPE on the fly and run generation for each expert:
```bash
BPE_CODE=examples/translation/wmt17_en_de/code
for EXPERT in $(seq 0 2); do \
cat wmt14-en-de.extra_refs.tok \
| grep ^S | cut -f 2 \
| fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 \
--bpe subword_nmt --bpe-codes $BPE_CODE \
--buffer-size 500 --max-tokens 6000 \
--task translation_moe --user-dir examples/translation_moe/translation_moe_src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert $EXPERT ; \
done > wmt14-en-de.extra_refs.tok.gen.3experts
```
Finally use `score_moe.py` to compute pairwise BLUE and average oracle BLEU:
```bash
python examples/translation_moe/score.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
# pairwise BLEU: 48.26
# #refs covered: 2.11
# multi-reference BLEU (leave-one-out): 59.46
```
This matches row 3 from Table 7 in the paper.
## Citation
```bibtex
@article{shen2019mixture,
title = {Mixture Models for Diverse Machine Translation: Tricks of the Trade},
author = {Tianxiao Shen and Myle Ott and Michael Auli and Marc'Aurelio Ranzato},
journal = {International Conference on Machine Learning},
year = 2019,
}
```
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
candidate hypotheses.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
"""
import argparse
import random
import sys
from itertools import chain
import numpy as np
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
def main():
parser = argparse.ArgumentParser(sys.argv[0])
parser.add_argument(
"--sys", nargs="*", default="", metavar="FILE", help="path to system output"
)
parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
parser.add_argument(
"--output",
default="",
metavar="FILE",
help="print outputs into a pretty format",
)
args = parser.parse_args()
if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys)
print("pairwise BLEU: %.2f" % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
if args.ref:
_, _, refs = load_ref(args.ref)
if args.sys:
multi_ref(refs, hypos)
else:
intra_ref(refs)
def dictolist(d):
a = sorted(d.items(), key=lambda i: i[0])
return [i[1] for i in a]
def load_sys(paths):
src, tgt, hypos, log_probs = {}, {}, {}, {}
for path in paths:
with open(path) as f:
for line in f:
line = line.rstrip()
# S: source
# T: target
# D: detokenized system output
if line.startswith(("S-", "T-", "D-")):
i = int(line[line.find("-") + 1 : line.find("\t")])
if line.startswith("S-"):
src[i] = line.split("\t")[1]
if line.startswith("T-"):
tgt[i] = line.split("\t")[1]
if line.startswith("D-"):
if i not in hypos:
hypos[i] = []
log_probs[i] = []
hypos[i].append(line.split("\t")[2])
log_probs[i].append(float(line.split("\t")[1]))
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
def load_ref(path):
with open(path) as f:
lines = f.readlines()
src, tgt, refs = [], [], []
i = 0
while i < len(lines):
if lines[i].startswith("S-"):
src.append(lines[i].split("\t")[1].rstrip())
i += 1
elif lines[i].startswith("T-"):
tgt.append(lines[i].split("\t")[1].rstrip())
i += 1
else:
a = []
while i < len(lines) and lines[i].startswith("R"):
a.append(lines[i].split("\t")[1].rstrip())
i += 1
refs.append(a)
return src, tgt, refs
def merge(src, tgt, hypos, log_probs, path):
with open(path, "w") as f:
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
f.write(s + "\n")
f.write(t + "\n")
f.write("\n")
for h, lp in zip(hs, lps):
f.write("\t%f\t%s\n" % (lp, h.strip()))
f.write("------------------------------------------------------\n")
def corpus_bleu(sys_stream, ref_streams):
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
return bleu.score
def sentence_bleu(hypothesis, reference):
bleu = _corpus_bleu(hypothesis, reference)
for i in range(1, 4):
bleu.counts[i] += 1
bleu.totals[i] += 1
bleu = compute_bleu(
bleu.counts,
bleu.totals,
bleu.sys_len,
bleu.ref_len,
smooth_method="exp",
)
return bleu.score
def pairwise(sents):
_ref, _hypo = [], []
for s in sents:
for i in range(len(s)):
for j in range(len(s)):
if i != j:
_ref.append(s[i])
_hypo.append(s[j])
return corpus_bleu(_hypo, [_ref])
def multi_ref(refs, hypos):
_ref, _hypo = [], []
ref_cnt = 0
assert len(refs) == len(hypos)
# count number of refs covered
for rs, hs in zip(refs, hypos):
a = set()
for h in hs:
s = [sentence_bleu(h, r) for r in rs]
j = np.argmax(s)
_ref.append(rs[j])
_hypo.append(h)
best = [k for k in range(len(rs)) if s[k] == s[j]]
a.add(random.choice(best))
ref_cnt += len(a)
print("#refs covered: %.2f" % (ref_cnt / len(refs)))
# transpose refs and hypos
refs = list(zip(*refs))
hypos = list(zip(*hypos))
# compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
k = len(hypos)
m = len(refs)
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
loo_bleus = []
for held_out_ref in range(m):
remaining_refs = (
duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
)
assert len(remaining_refs) == m - 1
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
def intra_ref(refs):
print("ref pairwise BLEU: %.2f" % pairwise(refs))
refs = list(zip(*refs))
m = len(refs)
concat_h = []
concat_rest = [[] for j in range(m - 1)]
for i, h in enumerate(refs):
rest = refs[:i] + refs[i + 1 :]
concat_h.append(h)
for j in range(m - 1):
concat_rest[j].extend(rest[j])
concat_h = list(chain.from_iterable(concat_h))
bleu = corpus_bleu(concat_h, concat_rest)
print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import translation_moe # noqa
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment