Commit 33646ac9 authored by Jeff Cai's avatar Jeff Cai Committed by Facebook Github Bot
Browse files

wav2letter integration

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/846

Reviewed By: jcai1

Differential Revision: D17845996

Pulled By: okhonko

fbshipit-source-id: 3826fd9a4418496916bf1835c319dd85c89945cc
parent b6e001f6
......@@ -7,6 +7,7 @@ On top of main fairseq dependencies there are couple more additional requirement
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
## Preparing librispeech data
```
......@@ -30,3 +31,76 @@ python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task sp
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
```
`Sum/Avg` row from first table of the report has WER
## Using wav2letter components
[wav2letter](https://github.com/facebookresearch/wav2letter) now has integration with fairseq. Currently this includes:
* AutoSegmentationCriterion (ASG)
* wav2letter-style Conv/GLU model
* wav2letter's beam search decoder
To use these, follow the instructions at the bottom of [this page](https://github.com/facebookresearch/wav2letter/blob/master/docs/installation.md) to install python bindings. Please note that python bindings are for a *subset* of wav2letter and don't require its full dependencies (notably, `flashlight` and `ArrayFire` are *not* required).
To quickly summarize the instructions: first, install [CUDA](https://developer.nvidia.com/cuda-downloads). Then follow these steps:
```
# additional prerequisites - use equivalents for your distro
sudo apt-get install build-essential cmake libatlas-base-dev libfftw3-dev liblzma-dev libbz2-dev libzstd-dev
# install KenLM from source
git clone https://github.com/kpu/kenlm.git
cd kenlm
mkdir -p build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON
make -j16
cd ..
export KENLM_ROOT_DIR=$(pwd)
cd ..
# install wav2letter python bindings
git clone https://github.com/facebookresearch/wav2letter.git
cd wav2letter/bindings/python
# make sure your python environment is active at this point
pip install torch packaging
pip install -e .
# try some examples to verify installation succeeded
python ./examples/criterion_example.py
python ./examples/decoder_example.py ../../src/decoder/test
python ./examples/feature_example.py ../../src/feature/test/data
```
## Training librispeech data (wav2letter style, Conv/GLU + ASG loss)
Training command:
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
```
Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
## Inference for librispeech (wav2letter decoder, n-gram LM)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a wav2letter-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
```
doorbell D O 1 R B E L 1 ▁
```
For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
```
doorbell ▁DOOR BE LL
doorbell ▁DOOR B E LL
doorbell ▁DO OR BE LL
doorbell ▁DOOR B EL L
doorbell ▁DOOR BE L L
doorbell ▁DO OR B E LL
doorbell ▁DOOR B E L L
doorbell ▁DO OR B EL L
doorbell ▁DO O R BE LL
doorbell ▁DO OR BE L L
```
Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
## Inference for librispeech (wav2letter decoder, viterbi only)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
from examples.speech_recognition.data.replabels import pack_replabels
from wav2letter.criterion import ASGLoss, CriterionScaleMode
@register_criterion("asg_loss")
class ASGCriterion(FairseqCriterion):
@staticmethod
def add_args(parser):
group = parser.add_argument_group("ASG Loss")
group.add_argument(
"--asg-transitions-init",
help="initial diagonal value of transition matrix",
type=float,
default=0.0,
)
group.add_argument(
"--max-replabel", help="maximum # of replabels", type=int, default=2
)
group.add_argument(
"--linseg-updates",
help="# of training updates to use LinSeg initialization",
type=int,
default=0,
)
group.add_argument(
"--hide-linseg-messages",
help="hide messages about LinSeg initialization",
action="store_true",
)
def __init__(self, args, task):
super().__init__(args, task)
self.tgt_dict = task.target_dictionary
self.eos = self.tgt_dict.eos()
self.silence = (
self.tgt_dict.index(args.silence_token)
if args.silence_token in self.tgt_dict
else None
)
self.max_replabel = args.max_replabel
num_labels = len(self.tgt_dict)
self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
self.asg.trans = torch.nn.Parameter(
args.asg_transitions_init * torch.eye(num_labels), requires_grad=True
)
self.linseg_progress = torch.nn.Parameter(
torch.tensor([0], dtype=torch.int), requires_grad=False
)
self.linseg_maximum = args.linseg_updates
self.linseg_message_state = "none" if args.hide_linseg_messages else "start"
def linseg_step(self):
if not self.training:
return False
if self.linseg_progress.item() < self.linseg_maximum:
if self.linseg_message_state == "start":
print("| using LinSeg to initialize ASG")
self.linseg_message_state = "finish"
self.linseg_progress.add_(1)
return True
elif self.linseg_message_state == "finish":
print("| finished LinSeg initialization")
self.linseg_message_state = "none"
return False
def replace_eos_with_silence(self, tgt):
if tgt[-1] != self.eos:
return tgt
elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
return tgt[:-1]
else:
return tgt[:-1] + [self.silence]
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
B = emissions.size(0)
T = emissions.size(1)
device = emissions.device
target = torch.IntTensor(B, T)
target_size = torch.IntTensor(B)
using_linseg = self.linseg_step()
for b in range(B):
initial_target_size = sample["target_lengths"][b].item()
if initial_target_size == 0:
raise ValueError("target size cannot be zero")
tgt = sample["target"][b, :initial_target_size].tolist()
tgt = self.replace_eos_with_silence(tgt)
tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
tgt = tgt[:T]
if using_linseg:
tgt = [tgt[t * len(tgt) // T] for t in range(T)]
target[b][: len(tgt)] = torch.IntTensor(tgt)
target_size[b] = len(tgt)
loss = self.asg.forward(emissions, target.to(device), target_size.to(device))
if reduce:
loss = torch.sum(loss)
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / nsentences,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return agg_output
import importlib
import os
# ASG loss requires wav2letter
blacklist = set()
try:
import wav2letter
except ImportError:
blacklist.add("ASG_loss.py")
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
criterion_name = file[:file.find('.py')]
importlib.import_module('examples.speech_recognition.criterions.' + criterion_name)
if file.endswith(".py") and not file.startswith("_") and file not in blacklist:
criterion_name = file[: file.find(".py")]
importlib.import_module(
"examples.speech_recognition.criterions." + criterion_name
)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Replabel transforms for use with wav2letter's ASG criterion.
"""
def replabel_symbol(i):
"""
Replabel symbols used in wav2letter, currently just "1", "2", ...
This prevents training with numeral tokens, so this might change in the future
"""
return str(i)
def pack_replabels(tokens, dictionary, max_reps):
"""
Pack a token sequence so that repeated symbols are replaced by replabels
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_value_to_idx = [0] * (max_reps + 1)
for i in range(1, max_reps + 1):
replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i))
result = []
prev_token = -1
num_reps = 0
for token in tokens:
if token == prev_token and num_reps < max_reps:
num_reps += 1
else:
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
num_reps = 0
result.append(token)
prev_token = token
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
return result
def unpack_replabels(tokens, dictionary, max_reps):
"""
Unpack a token sequence so that replabels are replaced by repeated symbols
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_idx_to_value = {}
for i in range(1, max_reps + 1):
replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i
result = []
prev_token = -1
for token in tokens:
try:
for _ in range(replabel_idx_to_value[token]):
result.append(prev_token)
prev_token = -1
except KeyError:
result.append(token)
prev_token = token
return result
......@@ -9,11 +9,12 @@ Run inference for pre-processed data with a trained model.
"""
import logging
import math
import os
import sentencepiece as spm
import torch
from fairseq import options, progress_bar, utils, tasks
from fairseq import checkpoint_utils, options, progress_bar, utils, tasks
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module
......@@ -23,8 +24,6 @@ logger.setLevel(logging.INFO)
def add_asr_eval_argument(parser):
parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
......@@ -36,14 +35,24 @@ def add_asr_eval_argument(parser):
output units",
)
parser.add_argument(
"--lm-weight",
"--lm_weight",
type=float,
default=0.2,
help="weight for wfstlm while interpolating\
with neural score",
help="weight for lm while interpolating with neural score",
)
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
parser.add_argument(
"--w2l-decoder", choices=["viterbi", "kenlm"], help="use a w2l decoder"
)
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
parser.add_argument("--kenlm-model", help="kenlm model for w2l decoder")
parser.add_argument("--beam-threshold", type=float, default=25.0)
parser.add_argument("--word-score", type=float, default=1.0)
parser.add_argument("--unk-weight", type=float, default=-math.inf)
parser.add_argument("--sil-weight", type=float, default=0.0)
return parser
......@@ -72,29 +81,21 @@ def get_dataset_itr(args, task):
).next_epoch_itr(shuffle=False)
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id):
def process_predictions(
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
):
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
hyp_words = sp.DecodePieces(hyp_pieces.split())
print(
"{} ({}-{})".format(hyp_pieces, speaker, id),
file=res_files["hypo.units"],
)
print(
"{} ({}-{})".format(hyp_words, speaker, id),
file=res_files["hypo.words"],
"{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"]
)
print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"])
tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = sp.DecodePieces(tgt_pieces.split())
print(
"{} ({}-{})".format(tgt_pieces, speaker, id),
file=res_files["ref.units"],
)
print(
"{} ({}-{})".format(tgt_words, speaker, id),
file=res_files["ref.words"],
)
print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"])
print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"])
# only score top hypothesis
if not args.quiet:
logger.debug("HYPO:" + hyp_words)
......@@ -120,6 +121,30 @@ def prepare_result_files(args):
}
def load_models_and_criterions(filenames, arg_overrides=None, task=None):
models = []
criterions = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides)
args = state["args"]
if task is None:
task = tasks.setup_task(args)
# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state["model"], strict=True)
models.append(model)
criterion = task.build_criterion(args)
if "criterion" in state:
criterion.load_state_dict(state["criterion"], strict=True)
criterions.append(criterion)
return models, criterions, args
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation
"""
......@@ -156,22 +181,22 @@ def main(args):
# Set dictionary
tgt_dict = task.target_dictionary
if args.ctc or args.rnnt:
tgt_dict.add_symbol("<ctc_blank>")
if args.ctc:
logger.info("| decoding a ctc model")
if args.rnnt:
logger.info("| decoding a rnnt model")
logger.info("| decoding with criterion {}".format(args.criterion))
# Load ensemble
logger.info("| loading model(s) from {}".format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
models, criterions, _model_args = load_models_and_criterions(
args.path.split(":"),
task,
model_arg_overrides=eval(args.model_overrides), # noqa
arg_overrides=eval(args.model_overrides), # noqa
task=task,
)
optimize_models(args, use_cuda, models)
# hack to pass transitions to W2lDecoder
if args.criterion == "asg_loss":
trans = criterions[0].asg.trans.data
args.asg_transitions = torch.flatten(trans).tolist()
# Load dataset (possibly sharded)
itr = get_dataset_itr(args, task)
......@@ -185,7 +210,7 @@ def main(args):
os.makedirs(args.results_path)
sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(args.data, 'spm.model'))
sp.Load(os.path.join(args.data, "spm.model"))
res_files = prepare_result_files(args)
with progress_bar.build_progress_bar(args, itr) as t:
......@@ -204,7 +229,7 @@ def main(args):
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample['id'].tolist()):
for i, sample_id in enumerate(sample["id"].tolist()):
speaker = task.dataset(args.gen_subset).speakers[int(sample_id)]
id = task.dataset(args.gen_subset).ids[int(sample_id)]
target_tokens = (
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
default_conv_enc_config = """[
(400, 13, 170, 0.2),
(440, 14, 0, 0.214),
(484, 15, 0, 0.22898),
(532, 16, 0, 0.2450086),
(584, 17, 0, 0.262159202),
(642, 18, 0, 0.28051034614),
(706, 19, 0, 0.30014607037),
(776, 20, 0, 0.321156295296),
(852, 21, 0, 0.343637235966),
(936, 22, 0, 0.367691842484),
(1028, 23, 0, 0.393430271458),
(1130, 24, 0, 0.42097039046),
(1242, 25, 0, 0.450438317792),
(1366, 26, 0, 0.481969000038),
(1502, 27, 0, 0.51570683004),
(1652, 28, 0, 0.551806308143),
(1816, 29, 0, 0.590432749713),
]"""
@register_model("asr_w2l_conv_glu_encoder")
class W2lConvGluEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--conv-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one conv layer
[(out_channels, kernel_size, padding, dropout), ...]
""",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
encoder = W2lConvGluEncoder(
vocab_size=len(task.target_dictionary),
input_feat_per_channel=args.input_feat_per_channel,
in_channels=args.in_channels,
conv_enc_config=eval(conv_enc_config),
)
return cls(encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = False
return lprobs
class W2lConvGluEncoder(FairseqEncoder):
def __init__(
self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config
):
super().__init__(None)
self.input_dim = input_feat_per_channel
if in_channels != 1:
raise ValueError("only 1 input channel is currently supported")
self.conv_layers = nn.ModuleList()
self.linear_layers = nn.ModuleList()
self.dropouts = []
cur_channels = input_feat_per_channel
for out_channels, kernel_size, padding, dropout in conv_enc_config:
layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding)
layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init
self.conv_layers.append(nn.utils.weight_norm(layer))
self.dropouts.append(dropout)
if out_channels % 2 != 0:
raise ValueError("odd # of out_channels is incompatible with GLU")
cur_channels = out_channels // 2 # halved by GLU
for out_channels in [2 * cur_channels, vocab_size]:
layer = nn.Linear(cur_channels, out_channels)
layer.weight.data.mul_(math.sqrt(3))
self.linear_layers.append(nn.utils.weight_norm(layer))
cur_channels = out_channels // 2
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
B, T, _ = src_tokens.size()
x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
x = F.glu(x, dim=1)
x = F.dropout(x, p=self.dropouts[layer_idx], training=self.training)
x = x.transpose(1, 2).contiguous() # (B, T, 908)
x = self.linear_layers[0](x)
x = F.glu(x, dim=2)
x = F.dropout(x, p=self.dropouts[-1])
x = self.linear_layers[1](x)
assert x.size(0) == B
assert x.size(1) == T
encoder_out = x.transpose(0, 1) # (T, B, vocab_size)
# need to debug this -- find a simpler/elegant way in pytorch APIs
encoder_padding_mask = (
torch.arange(T).view(1, T).expand(B, -1).to(x.device)
>= src_lengths.view(B, 1).expand(-1, T)
).t() # (B x T) -> (T x B)
return {
"encoder_out": encoder_out, # (T, B, vocab_size)
"encoder_padding_mask": encoder_padding_mask, # (T, B)
}
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return (1e6, 1e6) # an arbitrary large number
@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc")
def w2l_conv_glu_enc(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.in_channels = getattr(args, "in_channels", 1)
args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
......@@ -11,6 +11,7 @@ import torch
from fairseq.data import Dictionary
from fairseq.tasks import FairseqTask, register_task
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
def get_asr_dataset_from_json(data_json_path, tgt_dict):
......@@ -55,16 +56,12 @@ def get_asr_dataset_from_json(data_json_path, tgt_dict):
speakers.append(m.group(1) + "_" + m.group(2))
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
tgt = [
torch.LongTensor(
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
)
torch.LongTensor([int(i) for i in s[1]["output"]["tokenid"].split(", ")])
for s in sorted_samples
]
# append eos
tgt = [torch.cat([t, torch.LongTensor([tgt_dict.eos()])]) for t in tgt]
return AsrDataset(
aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers
)
return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
@register_task("speech_recognition")
......@@ -77,6 +74,9 @@ class SpeechRecognitionTask(FairseqTask):
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("data", help="path to data directory")
parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
)
def __init__(self, args, tgt_dict):
super().__init__(args)
......@@ -90,6 +90,12 @@ class SpeechRecognitionTask(FairseqTask):
raise FileNotFoundError("Dict not found: {}".format(dict_path))
tgt_dict = Dictionary.load(dict_path)
if args.criterion == "ctc_loss":
tgt_dict.add_symbol("<ctc_blank>")
elif args.criterion == "asg_loss":
for i in range(1, args.max_replabel + 1):
tgt_dict.add_symbol(replabel_symbol(i))
print("| dictionary: {} types".format(len(tgt_dict)))
return cls(args, tgt_dict)
......@@ -100,8 +106,20 @@ class SpeechRecognitionTask(FairseqTask):
split (str): name of the split (e.g., train, valid, test)
"""
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
self.datasets[split] = get_asr_dataset_from_json(
data_json_path, self.tgt_dict)
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
def build_generator(self, args):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, self.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, self.target_dictionary)
else:
return super().build_generator(args)
@property
def target_dictionary(self):
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wav2letter decoders.
"""
import math
import itertools as it
import torch
from fairseq import utils
from examples.speech_recognition.data.replabels import unpack_replabels
from wav2letter.common import create_word_dict, load_words
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from wav2letter.decoder import (
CriterionType,
DecoderOptions,
KenLM,
SmearingMode,
Trie,
WordLMDecoder,
)
class W2lDecoder(object):
def __init__(self, args, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = args.nbest
# criterion-specific init
if args.criterion == "ctc_loss":
self.criterion_type = CriterionType.CTC
self.blank = tgt_dict.index("<ctc_blank>")
self.asg_transitions = None
elif args.criterion == "asg_loss":
self.criterion_type = CriterionType.ASG
self.blank = -1
self.asg_transitions = args.asg_transitions
self.max_replabel = args.max_replabel
assert len(self.asg_transitions) == self.vocab_size ** 2
else:
raise RuntimeError(f"unknown criterion: {args.criterion}")
def generate(self, models, sample, prefix_tokens=None):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
encoder_out = models[0].encoder(**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
elif self.criterion_type == CriterionType.ASG:
emissions = encoder_out["encoder_out"]
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x >= 0, idxs)
if self.criterion_type == CriterionType.CTC:
idxs = filter(lambda x: x != self.blank, idxs)
elif self.criterion_type == CriterionType.ASG:
idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]
class W2lKenLMDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.silence = tgt_dict.index(args.silence_token)
self.lexicon = load_words(args.lexicon)
self.word_dict = create_word_dict(self.lexicon)
self.unk_word = self.word_dict.get_index("<unk>")
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.trie = Trie(self.vocab_size, self.silence)
start_state = self.lm.start(False)
for word, spellings in self.lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = self.lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = DecoderOptions(
args.beam,
args.beam_threshold,
args.lm_weight,
args.word_score,
args.unk_weight,
False,
args.sil_weight,
self.criterion_type,
)
self.decoder = WordLMDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
self.asg_transitions,
)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
nbest_results = self.decoder.decode(emissions_ptr, T, N)[: self.nbest]
hypos.append(
[
{"tokens": self.get_tokens(result.tokens), "score": result.score}
for result in nbest_results
]
)
return hypos
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment