Commit 799a38c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #616 failed with stages
in 0 seconds
from typing import Optional
import torch
from torch import Tensor
from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
prob_check,
moving_sum,
)
def expected_alignment_from_p_choose(
p_choose: Tensor,
padding_mask: Optional[Tensor] = None,
eps: float = 1e-6
):
"""
Calculating expected alignment for from stepwise probability
Reference:
Online and Linear-Time Attention by Enforcing Monotonic Alignments
https://arxiv.org/pdf/1704.00784.pdf
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
a_ij = p_ij q_ij
Parallel solution:
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
============================================================
Expected input size
p_choose: bsz, tgt_len, src_len
"""
prob_check(p_choose)
# p_choose: bsz, tgt_len, src_len
bsz, tgt_len, src_len = p_choose.size()
dtype = p_choose.dtype
p_choose = p_choose.float()
if padding_mask is not None:
p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0)
# cumprod_1mp : bsz, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0)
alpha_0 = p_choose.new_zeros([bsz, 1, src_len])
alpha_0[:, :, 0] = 1.0
previous_alpha = [alpha_0]
for i in range(tgt_len):
# p_choose: bsz , tgt_len, src_len
# cumprod_1mp_clamp : bsz, tgt_len, src_len
# previous_alpha[i]: bsz, 1, src_len
# alpha_i: bsz, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(
previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1
)
).clamp(0, 1.0)
previous_alpha.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_alpha[1:], dim=1)
# Mix precision to prevent overflow for fp16
alpha = alpha.type(dtype)
prob_check(alpha)
return alpha
def expected_soft_attention(
alpha: Tensor,
soft_energy: Tensor,
padding_mask: Optional[Tensor] = None,
chunk_size: Optional[int] = None,
eps: float = 1e-10
):
"""
Function to compute expected soft attention for
monotonic infinite lookback attention from
expected alignment and soft energy.
Reference:
Monotonic Chunkwise Attention
https://arxiv.org/abs/1712.05382
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
https://arxiv.org/abs/1906.05218
alpha: bsz, tgt_len, src_len
soft_energy: bsz, tgt_len, src_len
padding_mask: bsz, src_len
left_padding: bool
"""
if padding_mask is not None:
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0)
soft_energy = soft_energy.masked_fill(
padding_mask.unsqueeze(1), -float("inf")
)
prob_check(alpha)
dtype = alpha.dtype
alpha = alpha.float()
soft_energy = soft_energy.float()
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy) + eps
if chunk_size is not None:
# Chunkwise
beta = (
exp_soft_energy
* moving_sum(
alpha / (eps + moving_sum(exp_soft_energy, chunk_size, 1)),
1, chunk_size
)
)
else:
# Infinite lookback
# Notice that infinite lookback is a special case of chunkwise
# where chunksize = inf
inner_items = alpha / (eps + torch.cumsum(exp_soft_energy, dim=2))
beta = (
exp_soft_energy
* torch.cumsum(inner_items.flip(dims=[2]), dim=2)
.flip(dims=[2])
)
if padding_mask is not None:
beta = beta.masked_fill(
padding_mask.unsqueeze(1).to(torch.bool), 0.0)
# Mix precision to prevent overflow for fp16
beta = beta.type(dtype)
beta = beta.clamp(0, 1)
prob_check(beta)
return beta
def mass_preservation(
alpha: Tensor,
padding_mask: Optional[Tensor] = None,
left_padding: bool = False
):
"""
Function to compute the mass perservation for alpha.
This means that the residual weights of alpha will be assigned
to the last token.
Reference:
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
https://arxiv.org/abs/1906.05218
alpha: bsz, tgt_len, src_len
padding_mask: bsz, src_len
left_padding: bool
"""
prob_check(alpha)
if padding_mask is not None:
if not left_padding:
assert not padding_mask[:, 0].any(), (
"Find padding on the beginning of the sequence."
)
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0)
if left_padding or padding_mask is None:
residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0, 1)
alpha[:, :, -1] = residuals
else:
# right padding
_, tgt_len, src_len = alpha.size()
residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0, 1)
src_lens = src_len - padding_mask.sum(dim=1, keepdim=True)
src_lens = src_lens.expand(-1, tgt_len).contiguous()
# add back the last value
residuals += alpha.gather(2, src_lens.unsqueeze(2) - 1)
alpha = alpha.scatter(2, src_lens.unsqueeze(2) - 1, residuals)
prob_check(alpha)
return alpha
from typing import Optional, Dict
from torch import Tensor
import torch
def waitk_p_choose(
tgt_len: int,
src_len: int,
bsz: int,
waitk_lagging: int,
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None
):
max_src_len = src_len
if incremental_state is not None:
# Retrieve target length from incremental states
# For inference the length of query is always 1
max_tgt_len = incremental_state["steps"]["tgt"]
assert max_tgt_len is not None
max_tgt_len = int(max_tgt_len)
else:
max_tgt_len = tgt_len
if max_src_len < waitk_lagging:
if incremental_state is not None:
max_tgt_len = 1
return torch.zeros(
bsz, max_tgt_len, max_src_len
)
# Assuming the p_choose looks like this for wait k=3
# src_len = 6, max_tgt_len = 5
# [0, 0, 1, 0, 0, 0, 0]
# [0, 0, 0, 1, 0, 0, 0]
# [0, 0, 0, 0, 1, 0, 0]
# [0, 0, 0, 0, 0, 1, 0]
# [0, 0, 0, 0, 0, 0, 1]
# linearize the p_choose matrix:
# [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0...]
# The indices of linearized matrix that equals 1 is
# 2 + 6 * 0
# 3 + 6 * 1
# ...
# n + src_len * n + k - 1 = n * (src_len + 1) + k - 1
# n from 0 to max_tgt_len - 1
#
# First, generate the indices (activate_indices_offset: bsz, max_tgt_len)
# Second, scatter a zeros tensor (bsz, max_tgt_len * src_len)
# with activate_indices_offset
# Third, resize the tensor to (bsz, max_tgt_len, src_len)
activate_indices_offset = (
(
torch.arange(max_tgt_len) * (max_src_len + 1)
+ waitk_lagging - 1
)
.unsqueeze(0)
.expand(bsz, max_tgt_len)
.long()
)
if key_padding_mask is not None:
if key_padding_mask[:, 0].any():
# Left padding
activate_indices_offset += (
key_padding_mask.sum(dim=1, keepdim=True)
)
# Need to clamp the indices that are too large
activate_indices_offset = (
activate_indices_offset
.clamp(
0,
min(
[
max_tgt_len,
max_src_len - waitk_lagging + 1
]
) * max_src_len - 1
)
)
p_choose = torch.zeros(bsz, max_tgt_len * max_src_len)
p_choose = p_choose.scatter(
1,
activate_indices_offset,
1.0
).view(bsz, max_tgt_len, max_src_len)
if key_padding_mask is not None:
p_choose = p_choose.to(key_padding_mask)
p_choose = p_choose.masked_fill(key_padding_mask.unsqueeze(1), 0)
if incremental_state is not None:
p_choose = p_choose[:, -1:]
return p_choose.float()
def learnable_p_choose(
energy,
noise_mean: float = 0.0,
noise_var: float = 0.0,
training: bool = True
):
"""
Calculating step wise prob for reading and writing
1 to read, 0 to write
energy: bsz, tgt_len, src_len
"""
noise = 0
if training:
# add noise here to encourage discretness
noise = (
torch.normal(noise_mean, noise_var, energy.size())
.type_as(energy)
.to(energy.device)
)
p_choose = torch.sigmoid(energy + noise)
# p_choose: bsz * self.num_heads, tgt_len, src_len
return p_choose
### 2021 Update: We are merging this example into the [S2T framework](../speech_to_text), which supports more generic speech-to-text tasks (e.g. speech translation) and more flexible data processing pipelines. Please stay tuned.
# Speech Recognition
`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
## Additional dependencies
On top of main fairseq dependencies there are couple more additional requirements.
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
## Preparing librispeech data
```
./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
```
## Training librispeech data
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
```
## Inference for librispeech
`$SET` can be `test_clean` or `test_other`
Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
```
## Inference for librispeech
```
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
```
`Sum/Avg` row from first table of the report has WER
## Using flashlight (previously called [wav2letter](https://github.com/facebookresearch/wav2letter)) components
[flashlight](https://github.com/facebookresearch/flashlight) now has integration with fairseq. Currently this includes:
* AutoSegmentationCriterion (ASG)
* flashlight-style Conv/GLU model
* flashlight's beam search decoder
To use these, follow the instructions on [this page](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) to install python bindings.
## Training librispeech data (flashlight style, Conv/GLU + ASG loss)
Training command:
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
```
Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
## Inference for librispeech (flashlight decoder, n-gram LM)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a flashlight-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
```
doorbell D O 1 R B E L 1 ▁
```
For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
```
doorbell ▁DOOR BE LL
doorbell ▁DOOR B E LL
doorbell ▁DO OR BE LL
doorbell ▁DOOR B EL L
doorbell ▁DOOR BE L L
doorbell ▁DO OR B E LL
doorbell ▁DOOR B E L L
doorbell ▁DO OR B EL L
doorbell ▁DO O R BE LL
doorbell ▁DO OR BE L L
```
Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
## Inference for librispeech (flashlight decoder, viterbi only)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
from . import criterions, models, tasks # noqa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from examples.speech_recognition.data.replabels import pack_replabels
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("asg_loss")
class ASGCriterion(FairseqCriterion):
@staticmethod
def add_args(parser):
group = parser.add_argument_group("ASG Loss")
group.add_argument(
"--asg-transitions-init",
help="initial diagonal value of transition matrix",
type=float,
default=0.0,
)
group.add_argument(
"--max-replabel", help="maximum # of replabels", type=int, default=2
)
group.add_argument(
"--linseg-updates",
help="# of training updates to use LinSeg initialization",
type=int,
default=0,
)
group.add_argument(
"--hide-linseg-messages",
help="hide messages about LinSeg initialization",
action="store_true",
)
def __init__(
self,
task,
silence_token,
asg_transitions_init,
max_replabel,
linseg_updates,
hide_linseg_messages,
):
from flashlight.lib.sequence.criterion import ASGLoss, CriterionScaleMode
super().__init__(task)
self.tgt_dict = task.target_dictionary
self.eos = self.tgt_dict.eos()
self.silence = (
self.tgt_dict.index(silence_token)
if silence_token in self.tgt_dict
else None
)
self.max_replabel = max_replabel
num_labels = len(self.tgt_dict)
self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
self.asg.trans = torch.nn.Parameter(
asg_transitions_init * torch.eye(num_labels), requires_grad=True
)
self.linseg_progress = torch.nn.Parameter(
torch.tensor([0], dtype=torch.int), requires_grad=False
)
self.linseg_maximum = linseg_updates
self.linseg_message_state = "none" if hide_linseg_messages else "start"
@classmethod
def build_criterion(cls, args, task):
return cls(
task,
args.silence_token,
args.asg_transitions_init,
args.max_replabel,
args.linseg_updates,
args.hide_linseg_messages,
)
def linseg_step(self):
if not self.training:
return False
if self.linseg_progress.item() < self.linseg_maximum:
if self.linseg_message_state == "start":
print("| using LinSeg to initialize ASG")
self.linseg_message_state = "finish"
self.linseg_progress.add_(1)
return True
elif self.linseg_message_state == "finish":
print("| finished LinSeg initialization")
self.linseg_message_state = "none"
return False
def replace_eos_with_silence(self, tgt):
if tgt[-1] != self.eos:
return tgt
elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
return tgt[:-1]
else:
return tgt[:-1] + [self.silence]
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
B = emissions.size(0)
T = emissions.size(1)
device = emissions.device
target = torch.IntTensor(B, T)
target_size = torch.IntTensor(B)
using_linseg = self.linseg_step()
for b in range(B):
initial_target_size = sample["target_lengths"][b].item()
if initial_target_size == 0:
raise ValueError("target size cannot be zero")
tgt = sample["target"][b, :initial_target_size].tolist()
tgt = self.replace_eos_with_silence(tgt)
tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
tgt = tgt[:T]
if using_linseg:
tgt = [tgt[t * len(tgt) // T] for t in range(T)]
target[b][: len(tgt)] = torch.IntTensor(tgt)
target_size[b] = len(tgt)
loss = self.asg.forward(emissions, target.to(device), target_size.to(device))
if reduce:
loss = torch.sum(loss)
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / nsentences,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return agg_output
import importlib
import os
# ASG loss requires flashlight bindings
files_to_skip = set()
try:
import flashlight.lib.sequence.criterion
except ImportError:
files_to_skip.add("ASG_loss.py")
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip:
criterion_name = file[: file.find(".py")]
importlib.import_module(
"examples.speech_recognition.criterions." + criterion_name
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("cross_entropy_acc")
class CrossEntropyWithAccCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
def compute_loss(self, model, net_output, target, reduction, log_probs):
# N, T -> N * T
target = target.view(-1)
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
if not hasattr(lprobs, "batch_first"):
logging.warning(
"ERROR: we need to know whether "
"batch first for the net output; "
"you need to set batch_first attribute for the return value of "
"model.get_normalized_probs. Now, we assume this is true, but "
"in the future, we will raise exception instead. "
)
batch_first = getattr(lprobs, "batch_first", True)
if not batch_first:
lprobs = lprobs.transpose(0, 1)
# N, T, D -> N * T, D
lprobs = lprobs.view(-1, lprobs.size(-1))
loss = F.nll_loss(
lprobs, target, ignore_index=self.padding_idx, reduction=reduction
)
return lprobs, loss
def get_logging_output(self, sample, target, lprobs, loss):
target = target.view(-1)
mask = target != self.padding_idx
correct = torch.sum(
lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
)
total = torch.sum(mask)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
"correct": utils.item(correct.data),
"total": utils.item(total.data),
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
}
return sample_size, logging_output
def forward(self, model, sample, reduction="sum", log_probs=True):
"""Computes the cross entropy with accuracy metric for the given sample.
This is similar to CrossEntropyCriterion in fairseq, but also
computes accuracy metrics as part of logging
Args:
logprobs (Torch.tensor) of shape N, T, D i.e.
batchsize, timesteps, dimensions
targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
Returns:
tuple: With three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
TODO:
* Currently this Criterion will only work with LSTMEncoderModels or
FairseqModels which have decoder, or Models which return TorchTensor
as net_output.
We need to make a change to support all FairseqEncoder models.
"""
net_output = model(**sample["net_input"])
target = model.get_targets(sample, net_output)
lprobs, loss = self.compute_loss(
model, net_output, target, reduction, log_probs
)
sample_size, logging_output = self.get_logging_output(
sample, target, lprobs, loss
)
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
total_sum = sum(log.get("total", 0) for log in logging_outputs)
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
# if args.sentence_avg, then sample_size is nsentences, then loss
# is per-sentence loss; else sample_size is ntokens, the loss
# becomes per-output token loss
"ntokens": ntokens,
"nsentences": nsentences,
"nframes": nframes,
"sample_size": sample_size,
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
"correct": correct_sum,
"total": total_sum,
# total is the number of validate tokens
}
if sample_size != ntokens:
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
# loss: per output token loss
# nll_loss: per sentence loss
return agg_output
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .asr_dataset import AsrDataset
__all__ = [
"AsrDataset",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
from fairseq.data import FairseqDataset
from . import data_utils
from .collaters import Seq2SeqCollater
class AsrDataset(FairseqDataset):
"""
A dataset representing speech and corresponding transcription.
Args:
aud_paths: (List[str]): A list of str with paths to audio files.
aud_durations_ms (List[int]): A list of int containing the durations of
audio files.
tgt (List[torch.LongTensor]): A list of LongTensors containing the indices
of target transcriptions.
tgt_dict (~fairseq.data.Dictionary): target vocabulary.
ids (List[str]): A list of utterance IDs.
speakers (List[str]): A list of speakers corresponding to utterances.
num_mel_bins (int): Number of triangular mel-frequency bins (default: 80)
frame_length (float): Frame length in milliseconds (default: 25.0)
frame_shift (float): Frame shift in milliseconds (default: 10.0)
"""
def __init__(
self,
aud_paths,
aud_durations_ms,
tgt,
tgt_dict,
ids,
speakers,
num_mel_bins=80,
frame_length=25.0,
frame_shift=10.0,
):
assert frame_length > 0
assert frame_shift > 0
assert all(x > frame_length for x in aud_durations_ms)
self.frame_sizes = [
int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
]
assert len(aud_paths) > 0
assert len(aud_paths) == len(aud_durations_ms)
assert len(aud_paths) == len(tgt)
assert len(aud_paths) == len(ids)
assert len(aud_paths) == len(speakers)
self.aud_paths = aud_paths
self.tgt_dict = tgt_dict
self.tgt = tgt
self.ids = ids
self.speakers = speakers
self.num_mel_bins = num_mel_bins
self.frame_length = frame_length
self.frame_shift = frame_shift
self.s2s_collater = Seq2SeqCollater(
0,
1,
pad_index=self.tgt_dict.pad(),
eos_index=self.tgt_dict.eos(),
move_eos_to_beginning=True,
)
def __getitem__(self, index):
import torchaudio
import torchaudio.compliance.kaldi as kaldi
tgt_item = self.tgt[index] if self.tgt is not None else None
path = self.aud_paths[index]
if not os.path.exists(path):
raise FileNotFoundError("Audio file not found: {}".format(path))
sound, sample_rate = torchaudio.load_wav(path)
output = kaldi.fbank(
sound,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
)
output_cmvn = data_utils.apply_mv_norm(output)
return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
def __len__(self):
return len(self.aud_paths)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[int]): sample indices to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return self.s2s_collater.collate(samples)
def num_tokens(self, index):
return self.frame_sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (
self.frame_sizes[index],
len(self.tgt[index]) if self.tgt is not None else 0,
)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self))
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This module contains collection of classes which implement
collate functionalities for various tasks.
Collaters should know what data to expect for each sample
and they should pack / collate them into batches
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import torch
from fairseq.data import data_utils as fairseq_data_utils
class Seq2SeqCollater(object):
"""
Implements collate function mainly for seq2seq tasks
This expects each sample to contain feature (src_tokens) and
targets.
This collator is also used for aligned training task.
"""
def __init__(
self,
feature_index=0,
label_index=1,
pad_index=1,
eos_index=2,
move_eos_to_beginning=True,
):
self.feature_index = feature_index
self.label_index = label_index
self.pad_index = pad_index
self.eos_index = eos_index
self.move_eos_to_beginning = move_eos_to_beginning
def _collate_frames(self, frames):
"""Convert a list of 2d frames into a padded 3d tensor
Args:
frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
len_max = max(frame.size(0) for frame in frames)
f_dim = frames[0].size(1)
res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0)
for i, v in enumerate(frames):
res[i, : v.size(0)] = v
return res
def collate(self, samples):
"""
utility function to collate samples into batch for speech recognition.
"""
if len(samples) == 0:
return {}
# parse samples into torch tensors
parsed_samples = []
for s in samples:
# skip invalid samples
if s["data"][self.feature_index] is None:
continue
source = s["data"][self.feature_index]
if isinstance(source, (np.ndarray, np.generic)):
source = torch.from_numpy(source)
target = s["data"][self.label_index]
if isinstance(target, (np.ndarray, np.generic)):
target = torch.from_numpy(target).long()
elif isinstance(target, list):
target = torch.LongTensor(target)
parsed_sample = {"id": s["id"], "source": source, "target": target}
parsed_samples.append(parsed_sample)
samples = parsed_samples
id = torch.LongTensor([s["id"] for s in samples])
frames = self._collate_frames([s["source"] for s in samples])
# sort samples by descending number of frames
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
frames_lengths, sort_order = frames_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
frames = frames.index_select(0, sort_order)
target = None
target_lengths = None
prev_output_tokens = None
if samples[0].get("target", None) is not None:
ntokens = sum(len(s["target"]) for s in samples)
target = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=False,
)
target = target.index_select(0, sort_order)
target_lengths = torch.LongTensor(
[s["target"].size(0) for s in samples]
).index_select(0, sort_order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=self.move_eos_to_beginning,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s["source"]) for s in samples)
batch = {
"id": id,
"ntokens": ntokens,
"net_input": {"src_tokens": frames, "src_lengths": frames_lengths},
"target": target,
"target_lengths": target_lengths,
"nsentences": len(samples),
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def calc_mean_invstddev(feature):
if len(feature.size()) != 2:
raise ValueError("We expect the input feature to be 2-D tensor")
mean = feature.mean(0)
var = feature.var(0)
# avoid division by ~zero
eps = 1e-8
if (var < eps).any():
return mean, 1.0 / (torch.sqrt(var) + eps)
return mean, 1.0 / torch.sqrt(var)
def apply_mv_norm(features):
# If there is less than 2 spectrograms, the variance cannot be computed (is NaN)
# and normalization is not possible, so return the item as it is
if features.size(0) < 2:
return features
mean, invstddev = calc_mean_invstddev(features)
res = (features - mean) * invstddev
return res
def lengths_to_encoder_padding_mask(lengths, batch_first=False):
"""
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor
Args:
lengths: a (B, )-shaped tensor
Return:
max_length: maximum length of B sequences
encoder_padding_mask: a (max_length, B) binary mask, where
[t, b] = 0 for t < lengths[b] and 1 otherwise
TODO:
kernelize this function if benchmarking shows this function is slow
"""
max_lengths = torch.max(lengths).item()
bsz = lengths.size(0)
encoder_padding_mask = torch.arange(
max_lengths
).to( # a (T, ) tensor with [0, ..., T-1]
lengths.device
).view( # move to the right device
1, max_lengths
).expand( # reshape to (1, T)-shaped tensor
bsz, -1
) >= lengths.view( # expand to (B, T)-shaped tensor
bsz, 1
).expand(
-1, max_lengths
)
if not batch_first:
return encoder_padding_mask.t(), max_lengths
else:
return encoder_padding_mask, max_lengths
def encoder_padding_mask_to_lengths(
encoder_padding_mask, max_lengths, batch_size, device
):
"""
convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor
Conventionally, encoder output contains a encoder_padding_mask, which is
a 2-D mask in a shape (T, B), whose (t, b) element indicate whether
encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we
need to convert this mask tensor to a 1-D tensor in shape (B, ), where
[b] denotes the valid length of b-th sequence
Args:
encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None,
indicating all are valid
Return:
seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the
number of valid elements of b-th sequence
max_lengths: maximum length of all sequence, if encoder_padding_mask is
not None, max_lengths must equal to encoder_padding_mask.size(0)
batch_size: batch size; if encoder_padding_mask is
not None, max_lengths must equal to encoder_padding_mask.size(1)
device: which device to put the result on
"""
if encoder_padding_mask is None:
return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device)
assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match"
assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match"
return max_lengths - torch.sum(encoder_padding_mask, dim=0)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Replabel transforms for use with flashlight's ASG criterion.
"""
def replabel_symbol(i):
"""
Replabel symbols used in flashlight, currently just "1", "2", ...
This prevents training with numeral tokens, so this might change in the future
"""
return str(i)
def pack_replabels(tokens, dictionary, max_reps):
"""
Pack a token sequence so that repeated symbols are replaced by replabels
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_value_to_idx = [0] * (max_reps + 1)
for i in range(1, max_reps + 1):
replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i))
result = []
prev_token = -1
num_reps = 0
for token in tokens:
if token == prev_token and num_reps < max_reps:
num_reps += 1
else:
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
num_reps = 0
result.append(token)
prev_token = token
if num_reps > 0:
result.append(replabel_value_to_idx[num_reps])
return result
def unpack_replabels(tokens, dictionary, max_reps):
"""
Unpack a token sequence so that replabels are replaced by repeated symbols
"""
if len(tokens) == 0 or max_reps <= 0:
return tokens
replabel_idx_to_value = {}
for i in range(1, max_reps + 1):
replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i
result = []
prev_token = -1
for token in tokens:
try:
for _ in range(replabel_idx_to_value[token]):
result.append(prev_token)
prev_token = -1
except KeyError:
result.append(token)
prev_token = token
return result
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import concurrent.futures
import json
import multiprocessing
import os
from collections import namedtuple
from itertools import chain
import sentencepiece as spm
from fairseq.data import Dictionary
MILLISECONDS_TO_SECONDS = 0.001
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
import torchaudio
input = {}
output = {}
si, ei = torchaudio.info(aud_path)
input["length_ms"] = int(
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
)
input["path"] = aud_path
token = " ".join(sp.EncodeAsPieces(lable))
ids = tgt_dict.encode_line(token, append_eos=False)
output["text"] = lable
output["token"] = token
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
return {utt_id: {"input": input, "output": output}}
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio-dirs",
nargs="+",
default=["-"],
required=True,
help="input directories with audio files",
)
parser.add_argument(
"--labels",
required=True,
help="aggregated input labels with format <ID LABEL> per line",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--spm-model",
required=True,
help="sentencepiece model to use for encoding",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument(
"--dictionary",
required=True,
help="file to load fairseq dictionary from",
type=argparse.FileType("r", encoding="UTF-8"),
)
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
parser.add_argument(
"--output",
required=True,
type=argparse.FileType("w"),
help="path to save json output",
)
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
sp.Load(args.spm_model.name)
tgt_dict = Dictionary.load(args.dictionary)
labels = {}
for line in args.labels:
(utt_id, label) = line.split(" ", 1)
labels[utt_id] = label
if len(labels) == 0:
raise Exception("No labels found in ", args.labels_path)
Sample = namedtuple("Sample", "aud_path utt_id")
samples = []
for path, _, files in chain.from_iterable(
os.walk(path) for path in args.audio_dirs
):
for f in files:
if f.endswith(args.audio_format):
if len(os.path.splitext(f)) != 2:
raise Exception("Expect <utt_id.extension> file name. Got: ", f)
utt_id = os.path.splitext(f)[0]
if utt_id not in labels:
continue
samples.append(Sample(os.path.join(path, f), utt_id))
utts = {}
num_cpu = multiprocessing.cpu_count()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
future_to_sample = {
executor.submit(
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
): s
for s in samples
}
for future in concurrent.futures.as_completed(future_to_sample):
try:
data = future.result()
except Exception as exc:
print("generated an exception: ", exc)
else:
utts.update(data)
json.dump({"utts": utts}, args.output, indent=4)
if __name__ == "__main__":
main()
#!/usr/bin/env bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Prepare librispeech dataset
base_url=www.openslr.org/resources/12
train_dir=train_960
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <download_dir> <out_dir>"
echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
exit 1
fi
download_dir=${1%/}
out_dir=${2%/}
fairseq_root=~/fairseq-py/
mkdir -p ${out_dir}
cd ${out_dir} || exit
nbpe=5000
bpemode=unigram
if [ ! -d "$fairseq_root" ]; then
echo "$0: Please set correct fairseq_root"
exit 1
fi
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
url=$base_url/$part.tar.gz
if ! wget -P $download_dir $url; then
echo "$0: wget failed for $url"
exit 1
fi
if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
exit 1
fi
done
echo "Merge all train packs into one"
mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
for part in train-clean-100 train-clean-360 train-other-500; do
mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
done
echo "Merge train text"
find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
# Use combined dev-clean and dev-other as validation set
find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
echo "dictionary: ${dict}"
echo "Dictionary preparation"
mkdir -p data/lang_char/
echo "<unk> 3" > ${dict}
echo "</s> 2" >> ${dict}
echo "<pad> 1" >> ${dict}
cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
wc -l ${dict}
echo "Prepare train and test jsons"
for part in train_960 test-other test-clean; do
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
done
# fairseq expects to find train.json and valid.json during training
mv train_960.json train.json
echo "Prepare valid json"
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
cp ${fairseq_dict} ./dict.txt
cp ${bpemodel}.model ./spm.model
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Run inference for pre-processed data with a trained model.
"""
import ast
import logging
import math
import os
import sys
import editdistance
import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import StopwatchMeter, TimeMeter
logging.basicConfig()
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def add_asr_eval_argument(parser):
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument(
"--rnnt_decoding_type",
default="greedy",
help="wfstlm on dictonary\
output units",
)
try:
parser.add_argument(
"--lm-weight",
"--lm_weight",
type=float,
default=0.2,
help="weight for lm while interpolating with neural score",
)
except:
pass
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
parser.add_argument(
"--w2l-decoder",
choices=["viterbi", "kenlm", "fairseqlm"],
help="use a w2l decoder",
)
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
parser.add_argument("--beam-threshold", type=float, default=25.0)
parser.add_argument("--beam-size-token", type=float, default=100)
parser.add_argument("--word-score", type=float, default=1.0)
parser.add_argument("--unk-weight", type=float, default=-math.inf)
parser.add_argument("--sil-weight", type=float, default=0.0)
parser.add_argument(
"--dump-emissions",
type=str,
default=None,
help="if present, dumps emissions into this file and exits",
)
parser.add_argument(
"--dump-features",
type=str,
default=None,
help="if present, dumps features into this file and exits",
)
parser.add_argument(
"--load-emissions",
type=str,
default=None,
help="if present, loads emissions from this file",
)
return parser
def check_args(args):
# assert args.path is not None, "--path required for generation!"
# assert args.results_path is not None, "--results_path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
def get_dataset_itr(args, task, models):
return task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
).next_epoch_itr(shuffle=False)
def process_predictions(
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
):
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
if "words" in hypo:
hyp_words = " ".join(hypo["words"])
else:
hyp_words = post_process(hyp_pieces, args.post_process)
if res_files is not None:
print(
"{} ({}-{})".format(hyp_pieces, speaker, id),
file=res_files["hypo.units"],
)
print(
"{} ({}-{})".format(hyp_words, speaker, id),
file=res_files["hypo.words"],
)
tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = post_process(tgt_pieces, args.post_process)
if res_files is not None:
print(
"{} ({}-{})".format(tgt_pieces, speaker, id),
file=res_files["ref.units"],
)
print(
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
)
if not args.quiet:
logger.info("HYPO:" + hyp_words)
logger.info("TARGET:" + tgt_words)
logger.info("___________________")
hyp_words = hyp_words.split()
tgt_words = tgt_words.split()
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def prepare_result_files(args):
def get_res_file(file_prefix):
if args.num_shards > 1:
file_prefix = f"{args.shard_id}_{file_prefix}"
path = os.path.join(
args.results_path,
"{}-{}-{}.txt".format(
file_prefix, os.path.basename(args.path), args.gen_subset
),
)
return open(path, "w", buffering=1)
if not args.results_path:
return None
return {
"hypo.words": get_res_file("hypo.word"),
"hypo.units": get_res_file("hypo.units"),
"ref.words": get_res_file("ref.word"),
"ref.units": get_res_file("ref.units"),
}
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
class ExistingEmissionsDecoder(object):
def __init__(self, decoder, emissions):
self.decoder = decoder
self.emissions = emissions
def generate(self, models, sample, **unused):
ids = sample["id"].cpu().numpy()
try:
emissions = np.stack(self.emissions[ids])
except:
print([x.shape for x in self.emissions[ids]])
raise Exception("invalid sizes")
emissions = torch.from_numpy(emissions)
return self.decoder.decode(emissions)
def main(args, task=None, model_state=None):
check_args(args)
if args.max_tokens is None and args.batch_size is None:
args.max_tokens = 4000000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
logger.info("| decoding with criterion {}".format(args.criterion))
task = tasks.setup_task(args)
# Load ensemble
if args.load_emissions:
models, criterions = [], []
task.load_dataset(args.gen_subset)
else:
logger.info("| loading model(s) from {}".format(args.path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(args.path, separator="\\"),
arg_overrides=ast.literal_eval(args.model_overrides),
task=task,
suffix=args.checkpoint_suffix,
strict=(args.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count,
state=model_state,
)
optimize_models(args, use_cuda, models)
task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
# Set dictionary
tgt_dict = task.target_dictionary
logger.info(
"| {} {} {} examples".format(
args.data, args.gen_subset, len(task.dataset(args.gen_subset))
)
)
# hack to pass transitions to W2lDecoder
if args.criterion == "asg_loss":
raise NotImplementedError("asg_loss is currently not supported")
# trans = criterions[0].asg.trans.data
# args.asg_transitions = torch.flatten(trans).tolist()
# Load dataset (possibly sharded)
itr = get_dataset_itr(args, task, models)
# Initialize generator
gen_timer = StopwatchMeter()
def build_generator(args):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, task.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, task.target_dictionary)
elif w2l_decoder == "fairseqlm":
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
return W2lFairseqLMDecoder(args, task.target_dictionary)
else:
print(
"only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
)
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
generator = build_generator(args)
if args.load_emissions:
generator = ExistingEmissionsDecoder(
generator, np.load(args.load_emissions, allow_pickle=True)
)
logger.info("loaded emissions from " + args.load_emissions)
num_sentences = 0
if args.results_path is not None and not os.path.exists(args.results_path):
os.makedirs(args.results_path)
max_source_pos = (
utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
),
)
if max_source_pos is not None:
max_source_pos = max_source_pos[0]
if max_source_pos is not None:
max_source_pos = max_source_pos[0] - 1
if args.dump_emissions:
emissions = {}
if args.dump_features:
features = {}
models[0].bert.proj = None
else:
res_files = prepare_result_files(args)
errs_t = 0
lengths_t = 0
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]
gen_timer.start()
if args.dump_emissions:
with torch.no_grad():
encoder_out = models[0](**sample["net_input"])
emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
emm = emm.transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
emissions[id.item()] = emm[i]
continue
elif args.dump_features:
with torch.no_grad():
encoder_out = models[0](**sample["net_input"])
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
padding = (
encoder_out["encoder_padding_mask"][i].cpu().numpy()
if encoder_out["encoder_padding_mask"] is not None
else None
)
features[id.item()] = (feat[i], padding)
continue
hypos = task.inference_step(generator, models, sample, prefix_tokens)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample["id"].tolist()):
speaker = None
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
id = sample_id
toks = (
sample["target"][i, :]
if "target_label" not in sample
else sample["target_label"][i, :]
)
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
# Process top predictions
errs, length = process_predictions(
args,
hypos[i],
None,
tgt_dict,
target_tokens,
res_files,
speaker,
id,
)
errs_t += errs
lengths_t += length
wps_meter.update(num_generated_tokens)
t.log({"wps": round(wps_meter.avg)})
num_sentences += (
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
)
wer = None
if args.dump_emissions:
emm_arr = []
for i in range(len(emissions)):
emm_arr.append(emissions[i])
np.save(args.dump_emissions, emm_arr)
logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
elif args.dump_features:
feat_arr = []
for i in range(len(features)):
feat_arr.append(features[i])
np.save(args.dump_features, feat_arr)
logger.info(f"saved {len(features)} emissions to {args.dump_features}")
else:
if lengths_t > 0:
wer = errs_t * 100.0 / lengths_t
logger.info(f"WER: {wer}")
logger.info(
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
"sentences/s, {:.2f} tokens/s)".format(
num_sentences,
gen_timer.n,
gen_timer.sum,
num_sentences / gen_timer.sum,
1.0 / gen_timer.avg,
)
)
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
return task, wer
def make_parser():
parser = options.get_generation_parser()
parser = add_asr_eval_argument(parser)
return parser
def cli_main():
parser = make_parser()
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <iostream>
#include "fstext/fstext-lib.h" // @manual
#include "util/common-utils.h" // @manual
/*
* This program is to modify a FST without self-loop by:
* for each incoming arc with non-eps input symbol, add a self-loop arc
* with that non-eps symbol as input and eps as output.
*
* This is to make sure the resultant FST can do deduplication for repeated
* symbols, which is very common in acoustic model
*
*/
namespace {
int32 AddSelfLoopsSimple(fst::StdVectorFst* fst) {
typedef fst::MutableArcIterator<fst::StdVectorFst> IterType;
int32 num_states_before = fst->NumStates();
fst::MakePrecedingInputSymbolsSame(false, fst);
int32 num_states_after = fst->NumStates();
KALDI_LOG << "There are " << num_states_before
<< " states in the original FST; "
<< " after MakePrecedingInputSymbolsSame, there are "
<< num_states_after << " states " << std::endl;
auto weight_one = fst::StdArc::Weight::One();
int32 num_arc_added = 0;
fst::StdArc self_loop_arc;
self_loop_arc.weight = weight_one;
int32 num_states = fst->NumStates();
std::vector<std::set<int32>> incoming_non_eps_label_per_state(num_states);
for (int32 state = 0; state < num_states; state++) {
for (IterType aiter(fst, state); !aiter.Done(); aiter.Next()) {
fst::StdArc arc(aiter.Value());
if (arc.ilabel != 0) {
incoming_non_eps_label_per_state[arc.nextstate].insert(arc.ilabel);
}
}
}
for (int32 state = 0; state < num_states; state++) {
if (!incoming_non_eps_label_per_state[state].empty()) {
auto& ilabel_set = incoming_non_eps_label_per_state[state];
for (auto it = ilabel_set.begin(); it != ilabel_set.end(); it++) {
self_loop_arc.ilabel = *it;
self_loop_arc.olabel = 0;
self_loop_arc.nextstate = state;
fst->AddArc(state, self_loop_arc);
num_arc_added++;
}
}
}
return num_arc_added;
}
void print_usage() {
std::cout << "add-self-loop-simple usage:\n"
"\tadd-self-loop-simple <in-fst> <out-fst> \n";
}
} // namespace
int main(int argc, char** argv) {
if (argc != 3) {
print_usage();
exit(1);
}
auto input = argv[1];
auto output = argv[2];
auto fst = fst::ReadFstKaldi(input);
auto num_states = fst->NumStates();
KALDI_LOG << "Loading FST from " << input << " with " << num_states
<< " states." << std::endl;
int32 num_arc_added = AddSelfLoopsSimple(fst);
KALDI_LOG << "Adding " << num_arc_added << " self-loop arcs " << std::endl;
fst::WriteFstKaldi(*fst, std::string(output));
KALDI_LOG << "Writing FST to " << output << std::endl;
delete fst;
}
# @package _group_
data_dir: ???
fst_dir: ???
in_labels: ???
kaldi_root: ???
lm_arpa: ???
blank_symbol: <s>
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from concurrent.futures import ThreadPoolExecutor
import logging
from omegaconf import MISSING
import os
import torch
from typing import Optional
import warnings
from dataclasses import dataclass
from fairseq.dataclass import FairseqDataclass
from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi
logger = logging.getLogger(__name__)
@dataclass
class KaldiDecoderConfig(FairseqDataclass):
hlg_graph_path: Optional[str] = None
output_dict: str = MISSING
kaldi_initializer_config: Optional[KaldiInitializerConfig] = None
acoustic_scale: float = 0.5
max_active: int = 10000
beam_delta: float = 0.5
hash_ratio: float = 2.0
is_lattice: bool = False
lattice_beam: float = 10.0
prune_interval: int = 25
determinize_lattice: bool = True
prune_scale: float = 0.1
max_mem: int = 0
phone_determinize: bool = True
word_determinize: bool = True
minimize: bool = True
num_threads: int = 1
class KaldiDecoder(object):
def __init__(
self,
cfg: KaldiDecoderConfig,
beam: int,
nbest: int = 1,
):
try:
from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer
from kaldi.base import set_verbose_level
from kaldi.decoder import (
FasterDecoder,
FasterDecoderOptions,
LatticeFasterDecoder,
LatticeFasterDecoderOptions,
)
from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions
from kaldi.fstext import read_fst_kaldi, SymbolTable
except:
warnings.warn(
"pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi"
)
# set_verbose_level(2)
self.acoustic_scale = cfg.acoustic_scale
self.nbest = nbest
if cfg.hlg_graph_path is None:
assert (
cfg.kaldi_initializer_config is not None
), "Must provide hlg graph path or kaldi initializer config"
cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config)
assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path
if cfg.is_lattice:
self.dec_cls = LatticeFasterDecoder
opt_cls = LatticeFasterDecoderOptions
self.rec_cls = LatticeFasterRecognizer
else:
assert self.nbest == 1, "nbest > 1 requires lattice decoder"
self.dec_cls = FasterDecoder
opt_cls = FasterDecoderOptions
self.rec_cls = FasterRecognizer
self.decoder_options = opt_cls()
self.decoder_options.beam = beam
self.decoder_options.max_active = cfg.max_active
self.decoder_options.beam_delta = cfg.beam_delta
self.decoder_options.hash_ratio = cfg.hash_ratio
if cfg.is_lattice:
self.decoder_options.lattice_beam = cfg.lattice_beam
self.decoder_options.prune_interval = cfg.prune_interval
self.decoder_options.determinize_lattice = cfg.determinize_lattice
self.decoder_options.prune_scale = cfg.prune_scale
det_opts = DeterminizeLatticePhonePrunedOptions()
det_opts.max_mem = cfg.max_mem
det_opts.phone_determinize = cfg.phone_determinize
det_opts.word_determinize = cfg.word_determinize
det_opts.minimize = cfg.minimize
self.decoder_options.det_opts = det_opts
self.output_symbols = {}
with open(cfg.output_dict, "r") as f:
for line in f:
items = line.rstrip().split()
assert len(items) == 2
self.output_symbols[int(items[1])] = items[0]
logger.info(f"Loading FST from {cfg.hlg_graph_path}")
self.fst = read_fst_kaldi(cfg.hlg_graph_path)
self.symbol_table = SymbolTable.read_text(cfg.output_dict)
self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads)
def generate(self, models, sample, **unused):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions, padding = self.get_emissions(models, encoder_input)
return self.decode(emissions, padding)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
model = models[0]
all_encoder_out = [m(**encoder_input) for m in models]
if len(all_encoder_out) > 1:
if "encoder_out" in all_encoder_out[0]:
encoder_out = {
"encoder_out": sum(e["encoder_out"] for e in all_encoder_out)
/ len(all_encoder_out),
"encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"],
}
padding = encoder_out["encoder_padding_mask"]
else:
encoder_out = {
"logits": sum(e["logits"] for e in all_encoder_out)
/ len(all_encoder_out),
"padding_mask": all_encoder_out[0]["padding_mask"],
}
padding = encoder_out["padding_mask"]
else:
encoder_out = all_encoder_out[0]
padding = (
encoder_out["padding_mask"]
if "padding_mask" in encoder_out
else encoder_out["encoder_padding_mask"]
)
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out, normalize=True)
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return (
emissions.cpu().float().transpose(0, 1),
padding.cpu() if padding is not None and padding.any() else None,
)
def decode_one(self, logits, padding):
from kaldi.matrix import Matrix
decoder = self.dec_cls(self.fst, self.decoder_options)
asr = self.rec_cls(
decoder, self.symbol_table, acoustic_scale=self.acoustic_scale
)
if padding is not None:
logits = logits[~padding]
mat = Matrix(logits.numpy())
out = asr.decode(mat)
if self.nbest > 1:
from kaldi.fstext import shortestpath
from kaldi.fstext.utils import (
convert_compact_lattice_to_lattice,
convert_lattice_to_std,
convert_nbest_to_list,
get_linear_symbol_sequence,
)
lat = out["lattice"]
sp = shortestpath(lat, nshortest=self.nbest)
sp = convert_compact_lattice_to_lattice(sp)
sp = convert_lattice_to_std(sp)
seq = convert_nbest_to_list(sp)
results = []
for s in seq:
_, o, w = get_linear_symbol_sequence(s)
words = list(self.output_symbols[z] for z in o)
results.append(
{
"tokens": words,
"words": words,
"score": w.value,
"emissions": logits,
}
)
return results
else:
words = out["text"].split()
return [
{
"tokens": words,
"words": words,
"score": out["likelihood"],
"emissions": logits,
}
]
def decode(self, emissions, padding):
if padding is None:
padding = [None] * len(emissions)
ret = list(
map(
lambda e, p: self.executor.submit(self.decode_one, e, p),
emissions,
padding,
)
)
return ret
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import hydra
from hydra.core.config_store import ConfigStore
import logging
from omegaconf import MISSING, OmegaConf
import os
import os.path as osp
from pathlib import Path
import subprocess
from typing import Optional
from fairseq.data.dictionary import Dictionary
from fairseq.dataclass import FairseqDataclass
script_dir = Path(__file__).resolve().parent
config_path = script_dir / "config"
logger = logging.getLogger(__name__)
@dataclass
class KaldiInitializerConfig(FairseqDataclass):
data_dir: str = MISSING
fst_dir: Optional[str] = None
in_labels: str = MISSING
out_labels: Optional[str] = None
wav2letter_lexicon: Optional[str] = None
lm_arpa: str = MISSING
kaldi_root: str = MISSING
blank_symbol: str = "<s>"
silence_symbol: Optional[str] = None
def create_units(fst_dir: Path, in_labels: str, vocab: Dictionary) -> Path:
in_units_file = fst_dir / f"kaldi_dict.{in_labels}.txt"
if not in_units_file.exists():
logger.info(f"Creating {in_units_file}")
with open(in_units_file, "w") as f:
print("<eps> 0", file=f)
i = 1
for symb in vocab.symbols[vocab.nspecial :]:
if not symb.startswith("madeupword"):
print(f"{symb} {i}", file=f)
i += 1
return in_units_file
def create_lexicon(
cfg: KaldiInitializerConfig,
fst_dir: Path,
unique_label: str,
in_units_file: Path,
out_words_file: Path,
) -> (Path, Path):
disambig_in_units_file = fst_dir / f"kaldi_dict.{cfg.in_labels}_disambig.txt"
lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}.txt"
disambig_lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}_disambig.txt"
if (
not lexicon_file.exists()
or not disambig_lexicon_file.exists()
or not disambig_in_units_file.exists()
):
logger.info(f"Creating {lexicon_file} (in units file: {in_units_file})")
assert cfg.wav2letter_lexicon is not None or cfg.in_labels == cfg.out_labels
if cfg.wav2letter_lexicon is not None:
lm_words = set()
with open(out_words_file, "r") as lm_dict_f:
for line in lm_dict_f:
lm_words.add(line.split()[0])
num_skipped = 0
total = 0
with open(cfg.wav2letter_lexicon, "r") as w2l_lex_f, open(
lexicon_file, "w"
) as out_f:
for line in w2l_lex_f:
items = line.rstrip().split("\t")
assert len(items) == 2, items
if items[0] in lm_words:
print(items[0], items[1], file=out_f)
else:
num_skipped += 1
logger.debug(
f"Skipping word {items[0]} as it was not found in LM"
)
total += 1
if num_skipped > 0:
logger.warning(
f"Skipped {num_skipped} out of {total} words as they were not found in LM"
)
else:
with open(in_units_file, "r") as in_f, open(lexicon_file, "w") as out_f:
for line in in_f:
symb = line.split()[0]
if symb != "<eps>" and symb != "<ctc_blank>" and symb != "<SIL>":
print(symb, symb, file=out_f)
lex_disambig_path = (
Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_lex_disambig.pl"
)
res = subprocess.run(
[lex_disambig_path, lexicon_file, disambig_lexicon_file],
check=True,
capture_output=True,
)
ndisambig = int(res.stdout)
disamib_path = Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_disambig.pl"
res = subprocess.run(
[disamib_path, "--include-zero", in_units_file, str(ndisambig)],
check=True,
capture_output=True,
)
with open(disambig_in_units_file, "wb") as f:
f.write(res.stdout)
return disambig_lexicon_file, disambig_in_units_file
def create_G(
kaldi_root: Path, fst_dir: Path, lm_arpa: Path, arpa_base: str
) -> (Path, Path):
out_words_file = fst_dir / f"kaldi_dict.{arpa_base}.txt"
grammar_graph = fst_dir / f"G_{arpa_base}.fst"
if not grammar_graph.exists() or not out_words_file.exists():
logger.info(f"Creating {grammar_graph}")
arpa2fst = kaldi_root / "src/lmbin/arpa2fst"
subprocess.run(
[
arpa2fst,
"--disambig-symbol=#0",
f"--write-symbol-table={out_words_file}",
lm_arpa,
grammar_graph,
],
check=True,
)
return grammar_graph, out_words_file
def create_L(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
lexicon_file: Path,
in_units_file: Path,
out_words_file: Path,
) -> Path:
lexicon_graph = fst_dir / f"L.{unique_label}.fst"
if not lexicon_graph.exists():
logger.info(f"Creating {lexicon_graph} (in units: {in_units_file})")
make_lex = kaldi_root / "egs/wsj/s5/utils/make_lexicon_fst.pl"
fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile"
fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops"
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
def write_disambig_symbol(file):
with open(file, "r") as f:
for line in f:
items = line.rstrip().split()
if items[0] == "#0":
out_path = str(file) + "_disamig"
with open(out_path, "w") as out_f:
print(items[1], file=out_f)
return out_path
return None
in_disambig_sym = write_disambig_symbol(in_units_file)
assert in_disambig_sym is not None
out_disambig_sym = write_disambig_symbol(out_words_file)
assert out_disambig_sym is not None
try:
with open(lexicon_graph, "wb") as out_f:
res = subprocess.run(
[make_lex, lexicon_file], capture_output=True, check=True
)
assert len(res.stderr) == 0, res.stderr.decode("utf-8")
res = subprocess.run(
[
fstcompile,
f"--isymbols={in_units_file}",
f"--osymbols={out_words_file}",
"--keep_isymbols=false",
"--keep_osymbols=false",
],
input=res.stdout,
capture_output=True,
)
assert len(res.stderr) == 0, res.stderr.decode("utf-8")
res = subprocess.run(
[fstaddselfloops, in_disambig_sym, out_disambig_sym],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstarcsort, "--sort_type=olabel"],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(lexicon_graph)
raise
except AssertionError:
os.remove(lexicon_graph)
raise
return lexicon_graph
def create_LG(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
lexicon_graph: Path,
grammar_graph: Path,
) -> Path:
lg_graph = fst_dir / f"LG.{unique_label}.fst"
if not lg_graph.exists():
logger.info(f"Creating {lg_graph}")
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
fstpushspecial = kaldi_root / "src/fstbin/fstpushspecial"
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
try:
with open(lg_graph, "wb") as out_f:
res = subprocess.run(
[fsttablecompose, lexicon_graph, grammar_graph],
capture_output=True,
check=True,
)
res = subprocess.run(
[
fstdeterminizestar,
"--use-log=true",
],
input=res.stdout,
capture_output=True,
)
res = subprocess.run(
[fstminimizeencoded],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstpushspecial],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstarcsort, "--sort_type=ilabel"],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(lg_graph)
raise
return lg_graph
def create_H(
kaldi_root: Path,
fst_dir: Path,
disambig_out_units_file: Path,
in_labels: str,
vocab: Dictionary,
blk_sym: str,
silence_symbol: Optional[str],
) -> (Path, Path, Path):
h_graph = (
fst_dir / f"H.{in_labels}{'_' + silence_symbol if silence_symbol else ''}.fst"
)
h_out_units_file = fst_dir / f"kaldi_dict.h_out.{in_labels}.txt"
disambig_in_units_file_int = Path(str(h_graph) + "isym_disambig.int")
disambig_out_units_file_int = Path(str(disambig_out_units_file) + ".int")
if (
not h_graph.exists()
or not h_out_units_file.exists()
or not disambig_in_units_file_int.exists()
):
logger.info(f"Creating {h_graph}")
eps_sym = "<eps>"
num_disambig = 0
osymbols = []
with open(disambig_out_units_file, "r") as f, open(
disambig_out_units_file_int, "w"
) as out_f:
for line in f:
symb, id = line.rstrip().split()
if line.startswith("#"):
num_disambig += 1
print(id, file=out_f)
else:
if len(osymbols) == 0:
assert symb == eps_sym, symb
osymbols.append((symb, id))
i_idx = 0
isymbols = [(eps_sym, 0)]
imap = {}
for i, s in enumerate(vocab.symbols):
i_idx += 1
isymbols.append((s, i_idx))
imap[s] = i_idx
fst_str = []
node_idx = 0
root_node = node_idx
special_symbols = [blk_sym]
if silence_symbol is not None:
special_symbols.append(silence_symbol)
for ss in special_symbols:
fst_str.append("{} {} {} {}".format(root_node, root_node, ss, eps_sym))
for symbol, _ in osymbols:
if symbol == eps_sym or symbol.startswith("#"):
continue
node_idx += 1
# 1. from root to emitting state
fst_str.append("{} {} {} {}".format(root_node, node_idx, symbol, symbol))
# 2. from emitting state back to root
fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym))
# 3. from emitting state to optional blank state
pre_node = node_idx
node_idx += 1
for ss in special_symbols:
fst_str.append("{} {} {} {}".format(pre_node, node_idx, ss, eps_sym))
# 4. from blank state back to root
fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym))
fst_str.append("{}".format(root_node))
fst_str = "\n".join(fst_str)
h_str = str(h_graph)
isym_file = h_str + ".isym"
with open(isym_file, "w") as f:
for sym, id in isymbols:
f.write("{} {}\n".format(sym, id))
with open(h_out_units_file, "w") as f:
for sym, id in osymbols:
f.write("{} {}\n".format(sym, id))
with open(disambig_in_units_file_int, "w") as f:
disam_sym_id = len(isymbols)
for _ in range(num_disambig):
f.write("{}\n".format(disam_sym_id))
disam_sym_id += 1
fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile"
fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops"
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort"
try:
with open(h_graph, "wb") as out_f:
res = subprocess.run(
[
fstcompile,
f"--isymbols={isym_file}",
f"--osymbols={h_out_units_file}",
"--keep_isymbols=false",
"--keep_osymbols=false",
],
input=str.encode(fst_str),
capture_output=True,
check=True,
)
res = subprocess.run(
[
fstaddselfloops,
disambig_in_units_file_int,
disambig_out_units_file_int,
],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstarcsort, "--sort_type=olabel"],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(h_graph)
raise
return h_graph, h_out_units_file, disambig_in_units_file_int
def create_HLGa(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
h_graph: Path,
lg_graph: Path,
disambig_in_words_file_int: Path,
) -> Path:
hlga_graph = fst_dir / f"HLGa.{unique_label}.fst"
if not hlga_graph.exists():
logger.info(f"Creating {hlga_graph}")
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols"
fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal"
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
try:
with open(hlga_graph, "wb") as out_f:
res = subprocess.run(
[
fsttablecompose,
h_graph,
lg_graph,
],
capture_output=True,
check=True,
)
res = subprocess.run(
[fstdeterminizestar, "--use-log=true"],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstrmsymbols, disambig_in_words_file_int],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstrmepslocal],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstminimizeencoded],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(hlga_graph)
raise
return hlga_graph
def create_HLa(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
h_graph: Path,
l_graph: Path,
disambig_in_words_file_int: Path,
) -> Path:
hla_graph = fst_dir / f"HLa.{unique_label}.fst"
if not hla_graph.exists():
logger.info(f"Creating {hla_graph}")
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose"
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar"
fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols"
fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal"
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded"
try:
with open(hla_graph, "wb") as out_f:
res = subprocess.run(
[
fsttablecompose,
h_graph,
l_graph,
],
capture_output=True,
check=True,
)
res = subprocess.run(
[fstdeterminizestar, "--use-log=true"],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstrmsymbols, disambig_in_words_file_int],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstrmepslocal],
input=res.stdout,
capture_output=True,
check=True,
)
res = subprocess.run(
[fstminimizeencoded],
input=res.stdout,
capture_output=True,
check=True,
)
out_f.write(res.stdout)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
os.remove(hla_graph)
raise
return hla_graph
def create_HLG(
kaldi_root: Path,
fst_dir: Path,
unique_label: str,
hlga_graph: Path,
prefix: str = "HLG",
) -> Path:
hlg_graph = fst_dir / f"{prefix}.{unique_label}.fst"
if not hlg_graph.exists():
logger.info(f"Creating {hlg_graph}")
add_self_loop = script_dir / "add-self-loop-simple"
kaldi_src = kaldi_root / "src"
kaldi_lib = kaldi_src / "lib"
try:
if not add_self_loop.exists():
fst_include = kaldi_root / "tools/openfst-1.6.7/include"
add_self_loop_src = script_dir / "add-self-loop-simple.cc"
subprocess.run(
[
"c++",
f"-I{kaldi_src}",
f"-I{fst_include}",
f"-L{kaldi_lib}",
add_self_loop_src,
"-lkaldi-base",
"-lkaldi-fstext",
"-o",
add_self_loop,
],
check=True,
)
my_env = os.environ.copy()
my_env["LD_LIBRARY_PATH"] = f"{kaldi_lib}:{my_env['LD_LIBRARY_PATH']}"
subprocess.run(
[
add_self_loop,
hlga_graph,
hlg_graph,
],
check=True,
capture_output=True,
env=my_env,
)
except subprocess.CalledProcessError as e:
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}")
raise
return hlg_graph
def initalize_kaldi(cfg: KaldiInitializerConfig) -> Path:
if cfg.fst_dir is None:
cfg.fst_dir = osp.join(cfg.data_dir, "kaldi")
if cfg.out_labels is None:
cfg.out_labels = cfg.in_labels
kaldi_root = Path(cfg.kaldi_root)
data_dir = Path(cfg.data_dir)
fst_dir = Path(cfg.fst_dir)
fst_dir.mkdir(parents=True, exist_ok=True)
arpa_base = osp.splitext(osp.basename(cfg.lm_arpa))[0]
unique_label = f"{cfg.in_labels}.{arpa_base}"
with open(data_dir / f"dict.{cfg.in_labels}.txt", "r") as f:
vocab = Dictionary.load(f)
in_units_file = create_units(fst_dir, cfg.in_labels, vocab)
grammar_graph, out_words_file = create_G(
kaldi_root, fst_dir, Path(cfg.lm_arpa), arpa_base
)
disambig_lexicon_file, disambig_L_in_units_file = create_lexicon(
cfg, fst_dir, unique_label, in_units_file, out_words_file
)
h_graph, h_out_units_file, disambig_in_units_file_int = create_H(
kaldi_root,
fst_dir,
disambig_L_in_units_file,
cfg.in_labels,
vocab,
cfg.blank_symbol,
cfg.silence_symbol,
)
lexicon_graph = create_L(
kaldi_root,
fst_dir,
unique_label,
disambig_lexicon_file,
disambig_L_in_units_file,
out_words_file,
)
lg_graph = create_LG(
kaldi_root, fst_dir, unique_label, lexicon_graph, grammar_graph
)
hlga_graph = create_HLGa(
kaldi_root, fst_dir, unique_label, h_graph, lg_graph, disambig_in_units_file_int
)
hlg_graph = create_HLG(kaldi_root, fst_dir, unique_label, hlga_graph)
# for debugging
# hla_graph = create_HLa(kaldi_root, fst_dir, unique_label, h_graph, lexicon_graph, disambig_in_units_file_int)
# hl_graph = create_HLG(kaldi_root, fst_dir, unique_label, hla_graph, prefix="HL_looped")
# create_HLG(kaldi_root, fst_dir, "phnc", h_graph, prefix="H_looped")
return hlg_graph
@hydra.main(config_path=config_path, config_name="kaldi_initializer")
def cli_main(cfg: KaldiInitializerConfig) -> None:
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
cfg = OmegaConf.create(container)
OmegaConf.set_struct(cfg, True)
initalize_kaldi(cfg)
if __name__ == "__main__":
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
try:
from hydra._internal.utils import (
get_args,
) # pylint: disable=import-outside-toplevel
cfg_name = get_args().config_name or "kaldi_initializer"
except ImportError:
logger.warning("Failed to get config name from hydra args")
cfg_name = "kaldi_initializer"
cs = ConfigStore.instance()
cs.store(name=cfg_name, node=KaldiInitializerConfig)
cli_main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment