Commit c394d7d1 authored by “change”'s avatar “change”
Browse files

init

parents
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Replabel transforms for use with flashlight's ASG criterion.
"""
def replabel_symbol(i):
"""
Replabel symbols used in flashlight, currently just "1", "2", ...
This prevents training with numeral tokens, so this might change in the future
"""
return str(i)
def pack_replabels(tokens, dictionary, max_reps):
"""
Pack a token sequence so that repeated symbols are replaced by replabels
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_value_to_idx = [0] * (max_reps + 1)
for i in range(1, max_reps + 1):
replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i))
result = []
prev_token = -1
num_reps = 0
for token in tokens:
if token == prev_token and num_reps < max_reps:
num_reps += 1
else:
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
num_reps = 0
result.append(token)
prev_token = token
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
return result
def unpack_replabels(tokens, dictionary, max_reps):
"""
Unpack a token sequence so that replabels are replaced by repeated symbols
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_idx_to_value = {}
for i in range(1, max_reps + 1):
replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i
result = []
prev_token = -1
for token in tokens:
try:
for _ in range(replabel_idx_to_value[token]):
result.append(prev_token)
prev_token = -1
except KeyError:
result.append(token)
prev_token = token
return result
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import concurrent.futures
import json
import multiprocessing
import os
from collections import namedtuple
from itertools import chain
import sentencepiece as spm
from fairseq.data import Dictionary
MILLISECONDS_TO_SECONDS = 0.001
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
import torchaudio
input = {}
output = {}
si, ei = torchaudio.info(aud_path)
input["length_ms"] = int(
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
)
input["path"] = aud_path
token = " ".join(sp.EncodeAsPieces(lable))
ids = tgt_dict.encode_line(token, append_eos=False)
output["text"] = lable
output["token"] = token
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
return {utt_id: {"input": input, "output": output}}
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio-dirs",
nargs="+",
default=["-"],
required=True,
help="input directories with audio files",
)
parser.add_argument(
"--labels",
required=True,
help="aggregated input labels with format <ID LABEL> per line",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--spm-model",
required=True,
help="sentencepiece model to use for encoding",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--dictionary",
required=True,
help="file to load fairseq dictionary from",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
parser.add_argument(
"--output",
required=True,
type=argparse.FileType("w"),
help="path to save json output",
)
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
sp.Load(args.spm_model.name)
tgt_dict = Dictionary.load(args.dictionary)
labels = {}
for line in args.labels:
(utt_id, label) = line.split(" ", 1)
labels[utt_id] = label
if len(labels) == 0:
raise Exception("No labels found in ", args.labels_path)
Sample = namedtuple("Sample", "aud_path utt_id")
samples = []
for path, _, files in chain.from_iterable(
os.walk(path) for path in args.audio_dirs
):
for f in files:
if f.endswith(args.audio_format):
if len(os.path.splitext(f)) != 2:
raise Exception("Expect <utt_id.extension> file name. Got: ", f)
utt_id = os.path.splitext(f)[0]
if utt_id not in labels:
continue
samples.append(Sample(os.path.join(path, f), utt_id))
utts = {}
num_cpu = multiprocessing.cpu_count()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
future_to_sample = {
executor.submit(
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
): s
for s in samples
}
for future in concurrent.futures.as_completed(future_to_sample):
try:
data = future.result()
except Exception as exc:
print("generated an exception: ", exc)
else:
utts.update(data)
json.dump({"utts": utts}, args.output, indent=4)
if __name__ == "__main__":
main()
#!/usr/bin/env bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Prepare librispeech dataset
base_url=www.openslr.org/resources/12
train_dir=train_960
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <download_dir> <out_dir>"
echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
exit 1
fi
download_dir=${1%/}
out_dir=${2%/}
fairseq_root=~/fairseq-py/
mkdir -p ${out_dir}
cd ${out_dir} || exit
nbpe=5000
bpemode=unigram
if [ ! -d "$fairseq_root" ]; then
echo "$0: Please set correct fairseq_root"
exit 1
fi
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
url=$base_url/$part.tar.gz
if ! wget -P $download_dir $url; then
echo "$0: wget failed for $url"
exit 1
fi
if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
exit 1
fi
done
echo "Merge all train packs into one"
mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
for part in train-clean-100 train-clean-360 train-other-500; do
mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
done
echo "Merge train text"
find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
# Use combined dev-clean and dev-other as validation set
find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
echo "dictionary: ${dict}"
echo "Dictionary preparation"
mkdir -p data/lang_char/
echo "<unk> 3" > ${dict}
echo "</s> 2" >> ${dict}
echo "<pad> 1" >> ${dict}
cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
wc -l ${dict}
echo "Prepare train and test jsons"
for part in train_960 test-other test-clean; do
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
done
# fairseq expects to find train.json and valid.json during training
mv train_960.json train.json
echo "Prepare valid json"
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
cp ${fairseq_dict} ./dict.txt
cp ${bpemodel}.model ./spm.model
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Run inference for pre-processed data with a trained model.
"""
import ast
import logging
import math
import os
import sys
import editdistance
import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import StopwatchMeter, TimeMeter
logging.basicConfig()
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def add_asr_eval_argument(parser):
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument(
"--rnnt_decoding_type",
default="greedy",
help="wfstlm on dictonary\
output units",
)
try:
parser.add_argument(
"--lm-weight",
"--lm_weight",
type=float,
default=0.2,
help="weight for lm while interpolating with neural score",
)
except:
pass
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
parser.add_argument(
"--w2l-decoder",
choices=["viterbi", "kenlm", "fairseqlm"],
help="use a w2l decoder",
)
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
parser.add_argument("--beam-threshold", type=float, default=25.0)
parser.add_argument("--beam-size-token", type=float, default=100)
parser.add_argument("--word-score", type=float, default=1.0)
parser.add_argument("--unk-weight", type=float, default=-math.inf)
parser.add_argument("--sil-weight", type=float, default=0.0)
parser.add_argument(
"--dump-emissions",
type=str,
default=None,
help="if present, dumps emissions into this file and exits",
)
parser.add_argument(
"--dump-features",
type=str,
default=None,
help="if present, dumps features into this file and exits",
)
parser.add_argument(
"--load-emissions",
type=str,
default=None,
help="if present, loads emissions from this file",
)
return parser
def check_args(args):
# assert args.path is not None, "--path required for generation!"
# assert args.results_path is not None, "--results_path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
def get_dataset_itr(args, task, models):
return task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
).next_epoch_itr(shuffle=False)
def process_predictions(
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
):
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
if "words" in hypo:
hyp_words = " ".join(hypo["words"])
else:
hyp_words = post_process(hyp_pieces, args.post_process)
if res_files is not None:
print(
"{} ({}-{})".format(hyp_pieces, speaker, id),
file=res_files["hypo.units"],
)
print(
"{} ({}-{})".format(hyp_words, speaker, id),
file=res_files["hypo.words"],
)
tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = post_process(tgt_pieces, args.post_process)
if res_files is not None:
print(
"{} ({}-{})".format(tgt_pieces, speaker, id),
file=res_files["ref.units"],
)
print(
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
)
if not args.quiet:
logger.info("HYPO:" + hyp_words)
logger.info("TARGET:" + tgt_words)
logger.info("___________________")
hyp_words = hyp_words.split()
tgt_words = tgt_words.split()
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def prepare_result_files(args):
def get_res_file(file_prefix):
if args.num_shards > 1:
file_prefix = f"{args.shard_id}_{file_prefix}"
path = os.path.join(
args.results_path,
"{}-{}-{}.txt".format(
file_prefix, os.path.basename(args.path), args.gen_subset
),
)
return open(path, "w", buffering=1)
if not args.results_path:
return None
return {
"hypo.words": get_res_file("hypo.word"),
"hypo.units": get_res_file("hypo.units"),
"ref.words": get_res_file("ref.word"),
"ref.units": get_res_file("ref.units"),
}
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
class ExistingEmissionsDecoder(object):
def __init__(self, decoder, emissions):
self.decoder = decoder
self.emissions = emissions
def generate(self, models, sample, **unused):
ids = sample["id"].cpu().numpy()
try:
emissions = np.stack(self.emissions[ids])
except:
print([x.shape for x in self.emissions[ids]])
raise Exception("invalid sizes")
emissions = torch.from_numpy(emissions)
return self.decoder.decode(emissions)
def main(args, task=None, model_state=None):
check_args(args)
if args.max_tokens is None and args.batch_size is None:
args.max_tokens = 4000000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
logger.info("| decoding with criterion {}".format(args.criterion))
task = tasks.setup_task(args)
# Load ensemble
if args.load_emissions:
models, criterions = [], []
task.load_dataset(args.gen_subset)
else:
logger.info("| loading model(s) from {}".format(args.path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(args.path, separator="\\"),
arg_overrides=ast.literal_eval(args.model_overrides),
task=task,
suffix=args.checkpoint_suffix,
strict=(args.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count,
state=model_state,
)
optimize_models(args, use_cuda, models)
task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
# Set dictionary
tgt_dict = task.target_dictionary
logger.info(
"| {} {} {} examples".format(
args.data, args.gen_subset, len(task.dataset(args.gen_subset))
)
)
# hack to pass transitions to W2lDecoder
if args.criterion == "asg_loss":
raise NotImplementedError("asg_loss is currently not supported")
# trans = criterions[0].asg.trans.data
# args.asg_transitions = torch.flatten(trans).tolist()
# Load dataset (possibly sharded)
itr = get_dataset_itr(args, task, models)
# Initialize generator
gen_timer = StopwatchMeter()
def build_generator(args):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, task.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, task.target_dictionary)
elif w2l_decoder == "fairseqlm":
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
return W2lFairseqLMDecoder(args, task.target_dictionary)
else:
print(
"only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
)
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
generator = build_generator(args)
if args.load_emissions:
generator = ExistingEmissionsDecoder(
generator, np.load(args.load_emissions, allow_pickle=True)
)
logger.info("loaded emissions from " + args.load_emissions)
num_sentences = 0
if args.results_path is not None and not os.path.exists(args.results_path):
os.makedirs(args.results_path)
max_source_pos = (
utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
),
)
if max_source_pos is not None:
max_source_pos = max_source_pos[0]
if max_source_pos is not None:
max_source_pos = max_source_pos[0] - 1
if args.dump_emissions:
emissions = {}
if args.dump_features:
features = {}
models[0].bert.proj = None
else:
res_files = prepare_result_files(args)
errs_t = 0
lengths_t = 0
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]
gen_timer.start()
if args.dump_emissions:
with torch.no_grad():
encoder_out = models[0](**sample["net_input"])
emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
emm = emm.transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
emissions[id.item()] = emm[i]
continue
elif args.dump_features:
with torch.no_grad():
encoder_out = models[0](**sample["net_input"])
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
padding = (
encoder_out["encoder_padding_mask"][i].cpu().numpy()
if encoder_out["encoder_padding_mask"] is not None
else None
)
features[id.item()] = (feat[i], padding)
continue
hypos = task.inference_step(generator, models, sample, prefix_tokens)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample["id"].tolist()):
speaker = None
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
id = sample_id
toks = (
sample["target"][i, :]
if "target_label" not in sample
else sample["target_label"][i, :]
)
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
# Process top predictions
errs, length = process_predictions(
args,
hypos[i],
None,
tgt_dict,
target_tokens,
res_files,
speaker,
id,
)
errs_t += errs
lengths_t += length
wps_meter.update(num_generated_tokens)
t.log({"wps": round(wps_meter.avg)})
num_sentences += (
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
)
wer = None
if args.dump_emissions:
emm_arr = []
for i in range(len(emissions)):
emm_arr.append(emissions[i])
np.save(args.dump_emissions, emm_arr)
logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
elif args.dump_features:
feat_arr = []
for i in range(len(features)):
feat_arr.append(features[i])
np.save(args.dump_features, feat_arr)
logger.info(f"saved {len(features)} emissions to {args.dump_features}")
else:
if lengths_t > 0:
wer = errs_t * 100.0 / lengths_t
logger.info(f"WER: {wer}")
logger.info(
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
"sentences/s, {:.2f} tokens/s)".format(
num_sentences,
gen_timer.n,
gen_timer.sum,
num_sentences / gen_timer.sum,
1.0 / gen_timer.avg,
)
)
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
return task, wer
def make_parser():
parser = options.get_generation_parser()
parser = add_asr_eval_argument(parser)
return parser
def cli_main():
parser = make_parser()
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <iostream>
#include "fstext/fstext-lib.h" // @manual
#include "util/common-utils.h" // @manual
/*
* This program is to modify a FST without self-loop by:
* for each incoming arc with non-eps input symbol, add a self-loop arc
* with that non-eps symbol as input and eps as output.
*
* This is to make sure the resultant FST can do deduplication for repeated
* symbols, which is very common in acoustic model
*
*/
namespace {
int32 AddSelfLoopsSimple(fst::StdVectorFst* fst) {
typedef fst::MutableArcIterator<fst::StdVectorFst> IterType;
int32 num_states_before = fst->NumStates();
fst::MakePrecedingInputSymbolsSame(false, fst);
int32 num_states_after = fst->NumStates();
KALDI_LOG << "There are " << num_states_before
<< " states in the original FST; "
<< " after MakePrecedingInputSymbolsSame, there are "
<< num_states_after << " states " << std::endl;
auto weight_one = fst::StdArc::Weight::One();
int32 num_arc_added = 0;
fst::StdArc self_loop_arc;
self_loop_arc.weight = weight_one;
int32 num_states = fst->NumStates();
std::vector<std::set<int32>> incoming_non_eps_label_per_state(num_states);
for (int32 state = 0; state < num_states; state++) {
for (IterType aiter(fst, state); !aiter.Done(); aiter.Next()) {
fst::StdArc arc(aiter.Value());
if (arc.ilabel != 0) {
incoming_non_eps_label_per_state[arc.nextstate].insert(arc.ilabel);
}
}
}
for (int32 state = 0; state < num_states; state++) {
if (!incoming_non_eps_label_per_state[state].empty()) {
auto& ilabel_set = incoming_non_eps_label_per_state[state];
for (auto it = ilabel_set.begin(); it != ilabel_set.end(); it++) {
self_loop_arc.ilabel = *it;
self_loop_arc.olabel = 0;
self_loop_arc.nextstate = state;
fst->AddArc(state, self_loop_arc);
num_arc_added++;
}
}
}
return num_arc_added;
}
void print_usage() {
std::cout << "add-self-loop-simple usage:\n"
"\tadd-self-loop-simple <in-fst> <out-fst> \n";
}
} // namespace
int main(int argc, char** argv) {
if (argc != 3) {
print_usage();
exit(1);
}
auto input = argv[1];
auto output = argv[2];
auto fst = fst::ReadFstKaldi(input);
auto num_states = fst->NumStates();
KALDI_LOG << "Loading FST from " << input << " with " << num_states
<< " states." << std::endl;
int32 num_arc_added = AddSelfLoopsSimple(fst);
KALDI_LOG << "Adding " << num_arc_added << " self-loop arcs " << std::endl;
fst::WriteFstKaldi(*fst, std::string(output));
KALDI_LOG << "Writing FST to " << output << std::endl;
delete fst;
}
\ No newline at end of file
# @package _group_
data_dir: ???
fst_dir: ???
in_labels: ???
kaldi_root: ???
lm_arpa: ???
blank_symbol: <s>
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from concurrent.futures import ThreadPoolExecutor
import logging
from omegaconf import MISSING
import os
import torch
from typing import Optional
import warnings
from dataclasses import dataclass
from fairseq.dataclass import FairseqDataclass
from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi
logger = logging.getLogger(__name__)
@dataclass
class KaldiDecoderConfig(FairseqDataclass):
hlg_graph_path: Optional[str] = None
output_dict: str = MISSING
kaldi_initializer_config: Optional[KaldiInitializerConfig] = None
acoustic_scale: float = 0.5
max_active: int = 10000
beam_delta: float = 0.5
hash_ratio: float = 2.0
is_lattice: bool = False
lattice_beam: float = 10.0
prune_interval: int = 25
determinize_lattice: bool = True
prune_scale: float = 0.1
max_mem: int = 0
phone_determinize: bool = True
word_determinize: bool = True
minimize: bool = True
num_threads: int = 1
class KaldiDecoder(object):
def __init__(
self,
cfg: KaldiDecoderConfig,
beam: int,
nbest: int = 1,
):
try:
from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer
from kaldi.base import set_verbose_level
from kaldi.decoder import (
FasterDecoder,
FasterDecoderOptions,
LatticeFasterDecoder,
LatticeFasterDecoderOptions,
)
from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions
from kaldi.fstext import read_fst_kaldi, SymbolTable
except:
warnings.warn(
"pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi"
)
# set_verbose_level(2)
self.acoustic_scale = cfg.acoustic_scale
self.nbest = nbest
if cfg.hlg_graph_path is None:
assert (
cfg.kaldi_initializer_config is not None
), "Must provide hlg graph path or kaldi initializer config"
cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config)
assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path
if cfg.is_lattice:
self.dec_cls = LatticeFasterDecoder
opt_cls = LatticeFasterDecoderOptions
self.rec_cls = LatticeFasterRecognizer
else:
assert self.nbest == 1, "nbest > 1 requires lattice decoder"
self.dec_cls = FasterDecoder
opt_cls = FasterDecoderOptions
self.rec_cls = FasterRecognizer
self.decoder_options = opt_cls()
self.decoder_options.beam = beam
self.decoder_options.max_active = cfg.max_active
self.decoder_options.beam_delta = cfg.beam_delta
self.decoder_options.hash_ratio = cfg.hash_ratio
if cfg.is_lattice:
self.decoder_options.lattice_beam = cfg.lattice_beam
self.decoder_options.prune_interval = cfg.prune_interval
self.decoder_options.determinize_lattice = cfg.determinize_lattice
self.decoder_options.prune_scale = cfg.prune_scale
det_opts = DeterminizeLatticePhonePrunedOptions()
det_opts.max_mem = cfg.max_mem
det_opts.phone_determinize = cfg.phone_determinize
det_opts.word_determinize = cfg.word_determinize
det_opts.minimize = cfg.minimize
self.decoder_options.det_opts = det_opts
self.output_symbols = {}
with open(cfg.output_dict, "r") as f:
for line in f:
items = line.rstrip().split()
assert len(items) == 2
self.output_symbols[int(items[1])] = items[0]
logger.info(f"Loading FST from {cfg.hlg_graph_path}")
self.fst = read_fst_kaldi(cfg.hlg_graph_path)
self.symbol_table = SymbolTable.read_text(cfg.output_dict)
self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads)
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions, padding = self.get_emissions(models, encoder_input)
return self.decode(emissions, padding)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
model = models[0]
all_encoder_out = [m(**encoder_input) for m in models]
if len(all_encoder_out) > 1:
if "encoder_out" in all_encoder_out[0]:
encoder_out = {
"encoder_out": sum(e["encoder_out"] for e in all_encoder_out)
/ len(all_encoder_out),
"encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"],
}
padding = encoder_out["encoder_padding_mask"]
else:
encoder_out = {
"logits": sum(e["logits"] for e in all_encoder_out)
/ len(all_encoder_out),
"padding_mask": all_encoder_out[0]["padding_mask"],
}
padding = encoder_out["padding_mask"]
else:
encoder_out = all_encoder_out[0]
padding = (
encoder_out["padding_mask"]
if "padding_mask" in encoder_out
else encoder_out["encoder_padding_mask"]
)
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out, normalize=True)
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return (
emissions.cpu().float().transpose(0, 1),
padding.cpu() if padding is not None and padding.any() else None,
)
def decode_one(self, logits, padding):
from kaldi.matrix import Matrix
decoder = self.dec_cls(self.fst, self.decoder_options)
asr = self.rec_cls(
decoder, self.symbol_table, acoustic_scale=self.acoustic_scale
)
if padding is not None:
logits = logits[~padding]
mat = Matrix(logits.numpy())
out = asr.decode(mat)
if self.nbest > 1:
from kaldi.fstext import shortestpath
from kaldi.fstext.utils import (
convert_compact_lattice_to_lattice,
convert_lattice_to_std,
convert_nbest_to_list,
get_linear_symbol_sequence,
)
lat = out["lattice"]
sp = shortestpath(lat, nshortest=self.nbest)
sp = convert_compact_lattice_to_lattice(sp)
sp = convert_lattice_to_std(sp)
seq = convert_nbest_to_list(sp)
results = []
for s in seq:
_, o, w = get_linear_symbol_sequence(s)
words = list(self.output_symbols[z] for z in o)
results.append(
{
"tokens": words,
"words": words,
"score": w.value,
"emissions": logits,
}
)
return results
else:
words = out["text"].split()
return [
{
"tokens": words,
"words": words,
"score": out["likelihood"],
"emissions": logits,
}
]
def decode(self, emissions, padding):
if padding is None:
padding = [None] * len(emissions)
ret = list(
map(
lambda e, p: self.executor.submit(self.decode_one, e, p),
emissions,
padding,
)
)
return ret
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import hydra
from hydra.core.config_store import ConfigStore
import logging
from omegaconf import MISSING, OmegaConf
import os
import os.path as osp
from pathlib import Path
import subprocess
from typing import Optional
from fairseq.data.dictionary import Dictionary
from fairseq.dataclass import FairseqDataclass
script_dir = Path(__file__).resolve().parent
config_path = script_dir / "config"
logger = logging.getLogger(__name__)
@dataclass
class KaldiInitializerConfig(FairseqDataclass):
data_dir: str = MISSING
fst_dir: Optional[str] = None
in_labels: str = MISSING
out_labels: Optional[str] = None
wav2letter_lexicon: Optional[str] = None
lm_arpa: str = MISSING
kaldi_root: str = MISSING
blank_symbol: str = "<s>"
silence_symbol: Optional[str] = None
def create_units(fst_dir: Path, in_labels: str, vocab: Dictionary) -> Path:
in_units_file = fst_dir / f"kaldi_dict.{in_labels}.txt"
if not in_units_file.exists():
logger.info(f"Creating {in_units_file}")
with open(in_units_file, "w") as f:
print("<eps> 0", file=f)
i = 1
for symb in vocab.symbols[vocab.nspecial :]:
if not symb.startswith("madeupword"):
print(f"{symb} {i}", file=f)
i += 1
return in_units_file
def create_lexicon(
cfg: KaldiInitializerConfig,
fst_dir: Path,
unique_label: str,
in_units_file: Path,
out_words_file: Path,
) -> (Path, Path):
disambig_in_units_file = fst_dir / f"kaldi_dict.{cfg.in_labels}_disambig.txt"
lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}.txt"
disambig_lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}_disambig.txt"
if (
not lexicon_file.exists()
or not disambig_lexicon_file.exists()
or not disambig_in_units_file.exists()
):
logger.info(f"Creating {lexicon_file} (in units file: {in_units_file})")
assert cfg.wav2letter_lexicon is not None or cfg.in_labels == cfg.out_labels
if cfg.wav2letter_lexicon is not None:
lm_words = set()
with open(out_words_file, "r") as lm_dict_f:
for line in lm_dict_f:
lm_words.add(line.split()[0])
num_skipped = 0
total = 0
with open(cfg.wav2letter_lexicon, "r") as w2l_lex_f, open(
lexicon_file, "w"
) as out_f:
for line in w2l_lex_f:
items = line.rstrip().split("\t")
assert len(items) == 2, items
if items[0] in lm_words:
print(items[0], items[1], file=out_f)
else:
num_skipped += 1
logger.debug(
f"Skipping word {items[0]} as it was not found in LM"
)
total += 1
if num_skipped > 0:
logger.warning(
f"Skipped {num_skipped} out of {total} words as they were not found in LM"
)
else:
with open(in_units_file, "r") as in_f, open(lexicon_file, "w") as out_f:
for line in in_f:
symb = line.split()[0]
if symb != "<eps>" and symb != "<ctc_blank>" and symb != "<SIL>":
print(symb, symb, file=out_f)
lex_disambig_path = (
Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_lex_disambig.pl"
)
res = subprocess.run(
[lex_disambig_path, lexicon_file, disambig_lexicon_file],
check=True,
capture_output=True,
)
ndisambig = int(res.stdout)
disamib_path = Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_disambig.pl"
res = subprocess.run(
[disamib_path, "--include-zero", in_units_file, str(ndisambig)],
check=True,
capture_output=True,
)
with open(disambig_in_units_file, "wb") as f:
f.write(res.stdout)
return disambig_lexicon_file, disambig_in_units_file
def create_G(
kaldi_root: Path, fst_dir: Path, lm_arpa: Path, arpa_base: str
) -> (Path, Path):
out_words_file = fst_dir / f"kaldi_dict.{arpa_base}.txt"
grammar_graph = fst_dir / f"G_{arpa_base}.fst"
if not grammar_graph.exists() or not out_words_file.exists():
logger.info(f"Creating {grammar_graph}")
arpa2fst = kaldi_root / "src/lmbin/arpa2fst"
subprocess.run(
[
arpa2fst,
"--disambig-symbol=#0",
f"--write-symbol-table={out_words_file}",
lm_arpa,
grammar_graph,
],
check=True,
)
return grammar_graph, out_words_file
def create_L(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
lexicon_file: Path,
in_units_file: Path,
out_words_file: Path,
) -> Path:
lexicon_graph = fst_dir / f"L.{unique_label}.fst"
if not lexicon_graph.exists():
logger.info(f"Creating {lexicon_graph} (in units: {in_units_file})")
make_lex = kaldi_root / "egs/wsj/s5/utils/make_lexicon_fst.pl"
fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile"
fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops"
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
def write_disambig_symbol(file):
with open(file, "r") as f:
for line in f:
items = line.rstrip().split()
if items[0] == "#0":
out_path = str(file) + "_disamig"
with open(out_path, "w") as out_f:
print(items[1], file=out_f)
return out_path
return None
in_disambig_sym = write_disambig_symbol(in_units_file)
assert in_disambig_sym is not None
out_disambig_sym = write_disambig_symbol(out_words_file)
assert out_disambig_sym is not None
try:
with open(lexicon_graph, "wb") as out_f:
res = subprocess.run(
[make_lex, lexicon_file], capture_output=True, check=True
)
assert len(res.stderr) == 0, res.stderr.decode("utf-8")
res = subprocess.run(
[
fstcompile,
f"--isymbols={in_units_file}",
f"--osymbols={out_words_file}",
"--keep_isymbols=false",
"--keep_osymbols=false",
],
input=res.stdout,
capture_output=True,
)
assert len(res.stderr) == 0, res.stderr.decode("utf-8")
res = subprocess.run(
[fstaddselfloops, in_disambig_sym, out_disambig_sym],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstarcsort, "--sort_type=olabel"],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(lexicon_graph)
raise
except AssertionError:
os.remove(lexicon_graph)
raise
return lexicon_graph
def create_LG(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
lexicon_graph: Path,
grammar_graph: Path,
) -> Path:
lg_graph = fst_dir / f"LG.{unique_label}.fst"
if not lg_graph.exists():
logger.info(f"Creating {lg_graph}")
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
fstpushspecial = kaldi_root / "src/fstbin/fstpushspecial"
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
try:
with open(lg_graph, "wb") as out_f:
res = subprocess.run(
[fsttablecompose, lexicon_graph, grammar_graph],
capture_output=True,
check=True,
)
res = subprocess.run(
[
fstdeterminizestar,
"--use-log=true",
],
input=res.stdout,
capture_output=True,
)
res = subprocess.run(
[fstminimizeencoded],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstpushspecial],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstarcsort, "--sort_type=ilabel"],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(lg_graph)
raise
return lg_graph
def create_H(
kaldi_root: Path,
fst_dir: Path,
disambig_out_units_file: Path,
in_labels: str,
vocab: Dictionary,
blk_sym: str,
silence_symbol: Optional[str],
) -> (Path, Path, Path):
h_graph = (
fst_dir / f"H.{in_labels}{'_' + silence_symbol if silence_symbol else ''}.fst"
)
h_out_units_file = fst_dir / f"kaldi_dict.h_out.{in_labels}.txt"
disambig_in_units_file_int = Path(str(h_graph) + "isym_disambig.int")
disambig_out_units_file_int = Path(str(disambig_out_units_file) + ".int")
if (
not h_graph.exists()
or not h_out_units_file.exists()
or not disambig_in_units_file_int.exists()
):
logger.info(f"Creating {h_graph}")
eps_sym = "<eps>"
num_disambig = 0
osymbols = []
with open(disambig_out_units_file, "r") as f, open(
disambig_out_units_file_int, "w"
) as out_f:
for line in f:
symb, id = line.rstrip().split()
if line.startswith("#"):
num_disambig += 1
print(id, file=out_f)
else:
if len(osymbols) == 0:
assert symb == eps_sym, symb
osymbols.append((symb, id))
i_idx = 0
isymbols = [(eps_sym, 0)]
imap = {}
for i, s in enumerate(vocab.symbols):
i_idx += 1
isymbols.append((s, i_idx))
imap[s] = i_idx
fst_str = []
node_idx = 0
root_node = node_idx
special_symbols = [blk_sym]
if silence_symbol is not None:
special_symbols.append(silence_symbol)
for ss in special_symbols:
fst_str.append("{} {} {} {}".format(root_node, root_node, ss, eps_sym))
for symbol, _ in osymbols:
if symbol == eps_sym or symbol.startswith("#"):
continue
node_idx += 1
# 1. from root to emitting state
fst_str.append("{} {} {} {}".format(root_node, node_idx, symbol, symbol))
# 2. from emitting state back to root
fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym))
# 3. from emitting state to optional blank state
pre_node = node_idx
node_idx += 1
for ss in special_symbols:
fst_str.append("{} {} {} {}".format(pre_node, node_idx, ss, eps_sym))
# 4. from blank state back to root
fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym))
fst_str.append("{}".format(root_node))
fst_str = "\n".join(fst_str)
h_str = str(h_graph)
isym_file = h_str + ".isym"
with open(isym_file, "w") as f:
for sym, id in isymbols:
f.write("{} {}\n".format(sym, id))
with open(h_out_units_file, "w") as f:
for sym, id in osymbols:
f.write("{} {}\n".format(sym, id))
with open(disambig_in_units_file_int, "w") as f:
disam_sym_id = len(isymbols)
for _ in range(num_disambig):
f.write("{}\n".format(disam_sym_id))
disam_sym_id += 1
fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile"
fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops"
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
try:
with open(h_graph, "wb") as out_f:
res = subprocess.run(
[
fstcompile,
f"--isymbols={isym_file}",
f"--osymbols={h_out_units_file}",
"--keep_isymbols=false",
"--keep_osymbols=false",
],
input=str.encode(fst_str),
capture_output=True,
check=True,
)
res = subprocess.run(
[
fstaddselfloops,
disambig_in_units_file_int,
disambig_out_units_file_int,
],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstarcsort, "--sort_type=olabel"],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(h_graph)
raise
return h_graph, h_out_units_file, disambig_in_units_file_int
def create_HLGa(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
h_graph: Path,
lg_graph: Path,
disambig_in_words_file_int: Path,
) -> Path:
hlga_graph = fst_dir / f"HLGa.{unique_label}.fst"
if not hlga_graph.exists():
logger.info(f"Creating {hlga_graph}")
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols"
fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal"
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
try:
with open(hlga_graph, "wb") as out_f:
res = subprocess.run(
[
fsttablecompose,
h_graph,
lg_graph,
],
capture_output=True,
check=True,
)
res = subprocess.run(
[fstdeterminizestar, "--use-log=true"],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstrmsymbols, disambig_in_words_file_int],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstrmepslocal],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstminimizeencoded],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(hlga_graph)
raise
return hlga_graph
def create_HLa(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
h_graph: Path,
l_graph: Path,
disambig_in_words_file_int: Path,
) -> Path:
hla_graph = fst_dir / f"HLa.{unique_label}.fst"
if not hla_graph.exists():
logger.info(f"Creating {hla_graph}")
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols"
fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal"
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
try:
with open(hla_graph, "wb") as out_f:
res = subprocess.run(
[
fsttablecompose,
h_graph,
l_graph,
],
capture_output=True,
check=True,
)
res = subprocess.run(
[fstdeterminizestar, "--use-log=true"],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstrmsymbols, disambig_in_words_file_int],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstrmepslocal],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstminimizeencoded],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(hla_graph)
raise
return hla_graph
def create_HLG(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
hlga_graph: Path,
prefix: str = "HLG",
) -> Path:
hlg_graph = fst_dir / f"{prefix}.{unique_label}.fst"
if not hlg_graph.exists():
logger.info(f"Creating {hlg_graph}")
add_self_loop = script_dir / "add-self-loop-simple"
kaldi_src = kaldi_root / "src"
kaldi_lib = kaldi_src / "lib"
try:
if not add_self_loop.exists():
fst_include = kaldi_root / "tools/openfst-1.6.7/include"
add_self_loop_src = script_dir / "add-self-loop-simple.cc"
subprocess.run(
[
"c++",
f"-I{kaldi_src}",
f"-I{fst_include}",
f"-L{kaldi_lib}",
add_self_loop_src,
"-lkaldi-base",
"-lkaldi-fstext",
"-o",
add_self_loop,
],
check=True,
)
my_env = os.environ.copy()
my_env["LD_LIBRARY_PATH"] = f"{kaldi_lib}:{my_env['LD_LIBRARY_PATH']}"
subprocess.run(
[
add_self_loop,
hlga_graph,
hlg_graph,
],
check=True,
capture_output=True,
env=my_env,
)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
raise
return hlg_graph
def initalize_kaldi(cfg: KaldiInitializerConfig) -> Path:
if cfg.fst_dir is None:
cfg.fst_dir = osp.join(cfg.data_dir, "kaldi")
if cfg.out_labels is None:
cfg.out_labels = cfg.in_labels
kaldi_root = Path(cfg.kaldi_root)
data_dir = Path(cfg.data_dir)
fst_dir = Path(cfg.fst_dir)
fst_dir.mkdir(parents=True, exist_ok=True)
arpa_base = osp.splitext(osp.basename(cfg.lm_arpa))[0]
unique_label = f"{cfg.in_labels}.{arpa_base}"
with open(data_dir / f"dict.{cfg.in_labels}.txt", "r") as f:
vocab = Dictionary.load(f)
in_units_file = create_units(fst_dir, cfg.in_labels, vocab)
grammar_graph, out_words_file = create_G(
kaldi_root, fst_dir, Path(cfg.lm_arpa), arpa_base
)
disambig_lexicon_file, disambig_L_in_units_file = create_lexicon(
cfg, fst_dir, unique_label, in_units_file, out_words_file
)
h_graph, h_out_units_file, disambig_in_units_file_int = create_H(
kaldi_root,
fst_dir,
disambig_L_in_units_file,
cfg.in_labels,
vocab,
cfg.blank_symbol,
cfg.silence_symbol,
)
lexicon_graph = create_L(
kaldi_root,
fst_dir,
unique_label,
disambig_lexicon_file,
disambig_L_in_units_file,
out_words_file,
)
lg_graph = create_LG(
kaldi_root, fst_dir, unique_label, lexicon_graph, grammar_graph
)
hlga_graph = create_HLGa(
kaldi_root, fst_dir, unique_label, h_graph, lg_graph, disambig_in_units_file_int
)
hlg_graph = create_HLG(kaldi_root, fst_dir, unique_label, hlga_graph)
# for debugging
# hla_graph = create_HLa(kaldi_root, fst_dir, unique_label, h_graph, lexicon_graph, disambig_in_units_file_int)
# hl_graph = create_HLG(kaldi_root, fst_dir, unique_label, hla_graph, prefix="HL_looped")
# create_HLG(kaldi_root, fst_dir, "phnc", h_graph, prefix="H_looped")
return hlg_graph
@hydra.main(config_path=config_path, config_name="kaldi_initializer")
def cli_main(cfg: KaldiInitializerConfig) -> None:
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
cfg = OmegaConf.create(container)
OmegaConf.set_struct(cfg, True)
initalize_kaldi(cfg)
if __name__ == "__main__":
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
try:
from hydra._internal.utils import (
get_args,
) # pylint: disable=import-outside-toplevel
cfg_name = get_args().config_name or "kaldi_initializer"
except ImportError:
logger.warning("Failed to get config name from hydra args")
cfg_name = "kaldi_initializer"
cs = ConfigStore.instance()
cs.store(name=cfg_name, node=KaldiInitializerConfig)
cli_main()
import importlib
import os
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module("examples.speech_recognition.models." + model_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import math
from collections.abc import Iterable
import torch
import torch.nn as nn
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqEncoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import (
LinearizedConvolution,
TransformerDecoderLayer,
TransformerEncoderLayer,
VGGBlock,
)
@register_model("asr_vggtransformer")
class VGGTransformerModel(FairseqEncoderDecoderModel):
"""
Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--vggblock-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one vggblock:
[(out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
use_layer_norm), ...])
""",
)
parser.add_argument(
"--transformer-enc-config",
type=str,
metavar="EXPR",
help=""""
a tuple containing the configuration of the encoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]')
""",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="""
encoder output dimension, can be None. If specified, projecting the
transformer output to the specified dimension""",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--tgt-embed-dim",
type=int,
metavar="N",
help="embedding dimension of the decoder target tokens",
)
parser.add_argument(
"--transformer-dec-config",
type=str,
metavar="EXPR",
help="""
a tuple containing the configuration of the decoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]
""",
)
parser.add_argument(
"--conv-dec-config",
type=str,
metavar="EXPR",
help="""
an array of tuples for the decoder 1-D convolution config
[(out_channels, conv_kernel_size, use_layer_norm), ...]""",
)
@classmethod
def build_encoder(cls, args, task):
return VGGTransformerEncoder(
input_feat_per_channel=args.input_feat_per_channel,
vggblock_config=eval(args.vggblock_enc_config),
transformer_config=eval(args.transformer_enc_config),
encoder_output_dim=args.enc_output_dim,
in_channels=args.in_channels,
)
@classmethod
def build_decoder(cls, args, task):
return TransformerDecoder(
dictionary=task.target_dictionary,
embed_dim=args.tgt_embed_dim,
transformer_config=eval(args.transformer_dec_config),
conv_config=eval(args.conv_dec_config),
encoder_output_dim=args.enc_output_dim,
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted
# (in case there are any new ones)
base_architecture(args)
encoder = cls.build_encoder(args, task)
decoder = cls.build_decoder(args, task)
return cls(encoder, decoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2
DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2
# 256: embedding dimension
# 4: number of heads
# 1024: FFN
# True: apply layerNorm before (dropout + resiaul) instead of after
# 0.2 (dropout): dropout after MultiheadAttention and second FC
# 0.2 (attention_dropout): dropout in MultiheadAttention
# 0.2 (relu_dropout): dropout after ReLu
DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2
DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2
# TODO: repace transformer encoder config from one liner
# to explicit args to get rid of this transformation
def prepare_transformer_encoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.encoder_embed_dim = input_dim
args.encoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.encoder_normalize_before = normalize_before
args.encoder_ffn_embed_dim = ffn_dim
return args
def prepare_transformer_decoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.decoder_embed_dim = input_dim
args.decoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.decoder_normalize_before = normalize_before
args.decoder_ffn_embed_dim = ffn_dim
return args
class VGGTransformerEncoder(FairseqEncoder):
"""VGG + Transformer encoder"""
def __init__(
self,
input_feat_per_channel,
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
encoder_output_dim=512,
in_channels=1,
transformer_context=None,
transformer_sampling=None,
):
"""constructor for VGGTransformerEncoder
Args:
- input_feat_per_channel: feature dim (not including stacked,
just base feature)
- in_channel: # input channels (e.g., if stack 8 feature vector
together, this is 8)
- vggblock_config: configuration of vggblock, see comments on
DEFAULT_ENC_VGGBLOCK_CONFIG
- transformer_config: configuration of transformer layer, see comments
on DEFAULT_ENC_TRANSFORMER_CONFIG
- encoder_output_dim: final transformer output embedding dimension
- transformer_context: (left, right) if set, self-attention will be focused
on (t-left, t+right)
- transformer_sampling: an iterable of int, must match with
len(transformer_config), transformer_sampling[i] indicates sampling
factor for i-th transformer layer, after multihead att and feedfoward
part
"""
super().__init__(None)
self.num_vggblocks = 0
if vggblock_config is not None:
if not isinstance(vggblock_config, Iterable):
raise ValueError("vggblock_config is not iterable")
self.num_vggblocks = len(vggblock_config)
self.conv_layers = nn.ModuleList()
self.in_channels = in_channels
self.input_dim = input_feat_per_channel
self.pooling_kernel_sizes = []
if vggblock_config is not None:
for _, config in enumerate(vggblock_config):
(
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
layer_norm,
) = config
self.conv_layers.append(
VGGBlock(
in_channels,
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
input_dim=input_feat_per_channel,
layer_norm=layer_norm,
)
)
self.pooling_kernel_sizes.append(pooling_kernel_size)
in_channels = out_channels
input_feat_per_channel = self.conv_layers[-1].output_dim
transformer_input_dim = self.infer_conv_output_dim(
self.in_channels, self.input_dim
)
# transformer_input_dim is the output dimension of VGG part
self.validate_transformer_config(transformer_config)
self.transformer_context = self.parse_transformer_context(transformer_context)
self.transformer_sampling = self.parse_transformer_sampling(
transformer_sampling, len(transformer_config)
)
self.transformer_layers = nn.ModuleList()
if transformer_input_dim != transformer_config[0][0]:
self.transformer_layers.append(
Linear(transformer_input_dim, transformer_config[0][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.transformer_layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[i])
)
)
self.encoder_output_dim = encoder_output_dim
self.transformer_layers.extend(
[
Linear(transformer_config[-1][0], encoder_output_dim),
LayerNorm(encoder_output_dim),
]
)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
bsz, max_seq_len, _ = src_tokens.size()
x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
x = x.transpose(1, 2).contiguous()
# (B, C, T, feat)
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
bsz, _, output_seq_len, _ = x.size()
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat)
x = x.transpose(1, 2).transpose(0, 1)
x = x.contiguous().view(output_seq_len, bsz, -1)
input_lengths = src_lengths.clone()
for s in self.pooling_kernel_sizes:
input_lengths = (input_lengths.float() / s).ceil().long()
encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
input_lengths, batch_first=True
)
if not encoder_padding_mask.any():
encoder_padding_mask = None
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
transformer_layer_idx = 0
for layer_idx in range(len(self.transformer_layers)):
if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer):
x = self.transformer_layers[layer_idx](
x, encoder_padding_mask, attn_mask
)
if self.transformer_sampling[transformer_layer_idx] != 1:
sampling_factor = self.transformer_sampling[transformer_layer_idx]
x, encoder_padding_mask, attn_mask = self.slice(
x, encoder_padding_mask, attn_mask, sampling_factor
)
transformer_layer_idx += 1
else:
x = self.transformer_layers[layer_idx](x)
# encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate
# whether encoder_output[t, b] is valid or not (valid=0, invalid=1)
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": encoder_padding_mask.t()
if encoder_padding_mask is not None
else None,
# (B, T) --> (T, B)
}
def infer_conv_output_dim(self, in_channels, input_dim):
sample_seq_len = 200
sample_bsz = 10
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
for i, _ in enumerate(self.conv_layers):
x = self.conv_layers[i](x)
x = x.transpose(1, 2)
mb, seq = x.size()[:2]
return x.contiguous().view(mb, seq, -1).size(-1)
def validate_transformer_config(self, transformer_config):
for config in transformer_config:
input_dim, num_heads = config[:2]
if input_dim % num_heads != 0:
msg = (
"ERROR in transformer config {}: ".format(config)
+ "input dimension {} ".format(input_dim)
+ "not dividable by number of heads {}".format(num_heads)
)
raise ValueError(msg)
def parse_transformer_context(self, transformer_context):
"""
transformer_context can be the following:
- None; indicates no context is used, i.e.,
transformer can access full context
- a tuple/list of two int; indicates left and right context,
any number <0 indicates infinite context
* e.g., (5, 6) indicates that for query at x_t, transformer can
access [t-5, t+6] (inclusive)
* e.g., (-1, 6) indicates that for query at x_t, transformer can
access [0, t+6] (inclusive)
"""
if transformer_context is None:
return None
if not isinstance(transformer_context, Iterable):
raise ValueError("transformer context must be Iterable if it is not None")
if len(transformer_context) != 2:
raise ValueError("transformer context must have length 2")
left_context = transformer_context[0]
if left_context < 0:
left_context = None
right_context = transformer_context[1]
if right_context < 0:
right_context = None
if left_context is None and right_context is None:
return None
return (left_context, right_context)
def parse_transformer_sampling(self, transformer_sampling, num_layers):
"""
parsing transformer sampling configuration
Args:
- transformer_sampling, accepted input:
* None, indicating no sampling
* an Iterable with int (>0) as element
- num_layers, expected number of transformer layers, must match with
the length of transformer_sampling if it is not None
Returns:
- A tuple with length num_layers
"""
if transformer_sampling is None:
return (1,) * num_layers
if not isinstance(transformer_sampling, Iterable):
raise ValueError(
"transformer_sampling must be an iterable if it is not None"
)
if len(transformer_sampling) != num_layers:
raise ValueError(
"transformer_sampling {} does not match with the number "
"of layers {}".format(transformer_sampling, num_layers)
)
for layer, value in enumerate(transformer_sampling):
if not isinstance(value, int):
raise ValueError("Invalid value in transformer_sampling: ")
if value < 1:
raise ValueError(
"{} layer's subsampling is {}.".format(layer, value)
+ " This is not allowed! "
)
return transformer_sampling
def slice(self, embedding, padding_mask, attn_mask, sampling_factor):
"""
embedding is a (T, B, D) tensor
padding_mask is a (B, T) tensor or None
attn_mask is a (T, T) tensor or None
"""
embedding = embedding[::sampling_factor, :, :]
if padding_mask is not None:
padding_mask = padding_mask[:, ::sampling_factor]
if attn_mask is not None:
attn_mask = attn_mask[::sampling_factor, ::sampling_factor]
return embedding, padding_mask, attn_mask
def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1):
"""
create attention mask according to sequence lengths and transformer
context
Args:
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is
the length of b-th sequence
- subsampling_factor: int
* Note that the left_context and right_context is specified in
the input frame-level while input to transformer may already
go through subsampling (e.g., the use of striding in vggblock)
we use subsampling_factor to scale the left/right context
Return:
- a (T, T) binary tensor or None, where T is max(input_lengths)
* if self.transformer_context is None, None
* if left_context is None,
* attn_mask[t, t + right_context + 1:] = 1
* others = 0
* if right_context is None,
* attn_mask[t, 0:t - left_context] = 1
* others = 0
* elsif
* attn_mask[t, t - left_context: t + right_context + 1] = 0
* others = 1
"""
if self.transformer_context is None:
return None
maxT = torch.max(input_lengths).item()
attn_mask = torch.zeros(maxT, maxT)
left_context = self.transformer_context[0]
right_context = self.transformer_context[1]
if left_context is not None:
left_context = math.ceil(self.transformer_context[0] / subsampling_factor)
if right_context is not None:
right_context = math.ceil(self.transformer_context[1] / subsampling_factor)
for t in range(maxT):
if left_context is not None:
st = 0
en = max(st, t - left_context)
attn_mask[t, st:en] = 1
if right_context is not None:
st = t + right_context + 1
st = min(st, maxT - 1)
attn_mask[t, st:] = 1
return attn_mask.to(input_lengths.device)
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
class TransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
"""
def __init__(
self,
dictionary,
embed_dim=512,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
conv_config=DEFAULT_DEC_CONV_CONFIG,
encoder_output_dim=512,
):
super().__init__(dictionary)
vocab_size = len(dictionary)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx)
self.conv_layers = nn.ModuleList()
for i in range(len(conv_config)):
out_channels, kernel_size, layer_norm = conv_config[i]
if i == 0:
conv_layer = LinearizedConv1d(
embed_dim, out_channels, kernel_size, padding=kernel_size - 1
)
else:
conv_layer = LinearizedConv1d(
conv_config[i - 1][0],
out_channels,
kernel_size,
padding=kernel_size - 1,
)
self.conv_layers.append(conv_layer)
if layer_norm:
self.conv_layers.append(nn.LayerNorm(out_channels))
self.conv_layers.append(nn.ReLU())
self.layers = nn.ModuleList()
if conv_config[-1][0] != transformer_config[0][0]:
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
self.layers.append(
TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.layers.append(
TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[i])
)
)
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
target_padding_mask = (
(prev_output_tokens == self.padding_idx).to(prev_output_tokens.device)
if incremental_state is None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
# embed tokens
x = self.embed_tokens(prev_output_tokens)
# B x T x C -> T x B x C
x = self._transpose_if_training(x, incremental_state)
for layer in self.conv_layers:
if isinstance(layer, LinearizedConvolution):
x = layer(x, incremental_state)
else:
x = layer(x)
# B x T x C -> T x B x C
x = self._transpose_if_inference(x, incremental_state)
# decoder layers
for layer in self.layers:
if isinstance(layer, TransformerDecoderLayer):
x, *_ = layer(
x,
(encoder_out["encoder_out"] if encoder_out is not None else None),
(
encoder_out["encoder_padding_mask"].t()
if encoder_out["encoder_padding_mask"] is not None
else None
),
incremental_state,
self_attn_mask=(
self.buffered_future_mask(x)
if incremental_state is None
else None
),
self_attn_padding_mask=(
target_padding_mask if incremental_state is None else None
),
)
else:
x = layer(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
x = self.fc_out(x)
return x, None
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def _transpose_if_training(self, x, incremental_state):
if incremental_state is None:
x = x.transpose(0, 1)
return x
def _transpose_if_inference(self, x, incremental_state):
if incremental_state:
x = x.transpose(0, 1)
return x
@register_model("asr_vggtransformer_encoder")
class VGGTransformerEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--vggblock-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one vggblock
[(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...]
""",
)
parser.add_argument(
"--transformer-enc-config",
type=str,
metavar="EXPR",
help="""
a tuple containing the configuration of the Transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ]""",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="encoder output dimension, projecting the LSTM output",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--transformer-context",
type=str,
metavar="EXPR",
help="""
either None or a tuple of two ints, indicating left/right context a
transformer can have access to""",
)
parser.add_argument(
"--transformer-sampling",
type=str,
metavar="EXPR",
help="""
either None or a tuple of ints, indicating sampling factor in each layer""",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
base_architecture_enconly(args)
encoder = VGGTransformerEncoderOnly(
vocab_size=len(task.target_dictionary),
input_feat_per_channel=args.input_feat_per_channel,
vggblock_config=eval(args.vggblock_enc_config),
transformer_config=eval(args.transformer_enc_config),
encoder_output_dim=args.enc_output_dim,
in_channels=args.in_channels,
transformer_context=eval(args.transformer_context),
transformer_sampling=eval(args.transformer_sampling),
)
return cls(encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (T, B, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
# lprobs is a (T, B, D) tensor
# we need to transoose to get (B, T, D) tensor
lprobs = lprobs.transpose(0, 1).contiguous()
lprobs.batch_first = True
return lprobs
class VGGTransformerEncoderOnly(VGGTransformerEncoder):
def __init__(
self,
vocab_size,
input_feat_per_channel,
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
encoder_output_dim=512,
in_channels=1,
transformer_context=None,
transformer_sampling=None,
):
super().__init__(
input_feat_per_channel=input_feat_per_channel,
vggblock_config=vggblock_config,
transformer_config=transformer_config,
encoder_output_dim=encoder_output_dim,
in_channels=in_channels,
transformer_context=transformer_context,
transformer_sampling=transformer_sampling,
)
self.fc_out = Linear(self.encoder_output_dim, vocab_size)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
enc_out = super().forward(src_tokens, src_lengths)
x = self.fc_out(enc_out["encoder_out"])
# x = F.log_softmax(x, dim=-1)
# Note: no need this line, because model.get_normalized_prob will call
# log_softmax
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B)
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
return (1e6, 1e6) # an arbitrary large number
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
# nn.init.uniform_(m.weight, -0.1, 0.1)
# nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True, dropout=0):
"""Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
# m.weight.data.uniform_(-0.1, 0.1)
# if bias:
# m.bias.data.uniform_(-0.1, 0.1)
return m
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m, dim=2)
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
# seq2seq models
def base_architecture(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.in_channels = getattr(args, "in_channels", 1)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.transformer_dec_config = getattr(
args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG)
args.transformer_context = getattr(args, "transformer_context", "None")
@register_model_architecture("asr_vggtransformer", "vggtransformer_1")
def vggtransformer_1(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_2")
def vggtransformer_2(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_base")
def vggtransformer_base(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12"
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6"
)
# Size estimations:
# Encoder:
# - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K
# Transformer:
# - input dimension adapter: 2560 x 512 -> 1.31M
# - transformer_layers (x12) --> 37.74M
# * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M
# * FFN weight: 512*2048*2 = 2.097M
# - output dimension adapter: 512 x 512 -> 0.26 M
# Decoder:
# - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3
# - transformer_layer: (x6) --> 25.16M
# * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M
# * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M
# * FFN: 512*2048*2 = 2.097M
# Final FC:
# - FC: 512*5000 = 256K (assuming vocab size 5K)
# In total:
# ~65 M
# CTC models
def base_architecture_enconly(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2"
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2"
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.in_channels = getattr(args, "in_channels", 1)
args.transformer_context = getattr(args, "transformer_context", "None")
args.transformer_sampling = getattr(args, "transformer_sampling", "None")
@register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1")
def vggtransformer_enc_1(args):
# vggtransformer_1 is the same as vggtransformer_enc_big, except the number
# of layers is increased to 16
# keep it here for backward compatiablity purpose
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules.fairseq_dropout import FairseqDropout
default_conv_enc_config = """[
(400, 13, 170, 0.2),
(440, 14, 0, 0.214),
(484, 15, 0, 0.22898),
(532, 16, 0, 0.2450086),
(584, 17, 0, 0.262159202),
(642, 18, 0, 0.28051034614),
(706, 19, 0, 0.30014607037),
(776, 20, 0, 0.321156295296),
(852, 21, 0, 0.343637235966),
(936, 22, 0, 0.367691842484),
(1028, 23, 0, 0.393430271458),
(1130, 24, 0, 0.42097039046),
(1242, 25, 0, 0.450438317792),
(1366, 26, 0, 0.481969000038),
(1502, 27, 0, 0.51570683004),
(1652, 28, 0, 0.551806308143),
(1816, 29, 0, 0.590432749713),
]"""
@register_model("asr_w2l_conv_glu_encoder")
class W2lConvGluEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--conv-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one conv layer
[(out_channels, kernel_size, padding, dropout), ...]
""",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
encoder = W2lConvGluEncoder(
vocab_size=len(task.target_dictionary),
input_feat_per_channel=args.input_feat_per_channel,
in_channels=args.in_channels,
conv_enc_config=eval(conv_enc_config),
)
return cls(encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = False
return lprobs
class W2lConvGluEncoder(FairseqEncoder):
def __init__(
self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config
):
super().__init__(None)
self.input_dim = input_feat_per_channel
if in_channels != 1:
raise ValueError("only 1 input channel is currently supported")
self.conv_layers = nn.ModuleList()
self.linear_layers = nn.ModuleList()
self.dropouts = []
cur_channels = input_feat_per_channel
for out_channels, kernel_size, padding, dropout in conv_enc_config:
layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding)
layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init
self.conv_layers.append(nn.utils.weight_norm(layer))
self.dropouts.append(
FairseqDropout(dropout, module_name=self.__class__.__name__)
)
if out_channels % 2 != 0:
raise ValueError("odd # of out_channels is incompatible with GLU")
cur_channels = out_channels // 2 # halved by GLU
for out_channels in [2 * cur_channels, vocab_size]:
layer = nn.Linear(cur_channels, out_channels)
layer.weight.data.mul_(math.sqrt(3))
self.linear_layers.append(nn.utils.weight_norm(layer))
cur_channels = out_channels // 2
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
B, T, _ = src_tokens.size()
x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
x = F.glu(x, dim=1)
x = self.dropouts[layer_idx](x)
x = x.transpose(1, 2).contiguous() # (B, T, 908)
x = self.linear_layers[0](x)
x = F.glu(x, dim=2)
x = self.dropouts[-1](x)
x = self.linear_layers[1](x)
assert x.size(0) == B
assert x.size(1) == T
encoder_out = x.transpose(0, 1) # (T, B, vocab_size)
# need to debug this -- find a simpler/elegant way in pytorch APIs
encoder_padding_mask = (
torch.arange(T).view(1, T).expand(B, -1).to(x.device)
>= src_lengths.view(B, 1).expand(-1, T)
).t() # (B x T) -> (T x B)
return {
"encoder_out": encoder_out, # (T, B, vocab_size)
"encoder_padding_mask": encoder_padding_mask, # (T, B)
}
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return (1e6, 1e6) # an arbitrary large number
@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc")
def w2l_conv_glu_enc(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.in_channels = getattr(args, "in_channels", 1)
args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
# Flashlight Decoder
This script runs decoding for pre-trained speech recognition models.
## Usage
Assuming a few variables:
```bash
checkpoint=<path-to-checkpoint>
data=<path-to-data-directory>
lm_model=<path-to-language-model>
lexicon=<path-to-lexicon>
```
Example usage for decoding a fine-tuned Wav2Vec model:
```bash
python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
task=audio_pretraining \
task.data=$data \
task.labels=ltr \
common_eval.path=$checkpoint \
decoding.type=kenlm \
decoding.lexicon=$lexicon \
decoding.lmpath=$lm_model \
dataset.gen_subset=dev_clean,dev_other,test_clean,test_other
```
Example usage for using Ax to sweep WER parameters (requires `pip install hydra-ax-sweeper`):
```bash
python $FAIRSEQ_ROOT/examples/speech_recognition/new/infer.py --multirun \
hydra/sweeper=ax \
task=audio_pretraining \
task.data=$data \
task.labels=ltr \
common_eval.path=$checkpoint \
decoding.type=kenlm \
decoding.lexicon=$lexicon \
decoding.lmpath=$lm_model \
dataset.gen_subset=dev_other
```
# @package hydra.sweeper
_target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper
max_batch_size: null
ax_config:
max_trials: 128
early_stop:
minimize: true
max_epochs_without_improvement: 32
epsilon: 1.0e-05
experiment:
name: ${dataset.gen_subset}
objective_name: wer
minimize: true
parameter_constraints: null
outcome_constraints: null
status_quo: null
client:
verbose_logging: false
random_seed: null
params:
decoding.lmweight:
type: range
bounds: [0.0, 5.0]
decoding.wordscore:
type: range
bounds: [-5.0, 5.0]
# @package _group_
defaults:
- task: null
- model: null
hydra:
run:
dir: ${common_eval.results_path}/${dataset.gen_subset}
sweep:
dir: ${common_eval.results_path}
subdir: ${dataset.gen_subset}
common_eval:
results_path: null
path: null
post_process: letter
quiet: true
dataset:
max_tokens: 1000000
gen_subset: test
distributed_training:
distributed_world_size: 1
decoding:
beam: 5
type: viterbi
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools as it
from typing import Any, Dict, List
import torch
from fairseq.data.dictionary import Dictionary
from fairseq.models.fairseq_model import FairseqModel
class BaseDecoder:
def __init__(self, tgt_dict: Dictionary) -> None:
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
if "<sep>" in tgt_dict.indices:
self.silence = tgt_dict.index("<sep>")
elif "|" in tgt_dict.indices:
self.silence = tgt_dict.index("|")
else:
self.silence = tgt_dict.eos()
def generate(
self, models: List[FairseqModel], sample: Dict[str, Any], **unused
) -> List[List[Dict[str, torch.LongTensor]]]:
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(
self,
models: List[FairseqModel],
encoder_input: Dict[str, Any],
) -> torch.FloatTensor:
model = models[0]
encoder_out = model(**encoder_input)
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out)
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
def decode(
self,
emissions: torch.FloatTensor,
) -> List[List[Dict[str, torch.LongTensor]]]:
raise NotImplementedError
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
from fairseq.data.dictionary import Dictionary
from .decoder_config import DecoderConfig, FlashlightDecoderConfig
from .base_decoder import BaseDecoder
def Decoder(
cfg: Union[DecoderConfig, FlashlightDecoderConfig], tgt_dict: Dictionary
) -> BaseDecoder:
if cfg.type == "viterbi":
from .viterbi_decoder import ViterbiDecoder
return ViterbiDecoder(tgt_dict)
if cfg.type == "kenlm":
from .flashlight_decoder import KenLMDecoder
return KenLMDecoder(cfg, tgt_dict)
if cfg.type == "fairseqlm":
from .flashlight_decoder import FairseqLMDecoder
return FairseqLMDecoder(cfg, tgt_dict)
raise NotImplementedError(f"Invalid decoder name: {cfg.name}")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import Optional
from fairseq.dataclass.configs import FairseqDataclass
from fairseq.dataclass.constants import ChoiceEnum
from omegaconf import MISSING
DECODER_CHOICES = ChoiceEnum(["viterbi", "kenlm", "fairseqlm"])
@dataclass
class DecoderConfig(FairseqDataclass):
type: DECODER_CHOICES = field(
default="viterbi",
metadata={"help": "The type of decoder to use"},
)
@dataclass
class FlashlightDecoderConfig(FairseqDataclass):
nbest: int = field(
default=1,
metadata={"help": "Number of decodings to return"},
)
unitlm: bool = field(
default=False,
metadata={"help": "If set, use unit language model"},
)
lmpath: str = field(
default=MISSING,
metadata={"help": "Language model for KenLM decoder"},
)
lexicon: Optional[str] = field(
default=None,
metadata={"help": "Lexicon for Flashlight decoder"},
)
beam: int = field(
default=50,
metadata={"help": "Number of beams to use for decoding"},
)
beamthreshold: float = field(
default=50.0,
metadata={"help": "Threshold for beam search decoding"},
)
beamsizetoken: Optional[int] = field(
default=None, metadata={"help": "Beam size to use"}
)
wordscore: float = field(
default=-1,
metadata={"help": "Word score for KenLM decoder"},
)
unkweight: float = field(
default=-math.inf,
metadata={"help": "Unknown weight for KenLM decoder"},
)
silweight: float = field(
default=0,
metadata={"help": "Silence weight for KenLM decoder"},
)
lmweight: float = field(
default=2,
metadata={"help": "Weight for LM while interpolating score"},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment