"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1e216be8950d22813b975a4d026e58011cd78162"
Commit 33646ac9 authored by Jeff Cai's avatar Jeff Cai Committed by Facebook Github Bot
Browse files

wav2letter integration

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/846

Reviewed By: jcai1

Differential Revision: D17845996

Pulled By: okhonko

fbshipit-source-id: 3826fd9a4418496916bf1835c319dd85c89945cc
parent b6e001f6
...@@ -7,6 +7,7 @@ On top of main fairseq dependencies there are couple more additional requirement ...@@ -7,6 +7,7 @@ On top of main fairseq dependencies there are couple more additional requirement
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features. 1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency. 2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
## Preparing librispeech data ## Preparing librispeech data
``` ```
...@@ -30,3 +31,76 @@ python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task sp ...@@ -30,3 +31,76 @@ python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task sp
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
``` ```
`Sum/Avg` row from first table of the report has WER `Sum/Avg` row from first table of the report has WER
## Using wav2letter components
[wav2letter](https://github.com/facebookresearch/wav2letter) now has integration with fairseq. Currently this includes:
* AutoSegmentationCriterion (ASG)
* wav2letter-style Conv/GLU model
* wav2letter's beam search decoder
To use these, follow the instructions at the bottom of [this page](https://github.com/facebookresearch/wav2letter/blob/master/docs/installation.md) to install python bindings. Please note that python bindings are for a *subset* of wav2letter and don't require its full dependencies (notably, `flashlight` and `ArrayFire` are *not* required).
To quickly summarize the instructions: first, install [CUDA](https://developer.nvidia.com/cuda-downloads). Then follow these steps:
```
# additional prerequisites - use equivalents for your distro
sudo apt-get install build-essential cmake libatlas-base-dev libfftw3-dev liblzma-dev libbz2-dev libzstd-dev
# install KenLM from source
git clone https://github.com/kpu/kenlm.git
cd kenlm
mkdir -p build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON
make -j16
cd ..
export KENLM_ROOT_DIR=$(pwd)
cd ..
# install wav2letter python bindings
git clone https://github.com/facebookresearch/wav2letter.git
cd wav2letter/bindings/python
# make sure your python environment is active at this point
pip install torch packaging
pip install -e .
# try some examples to verify installation succeeded
python ./examples/criterion_example.py
python ./examples/decoder_example.py ../../src/decoder/test
python ./examples/feature_example.py ../../src/feature/test/data
```
## Training librispeech data (wav2letter style, Conv/GLU + ASG loss)
Training command:
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
```
Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
## Inference for librispeech (wav2letter decoder, n-gram LM)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a wav2letter-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
```
doorbell D O 1 R B E L 1 ▁
```
For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
```
doorbell ▁DOOR BE LL
doorbell ▁DOOR B E LL
doorbell ▁DO OR BE LL
doorbell ▁DOOR B EL L
doorbell ▁DOOR BE L L
doorbell ▁DO OR B E LL
doorbell ▁DOOR B E L L
doorbell ▁DO OR B EL L
doorbell ▁DO O R BE LL
doorbell ▁DO OR BE L L
```
Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
## Inference for librispeech (wav2letter decoder, viterbi only)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
from examples.speech_recognition.data.replabels import pack_replabels
from wav2letter.criterion import ASGLoss, CriterionScaleMode
@register_criterion("asg_loss")
class ASGCriterion(FairseqCriterion):
@staticmethod
def add_args(parser):
group = parser.add_argument_group("ASG Loss")
group.add_argument(
"--asg-transitions-init",
help="initial diagonal value of transition matrix",
type=float,
default=0.0,
)
group.add_argument(
"--max-replabel", help="maximum # of replabels", type=int, default=2
)
group.add_argument(
"--linseg-updates",
help="# of training updates to use LinSeg initialization",
type=int,
default=0,
)
group.add_argument(
"--hide-linseg-messages",
help="hide messages about LinSeg initialization",
action="store_true",
)
def __init__(self, args, task):
super().__init__(args, task)
self.tgt_dict = task.target_dictionary
self.eos = self.tgt_dict.eos()
self.silence = (
self.tgt_dict.index(args.silence_token)
if args.silence_token in self.tgt_dict
else None
)
self.max_replabel = args.max_replabel
num_labels = len(self.tgt_dict)
self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
self.asg.trans = torch.nn.Parameter(
args.asg_transitions_init * torch.eye(num_labels), requires_grad=True
)
self.linseg_progress = torch.nn.Parameter(
torch.tensor([0], dtype=torch.int), requires_grad=False
)
self.linseg_maximum = args.linseg_updates
self.linseg_message_state = "none" if args.hide_linseg_messages else "start"
def linseg_step(self):
if not self.training:
return False
if self.linseg_progress.item() < self.linseg_maximum:
if self.linseg_message_state == "start":
print("| using LinSeg to initialize ASG")
self.linseg_message_state = "finish"
self.linseg_progress.add_(1)
return True
elif self.linseg_message_state == "finish":
print("| finished LinSeg initialization")
self.linseg_message_state = "none"
return False
def replace_eos_with_silence(self, tgt):
if tgt[-1] != self.eos:
return tgt
elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
return tgt[:-1]
else:
return tgt[:-1] + [self.silence]
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
B = emissions.size(0)
T = emissions.size(1)
device = emissions.device
target = torch.IntTensor(B, T)
target_size = torch.IntTensor(B)
using_linseg = self.linseg_step()
for b in range(B):
initial_target_size = sample["target_lengths"][b].item()
if initial_target_size == 0:
raise ValueError("target size cannot be zero")
tgt = sample["target"][b, :initial_target_size].tolist()
tgt = self.replace_eos_with_silence(tgt)
tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
tgt = tgt[:T]
if using_linseg:
tgt = [tgt[t * len(tgt) // T] for t in range(T)]
target[b][: len(tgt)] = torch.IntTensor(tgt)
target_size[b] = len(tgt)
loss = self.asg.forward(emissions, target.to(device), target_size.to(device))
if reduce:
loss = torch.sum(loss)
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / nsentences,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return agg_output
import importlib import importlib
import os import os
# ASG loss requires wav2letter
blacklist = set()
try:
import wav2letter
except ImportError:
blacklist.add("ASG_loss.py")
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith(".py") and not file.startswith("_") and file not in blacklist:
criterion_name = file[:file.find('.py')] criterion_name = file[: file.find(".py")]
importlib.import_module('examples.speech_recognition.criterions.' + criterion_name) importlib.import_module(
"examples.speech_recognition.criterions." + criterion_name
)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Replabel transforms for use with wav2letter's ASG criterion.
"""
def replabel_symbol(i):
"""
Replabel symbols used in wav2letter, currently just "1", "2", ...
This prevents training with numeral tokens, so this might change in the future
"""
return str(i)
def pack_replabels(tokens, dictionary, max_reps):
"""
Pack a token sequence so that repeated symbols are replaced by replabels
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_value_to_idx = [0] * (max_reps + 1)
for i in range(1, max_reps + 1):
replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i))
result = []
prev_token = -1
num_reps = 0
for token in tokens:
if token == prev_token and num_reps < max_reps:
num_reps += 1
else:
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
num_reps = 0
result.append(token)
prev_token = token
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
return result
def unpack_replabels(tokens, dictionary, max_reps):
"""
Unpack a token sequence so that replabels are replaced by repeated symbols
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_idx_to_value = {}
for i in range(1, max_reps + 1):
replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i
result = []
prev_token = -1
for token in tokens:
try:
for _ in range(replabel_idx_to_value[token]):
result.append(prev_token)
prev_token = -1
except KeyError:
result.append(token)
prev_token = token
return result
...@@ -9,11 +9,12 @@ Run inference for pre-processed data with a trained model. ...@@ -9,11 +9,12 @@ Run inference for pre-processed data with a trained model.
""" """
import logging import logging
import math
import os import os
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from fairseq import options, progress_bar, utils, tasks from fairseq import checkpoint_utils, options, progress_bar, utils, tasks
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module from fairseq.utils import import_user_module
...@@ -23,8 +24,6 @@ logger.setLevel(logging.INFO) ...@@ -23,8 +24,6 @@ logger.setLevel(logging.INFO)
def add_asr_eval_argument(parser): def add_asr_eval_argument(parser):
parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
parser.add_argument("--kspmodel", default=None, help="sentence piece model") parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument( parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units" "--wfstlm", default=None, help="wfstlm on dictonary output units"
...@@ -36,14 +35,24 @@ def add_asr_eval_argument(parser): ...@@ -36,14 +35,24 @@ def add_asr_eval_argument(parser):
output units", output units",
) )
parser.add_argument( parser.add_argument(
"--lm-weight",
"--lm_weight", "--lm_weight",
type=float,
default=0.2, default=0.2,
help="weight for wfstlm while interpolating\ help="weight for lm while interpolating with neural score",
with neural score",
) )
parser.add_argument( parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
) )
parser.add_argument(
"--w2l-decoder", choices=["viterbi", "kenlm"], help="use a w2l decoder"
)
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
parser.add_argument("--kenlm-model", help="kenlm model for w2l decoder")
parser.add_argument("--beam-threshold", type=float, default=25.0)
parser.add_argument("--word-score", type=float, default=1.0)
parser.add_argument("--unk-weight", type=float, default=-math.inf)
parser.add_argument("--sil-weight", type=float, default=0.0)
return parser return parser
...@@ -72,29 +81,21 @@ def get_dataset_itr(args, task): ...@@ -72,29 +81,21 @@ def get_dataset_itr(args, task):
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id): def process_predictions(
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
):
for hypo in hypos[: min(len(hypos), args.nbest)]: for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
hyp_words = sp.DecodePieces(hyp_pieces.split()) hyp_words = sp.DecodePieces(hyp_pieces.split())
print( print(
"{} ({}-{})".format(hyp_pieces, speaker, id), "{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"]
file=res_files["hypo.units"],
)
print(
"{} ({}-{})".format(hyp_words, speaker, id),
file=res_files["hypo.words"],
) )
print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"])
tgt_pieces = tgt_dict.string(target_tokens) tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = sp.DecodePieces(tgt_pieces.split()) tgt_words = sp.DecodePieces(tgt_pieces.split())
print( print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"])
"{} ({}-{})".format(tgt_pieces, speaker, id), print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"])
file=res_files["ref.units"],
)
print(
"{} ({}-{})".format(tgt_words, speaker, id),
file=res_files["ref.words"],
)
# only score top hypothesis # only score top hypothesis
if not args.quiet: if not args.quiet:
logger.debug("HYPO:" + hyp_words) logger.debug("HYPO:" + hyp_words)
...@@ -120,6 +121,30 @@ def prepare_result_files(args): ...@@ -120,6 +121,30 @@ def prepare_result_files(args):
} }
def load_models_and_criterions(filenames, arg_overrides=None, task=None):
models = []
criterions = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides)
args = state["args"]
if task is None:
task = tasks.setup_task(args)
# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state["model"], strict=True)
models.append(model)
criterion = task.build_criterion(args)
if "criterion" in state:
criterion.load_state_dict(state["criterion"], strict=True)
criterions.append(criterion)
return models, criterions, args
def optimize_models(args, use_cuda, models): def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation """Optimize ensemble for generation
""" """
...@@ -156,22 +181,22 @@ def main(args): ...@@ -156,22 +181,22 @@ def main(args):
# Set dictionary # Set dictionary
tgt_dict = task.target_dictionary tgt_dict = task.target_dictionary
if args.ctc or args.rnnt: logger.info("| decoding with criterion {}".format(args.criterion))
tgt_dict.add_symbol("<ctc_blank>")
if args.ctc:
logger.info("| decoding a ctc model")
if args.rnnt:
logger.info("| decoding a rnnt model")
# Load ensemble # Load ensemble
logger.info("| loading model(s) from {}".format(args.path)) logger.info("| loading model(s) from {}".format(args.path))
models, _model_args = utils.load_ensemble_for_inference( models, criterions, _model_args = load_models_and_criterions(
args.path.split(":"), args.path.split(":"),
task, arg_overrides=eval(args.model_overrides), # noqa
model_arg_overrides=eval(args.model_overrides), # noqa task=task,
) )
optimize_models(args, use_cuda, models) optimize_models(args, use_cuda, models)
# hack to pass transitions to W2lDecoder
if args.criterion == "asg_loss":
trans = criterions[0].asg.trans.data
args.asg_transitions = torch.flatten(trans).tolist()
# Load dataset (possibly sharded) # Load dataset (possibly sharded)
itr = get_dataset_itr(args, task) itr = get_dataset_itr(args, task)
...@@ -185,7 +210,7 @@ def main(args): ...@@ -185,7 +210,7 @@ def main(args):
os.makedirs(args.results_path) os.makedirs(args.results_path)
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(args.data, 'spm.model')) sp.Load(os.path.join(args.data, "spm.model"))
res_files = prepare_result_files(args) res_files = prepare_result_files(args)
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
...@@ -204,7 +229,7 @@ def main(args): ...@@ -204,7 +229,7 @@ def main(args):
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens) gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample['id'].tolist()): for i, sample_id in enumerate(sample["id"].tolist()):
speaker = task.dataset(args.gen_subset).speakers[int(sample_id)] speaker = task.dataset(args.gen_subset).speakers[int(sample_id)]
id = task.dataset(args.gen_subset).ids[int(sample_id)] id = task.dataset(args.gen_subset).ids[int(sample_id)]
target_tokens = ( target_tokens = (
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
default_conv_enc_config = """[
(400, 13, 170, 0.2),
(440, 14, 0, 0.214),
(484, 15, 0, 0.22898),
(532, 16, 0, 0.2450086),
(584, 17, 0, 0.262159202),
(642, 18, 0, 0.28051034614),
(706, 19, 0, 0.30014607037),
(776, 20, 0, 0.321156295296),
(852, 21, 0, 0.343637235966),
(936, 22, 0, 0.367691842484),
(1028, 23, 0, 0.393430271458),
(1130, 24, 0, 0.42097039046),
(1242, 25, 0, 0.450438317792),
(1366, 26, 0, 0.481969000038),
(1502, 27, 0, 0.51570683004),
(1652, 28, 0, 0.551806308143),
(1816, 29, 0, 0.590432749713),
]"""
@register_model("asr_w2l_conv_glu_encoder")
class W2lConvGluEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--conv-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one conv layer
[(out_channels, kernel_size, padding, dropout), ...]
""",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
encoder = W2lConvGluEncoder(
vocab_size=len(task.target_dictionary),
input_feat_per_channel=args.input_feat_per_channel,
in_channels=args.in_channels,
conv_enc_config=eval(conv_enc_config),
)
return cls(encoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = False
return lprobs
class W2lConvGluEncoder(FairseqEncoder):
def __init__(
self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config
):
super().__init__(None)
self.input_dim = input_feat_per_channel
if in_channels != 1:
raise ValueError("only 1 input channel is currently supported")
self.conv_layers = nn.ModuleList()
self.linear_layers = nn.ModuleList()
self.dropouts = []
cur_channels = input_feat_per_channel
for out_channels, kernel_size, padding, dropout in conv_enc_config:
layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding)
layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init
self.conv_layers.append(nn.utils.weight_norm(layer))
self.dropouts.append(dropout)
if out_channels % 2 != 0:
raise ValueError("odd # of out_channels is incompatible with GLU")
cur_channels = out_channels // 2 # halved by GLU
for out_channels in [2 * cur_channels, vocab_size]:
layer = nn.Linear(cur_channels, out_channels)
layer.weight.data.mul_(math.sqrt(3))
self.linear_layers.append(nn.utils.weight_norm(layer))
cur_channels = out_channels // 2
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
B, T, _ = src_tokens.size()
x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
x = F.glu(x, dim=1)
x = F.dropout(x, p=self.dropouts[layer_idx], training=self.training)
x = x.transpose(1, 2).contiguous() # (B, T, 908)
x = self.linear_layers[0](x)
x = F.glu(x, dim=2)
x = F.dropout(x, p=self.dropouts[-1])
x = self.linear_layers[1](x)
assert x.size(0) == B
assert x.size(1) == T
encoder_out = x.transpose(0, 1) # (T, B, vocab_size)
# need to debug this -- find a simpler/elegant way in pytorch APIs
encoder_padding_mask = (
torch.arange(T).view(1, T).expand(B, -1).to(x.device)
>= src_lengths.view(B, 1).expand(-1, T)
).t() # (B x T) -> (T x B)
return {
"encoder_out": encoder_out, # (T, B, vocab_size)
"encoder_padding_mask": encoder_padding_mask, # (T, B)
}
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
return (1e6, 1e6) # an arbitrary large number
@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc")
def w2l_conv_glu_enc(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.in_channels = getattr(args, "in_channels", 1)
args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.tasks import FairseqTask, register_task from fairseq.tasks import FairseqTask, register_task
from examples.speech_recognition.data import AsrDataset from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
def get_asr_dataset_from_json(data_json_path, tgt_dict): def get_asr_dataset_from_json(data_json_path, tgt_dict):
...@@ -55,16 +56,12 @@ def get_asr_dataset_from_json(data_json_path, tgt_dict): ...@@ -55,16 +56,12 @@ def get_asr_dataset_from_json(data_json_path, tgt_dict):
speakers.append(m.group(1) + "_" + m.group(2)) speakers.append(m.group(1) + "_" + m.group(2))
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples] frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
tgt = [ tgt = [
torch.LongTensor( torch.LongTensor([int(i) for i in s[1]["output"]["tokenid"].split(", ")])
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
)
for s in sorted_samples for s in sorted_samples
] ]
# append eos # append eos
tgt = [torch.cat([t, torch.LongTensor([tgt_dict.eos()])]) for t in tgt] tgt = [torch.cat([t, torch.LongTensor([tgt_dict.eos()])]) for t in tgt]
return AsrDataset( return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers
)
@register_task("speech_recognition") @register_task("speech_recognition")
...@@ -77,6 +74,9 @@ class SpeechRecognitionTask(FairseqTask): ...@@ -77,6 +74,9 @@ class SpeechRecognitionTask(FairseqTask):
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument("data", help="path to data directory") parser.add_argument("data", help="path to data directory")
parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
)
def __init__(self, args, tgt_dict): def __init__(self, args, tgt_dict):
super().__init__(args) super().__init__(args)
...@@ -90,6 +90,12 @@ class SpeechRecognitionTask(FairseqTask): ...@@ -90,6 +90,12 @@ class SpeechRecognitionTask(FairseqTask):
raise FileNotFoundError("Dict not found: {}".format(dict_path)) raise FileNotFoundError("Dict not found: {}".format(dict_path))
tgt_dict = Dictionary.load(dict_path) tgt_dict = Dictionary.load(dict_path)
if args.criterion == "ctc_loss":
tgt_dict.add_symbol("<ctc_blank>")
elif args.criterion == "asg_loss":
for i in range(1, args.max_replabel + 1):
tgt_dict.add_symbol(replabel_symbol(i))
print("| dictionary: {} types".format(len(tgt_dict))) print("| dictionary: {} types".format(len(tgt_dict)))
return cls(args, tgt_dict) return cls(args, tgt_dict)
...@@ -100,8 +106,20 @@ class SpeechRecognitionTask(FairseqTask): ...@@ -100,8 +106,20 @@ class SpeechRecognitionTask(FairseqTask):
split (str): name of the split (e.g., train, valid, test) split (str): name of the split (e.g., train, valid, test)
""" """
data_json_path = os.path.join(self.args.data, "{}.json".format(split)) data_json_path = os.path.join(self.args.data, "{}.json".format(split))
self.datasets[split] = get_asr_dataset_from_json( self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
data_json_path, self.tgt_dict)
def build_generator(self, args):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, self.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, self.target_dictionary)
else:
return super().build_generator(args)
@property @property
def target_dictionary(self): def target_dictionary(self):
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wav2letter decoders.
"""
import math
import itertools as it
import torch
from fairseq import utils
from examples.speech_recognition.data.replabels import unpack_replabels
from wav2letter.common import create_word_dict, load_words
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from wav2letter.decoder import (
CriterionType,
DecoderOptions,
KenLM,
SmearingMode,
Trie,
WordLMDecoder,
)
class W2lDecoder(object):
def __init__(self, args, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = args.nbest
# criterion-specific init
if args.criterion == "ctc_loss":
self.criterion_type = CriterionType.CTC
self.blank = tgt_dict.index("<ctc_blank>")
self.asg_transitions = None
elif args.criterion == "asg_loss":
self.criterion_type = CriterionType.ASG
self.blank = -1
self.asg_transitions = args.asg_transitions
self.max_replabel = args.max_replabel
assert len(self.asg_transitions) == self.vocab_size ** 2
else:
raise RuntimeError(f"unknown criterion: {args.criterion}")
def generate(self, models, sample, prefix_tokens=None):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
encoder_out = models[0].encoder(**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
elif self.criterion_type == CriterionType.ASG:
emissions = encoder_out["encoder_out"]
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x >= 0, idxs)
if self.criterion_type == CriterionType.CTC:
idxs = filter(lambda x: x != self.blank, idxs)
elif self.criterion_type == CriterionType.ASG:
idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]
class W2lKenLMDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.silence = tgt_dict.index(args.silence_token)
self.lexicon = load_words(args.lexicon)
self.word_dict = create_word_dict(self.lexicon)
self.unk_word = self.word_dict.get_index("<unk>")
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.trie = Trie(self.vocab_size, self.silence)
start_state = self.lm.start(False)
for word, spellings in self.lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = self.lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = DecoderOptions(
args.beam,
args.beam_threshold,
args.lm_weight,
args.word_score,
args.unk_weight,
False,
args.sil_weight,
self.criterion_type,
)
self.decoder = WordLMDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
self.asg_transitions,
)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
nbest_results = self.decoder.decode(emissions_ptr, T, N)[: self.nbest]
hypos.append(
[
{"tokens": self.get_tokens(result.tokens), "score": result.score}
for result in nbest_results
]
)
return hypos
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment