Commit 1ed330b5 authored by G. Sun's avatar G. Sun Committed by Facebook GitHub Bot
Browse files

Add TCPGen context-biasing Conformer RNN-T (#2890)

Summary:
This commit adds the implementation of the tree-constrained pointer generator (TCPGen) for contextual biasing.

An example for Librispeech can be found in audio/examples/asr/librispeech_biasing.

Maintainer's note (mthrok):
It seems that TrieNode should be better typed as tuple, but changing the implementation from list to tuple
could cause some issue without running the code, so the code is not changed, though the annotation uses tuple.

Pull Request resolved: https://github.com/pytorch/audio/pull/2890

Reviewed By: nateanl

Differential Revision: D43171447

Pulled By: mthrok

fbshipit-source-id: 372bb077d997d720401dbf2dbfa131e6a958e37e
parent d3c9295c
# Contextual Conformer RNN-T with TCPGen Example
This directory contains sample implementations of training and evaluation pipelines for the Conformer RNN-T model with tree-constrained pointer generator (TCPGen) for contextual biasing, as described in the paper: [Tree-Constrained Pointer Generator for End-to-End Contextual Speech Recognition](https://ieeexplore.ieee.org/abstract/document/9687915)
## Setup
### Install PyTorch and TorchAudio nightly or from source
Because Conformer RNN-T is currently a prototype feature, you will need to either use the TorchAudio nightly build or build TorchAudio from source. Note also that GPU support is required for training.
To install the nightly, follow the directions at <https://pytorch.org/>.
To build TorchAudio from source, refer to the [contributing guidelines](https://github.com/pytorch/audio/blob/main/CONTRIBUTING.md).
### Install additional dependencies
```bash
pip install pytorch-lightning sentencepiece
```
## Usage
### Training
[`train.py`](./train.py) trains an Conformer RNN-T model with TCPGen on LibriSpeech using PyTorch Lightning. Note that the script expects users to have the following:
- Access to GPU nodes for training.
- Full LibriSpeech dataset.
- SentencePiece model to be used to encode targets; the model can be generated using [`train_spm.py`](./train_spm.py). **Note that suffix-based wordpieces are used in this example**. [`run_spm.sh`](./run_spm.sh) will generate 600 suffix-based wordpieces which is used in the paper.
- File (--global_stats_path) that contains training set feature statistics; this file can be generated using [`global_stats.py`](../emformer_rnnt/global_stats.py). The [`global_stats_100.json`](./global_stats_100.json) has been generated for train-clean-100.
- Biasing list: Please download [`rareword_f15.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f15.txt), [`rareword_f30.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f30.txt) and [`all_rare_words.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/all_rare_words.txt) and put it in [`blists`](./blists) directory. See [`rareword_f15.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f15.txt) as an example for the biasing list used for training with clean-100 data. Words appeared less than 15 times were treated as rare words. For inference, [`all_rare_words.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/all_rare_words.txt) which is the same list used in [https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias](https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias).
Additional training options:
- `--droprate` is the drop rate of biasing words appeared in the reference text to avoid over-confidence
- `--maxsize` is the size of the biasing list used for training, which is the sum of biasing words in reference and distractors
Sample SLURM command:
```
srun --cpus-per-task=16 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python train.py --exp-dir <Path_to_exp> --librispeech-path <Path_to_librispeech_data> --sp-model-path ./spm_unigram_600_100suffix.model --biasing --biasing-list ./blists/rareword_f15.txt --droprate 0.1 --maxsize 200 --epochs 90
```
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Conformer RNN-T model with TCPGen on LibriSpeech test-clean.
Additional decoding options:
- `--biasing-list` should be [`all_rare_words.txt`](blists/all_rare_words.txt) for Librispeech experiments
- `--droprate` normally should be 0 because we assume the reference biasing words are included
- `--maxsize` is the size of the biasing list used for decoding, where 1000 was used in the paper.
Sample SLURM command:
```
srun --cpus-per-task=16 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python eval.py --checkpoint-path <Path_to_model_checkpoint> --librispeech-path <Path_to_librispeech_data> --sp-model-path ./spm_unigram_600_100suffix.model --expdir <Path_to_exp> --use-cuda --biasing --biasing-list ./blists/all_rare_words.txt --droprate 0.0 --maxsize 1000
```
### Scoring
Need to install SCTK, the NIST Scoring Toolkit first following: [https://github.com/usnistgov/SCTK/blob/master/README.md](https://github.com/usnistgov/SCTK/blob/master/README.md)
Example scoring script using sclite is in [`score.sh`](./score.sh).
```
./score.sh <path_to_decoding_dir>
```
Note that this will generate a file named `results.wrd.txt` which is in the format that will be used in the following script to calculate rare word error rate. Follow these steps to calculate rare word error rate:
```bash
cd error_analysis
python get_error_word_count.py <path_to_results.wrd.txt>
```
Note that the `word_freq.txt` file contains word frequencies for train-clean-100 only. For the full set it should be calculated again, which will only slightly affect OOV word error rate calculation in this case.
The table below contains WER results for the test-clean sets using clean-100 training data. R-WER stands for rare word error rate, for words in the biasing list.
| | WER | R-WER |
|:-------------------:|-------------:|-----------:|
| test-clean | 0.0836 | 0.2366|
This is the default directory where rare word list files should be found.
To train or evaluate a model, please download the following files, and save them here.
- [`rareword_f15.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f15.txt)
- [`rareword_f30.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f30.txt)
- [`all_rare_words.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/all_rare_words.txt)
import os
import random
import torch
import torchaudio
from pytorch_lightning import LightningDataModule
def _batch_by_token_count(idx_target_lengths, max_tokens, batch_size=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > max_tokens or (batch_size and len(current_batch) == batch_size):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}
def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
return [_target_length(fileid) for fileid in librispeech_dataset._walker]
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset,
lengths,
max_tokens,
num_buckets,
shuffle=False,
batch_size=None,
):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_tokens >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets],
max_tokens,
batch_size=batch_size,
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn
def __getitem__(self, idx):
return self.transform_fn(self.dataset[idx])
def __len__(self):
return len(self.dataset)
class LibriSpeechDataModule(LightningDataModule):
librispeech_cls = torchaudio.datasets.LibriSpeechBiasing
def __init__(
self,
*,
librispeech_path,
train_transform,
val_transform,
test_transform,
max_tokens=3200,
batch_size=16,
train_num_buckets=50,
train_shuffle=True,
num_workers=10,
subset=None,
fullbiasinglist=None,
):
super().__init__()
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
self.max_tokens = max_tokens
self.batch_size = batch_size
self.train_num_buckets = train_num_buckets
self.train_shuffle = train_shuffle
self.num_workers = num_workers
if subset is not None and subset != "train-clean-100":
raise ValueError('subset must be ``None`` or `"train-clean-100"`. Found: {subset}')
self.subset = subset
self.fullbiasinglist = fullbiasinglist or []
def train_dataloader(self):
if self.subset is None:
datasets = [
self.librispeech_cls(self.librispeech_path, url="train-clean-360"),
self.librispeech_cls(self.librispeech_path, url="train-clean-100"),
self.librispeech_cls(self.librispeech_path, url="train-other-500"),
]
elif self.subset == "train-clean-100":
datasets = [self.librispeech_cls(self.librispeech_path, url="train-clean-100", blist=self.fullbiasinglist)]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
)
return dataloader
def val_dataloader(self):
datasets = [
self.librispeech_cls(self.librispeech_path, url="dev-clean", blist=self.fullbiasinglist),
self.librispeech_cls(self.librispeech_path, url="dev-other", blist=self.fullbiasinglist),
]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
def test_dataloader(self):
dataset = self.librispeech_cls(self.librispeech_path, url="test-clean", blist=self.fullbiasinglist)
dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
import sys
error_words_freqs = {}
infile = sys.argv[1]
# setname = sys.argv[2]
insert_error = 0
insert_rare = 0
freqlist_test = {}
freqlist = {}
# TODO: Change to path to your word frequency file
with open("word_freq.txt") as fin:
for line in fin:
word, freq = line.split()
freqlist[word.upper()] = int(freq)
with open("../blists/all_rare_words.txt") as fin:
rareset = set()
for line in fin:
rareset.add(line.strip().upper())
project_set = set()
with open(infile) as fin:
lines = fin.readlines()
for i, line in enumerate(lines):
if line.startswith("id:"):
project = line.strip(")\n").split("-")[-3:]
project = "-".join(project)
if "REF:" in line:
nextline = lines[i + 1].split()
for j, word in enumerate(line.split()):
if "*" in word:
insert_error += 1
if nextline[j].upper() in rareset:
insert_rare += 1
line = line.replace("*", "")
line.replace("%BCACK", "")
for word in line.split()[1:]:
if not word.startswith("("):
if word.upper() not in freqlist_test:
freqlist_test[word.upper()] = 1
else:
freqlist_test[word.upper()] += 1
if word != word.lower() and word.upper() in error_words_freqs:
error_words_freqs[word.upper()] += 1
elif word != word.lower() and word.upper() not in error_words_freqs:
error_words_freqs[word.upper()] = 1
elif word == word.lower() and word.upper() not in error_words_freqs:
if word == word.upper():
print("special token found in: {}".format(project))
error_words_freqs[word.upper()] = 0
elif word == word.upper():
print("special token found in: {}".format(project))
print(len(error_words_freqs.keys()))
print(insert_rare)
commonwords = []
rarewords = []
oovwords = []
common_freq = 0
rare_freq = 0
oov_freq = 0
common_error = 0
rare_error = 0
oov_error = 0
partial_error = 0
partial_freq = 0
very_common_error = 0
very_common_words = 0
words_error_freq = {}
words_total_freq = {}
low_freq_error = 0
low_freq_total = 0
for word, error in error_words_freqs.items():
if word in rareset:
rarewords.append(word)
rare_freq += freqlist_test[word]
rare_error += error
elif word not in freqlist:
oovwords.append(word)
oov_freq += freqlist_test[word] if word in freqlist_test else 1
oov_error += error
else:
commonwords.append(word)
common_freq += freqlist_test[word]
common_error += error
total_words = common_freq + rare_freq + oov_freq
insert_common = insert_error - insert_rare
total_errors = common_error + rare_error + oov_error + insert_error
WER = total_errors / total_words
print("=" * 89)
print(
"Common words error freq: {} / {} = {}".format(
common_error + insert_common, common_freq, (common_error + insert_common) / common_freq
)
)
print(
"Rare words error freq: {} / {} = {}".format(
rare_error + insert_rare, rare_freq, (rare_error + insert_rare) / rare_freq
)
)
print("OOV words error freq: {} / {} = {}".format(oov_error, oov_freq, oov_error / max(oov_freq, 1)))
print("WER estimate: {} / {} = {}".format(total_errors, total_words, WER))
print(
"Insert error: {} / {} = {}".format(
insert_error - insert_rare, total_words, (insert_error - insert_rare) / total_words
)
)
print("Insertion + OOV error {}".format((insert_error + oov_error - insert_rare) / total_words))
print("=" * 89)
import logging
import os
import pathlib
from argparse import ArgumentParser
import torch
import torchaudio
from lightning import ConformerRNNTModule
from transforms import get_data_module
logger = logging.getLogger()
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def run_eval(args):
model = ConformerRNNTModule.load_from_checkpoint(
args.checkpoint_path, sp_model=str(args.sp_model_path), biasing=args.biasing
).eval()
data_module = get_data_module(
str(args.librispeech_path),
str(args.global_stats_path),
str(args.sp_model_path),
biasinglist=args.biasing_list,
droprate=args.droprate,
maxsize=args.maxsize,
)
if args.use_cuda:
model = model.to(device="cuda")
total_edit_distance = 0
total_length = 0
dataloader = data_module.test_dataloader()
hypout = []
refout = []
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
filename = "librispeech_clean_100_{}".format(idx)
actual = sample[0][2]
predicted = model(batch)
hypout.append("{} ({})\n".format(predicted.upper().strip(), filename))
refout.append("{} ({})\n".format(actual.upper().strip(), filename))
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.warning(f"Final WER: {total_edit_distance / total_length}")
with open(os.path.join(args.expdir, "hyp.trn.txt"), "w") as fout:
fout.writelines(hypout)
with open(os.path.join(args.expdir, "ref.trn.txt"), "w") as fout:
fout.writelines(refout)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
required=True,
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats_100.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--expdir",
type=pathlib.Path,
help="Output path.",
required=True,
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
parser.add_argument(
"--biasing-list",
type=str,
default="",
help="Path to the biasing list used for inference.",
)
parser.add_argument(
"--droprate",
type=float,
default=0.0,
help="biasing list true entry drop rate",
)
parser.add_argument(
"--maxsize",
type=int,
default=0,
help="biasing list size",
)
parser.add_argument(
"--biasing",
type=str,
help="Use biasing",
)
args = parser.parse_args()
run_eval(args)
if __name__ == "__main__":
cli_main()
{
"mean": [
15.637551307678223,
16.923139572143555,
16.822391510009766,
16.71995735168457,
17.798818588256836,
17.773509979248047,
17.83729362487793,
18.358478546142578,
17.9212646484375,
17.89328956604004,
17.39158821105957,
17.29935646057129,
17.368602752685547,
17.506956100463867,
17.485977172851562,
17.350055694580078,
17.144203186035156,
16.917232513427734,
16.664018630981445,
16.391685485839844,
16.14568328857422,
15.940634727478027,
15.770298957824707,
15.61804485321045,
15.464021682739258,
15.357192039489746,
15.381829261779785,
15.079191207885742,
15.000809669494629,
15.170866012573242,
14.815556526184082,
14.997357368469238,
14.849116325378418,
15.036391258239746,
14.997495651245117,
15.095179557800293,
15.101740837097168,
15.14519214630127,
15.180743217468262,
15.156403541564941,
15.10532283782959,
15.052924156188965,
15.000785827636719,
14.922870635986328,
14.845956802368164,
14.808100700378418,
14.871763229370117,
14.811545372009277,
14.901408195495605,
14.831536293029785,
14.855134010314941,
14.778738975524902,
14.771122932434082,
14.732138633728027,
14.647477149963379,
14.561445236206055,
14.481091499328613,
14.422819137573242,
14.360650062561035,
14.320985794067383,
14.269508361816406,
14.194745063781738,
14.015007972717285,
13.864882469177246,
13.71847915649414,
13.572352409362793,
13.416520118713379,
13.26490306854248,
13.141557693481445,
13.059102058410645,
12.958319664001465,
12.893214225769043,
12.862544059753418,
12.809120178222656,
12.764240264892578,
12.69538688659668,
12.663688659667969,
12.577835083007812,
12.516019821166992,
12.431295394897461
],
"invstddev": [
0.2907441258430481,
0.2902139723300934,
0.2692812979221344,
0.26928815245628357,
0.23503832519054413,
0.2255384773015976,
0.21647363901138306,
0.20862863957881927,
0.2129201591014862,
0.21450363099575043,
0.21588537096977234,
0.21638962626457214,
0.21542198956012726,
0.2138291299343109,
0.21290543675422668,
0.21341055631637573,
0.2152717262506485,
0.21785758435726166,
0.22014079988002777,
0.22229592502117157,
0.2258182317018509,
0.229590505361557,
0.23325258493423462,
0.23702110350131989,
0.24055296182632446,
0.24432097375392914,
0.2479754239320755,
0.25079545378685,
0.25225189328193665,
0.2529183626174927,
0.2534293234348297,
0.25444915890693665,
0.25508981943130493,
0.26440101861953735,
0.27133750915527344,
0.2644997835159302,
0.257159560918808,
0.2571694552898407,
0.25754109025001526,
0.2585192620754242,
0.2599046528339386,
0.26100826263427734,
0.2632879614830017,
0.2646331489086151,
0.26438355445861816,
0.263375461101532,
0.2620519995689392,
0.259790301322937,
0.25840267539024353,
0.2579191029071808,
0.25915762782096863,
0.2594946622848511,
0.260376900434494,
0.2614041864871979,
0.262315034866333,
0.26391837000846863,
0.26464715600013733,
0.2644098997116089,
0.26398345828056335,
0.2636142671108246,
0.2644863724708557,
0.26914820075035095,
0.2653448283672333,
0.2650754153728485,
0.2656823992729187,
0.2668116092681885,
0.2679266929626465,
0.2683681547641754,
0.26889851689338684,
0.2706693708896637,
0.27155616879463196,
0.2738206088542938,
0.27583277225494385,
0.2766948342323303,
0.27835148572921753,
0.2804813086986542,
0.2821349799633026,
0.2801983952522278,
0.2796038091182709,
0.27942603826522827
]
}
import logging
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from pytorch_lightning import LightningModule
from torchaudio.prototype.models import conformer_rnnt_biasing_base, Hypothesis, RNNTBeamSearchBiasing
logger = logging.getLogger()
_expected_spm_vocab_size = 600
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths", "tries"])
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
r"""Learning rate scheduler that performs linear warmup and exponential annealing.
Args:
optimizer (torch.optim.Optimizer): optimizer to use.
warmup_steps (int): number of scheduler steps for which to warm up learning rate.
force_anneal_step (int): scheduler step at which annealing of learning rate begins.
anneal_factor (float): factor to scale base learning rate by at each annealing step.
last_epoch (int, optional): The index of last epoch. (Default: -1)
verbose (bool, optional): If ``True``, prints a message to stdout for
each update. (Default: ``False``)
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
force_anneal_step: int,
anneal_factor: float,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_steps
self.force_anneal_step = force_anneal_step
self.anneal_factor = anneal_factor
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count < self.force_anneal_step:
return [(min(1.0, self._step_count / self.warmup_steps)) * base_lr for base_lr in self.base_lrs]
else:
scaling_factor = self.anneal_factor ** (self._step_count - self.force_anneal_step)
return [scaling_factor * base_lr for base_lr in self.base_lrs]
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class ConformerRNNTModule(LightningModule):
def __init__(self, sp_model, biasing=False):
super().__init__()
self.sp_model = sp_model
self.sp_model = spm.SentencePieceProcessor(model_file=self.sp_model)
spm_vocab_size = self.sp_model.get_piece_size()
self.char_list = [self.sp_model.id_to_piece(idx) for idx in range(spm_vocab_size)]
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
self.char_list.append("<blank>")
# ``conformer_rnnt_biasing_base`` hardcodes a specific Conformer RNN-T configuration.
# For greater customizability, please refer to ``conformer_rnnt_biasing``.
self.biasing = biasing
self.model = conformer_rnnt_biasing_base(charlist=self.char_list, biasing=self.biasing)
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", fused_log_softmax=False)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9)
# This scheduler is for clean 100 and train 90 epochs, should change it when running longer
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 35, 60, 0.92)
# The epoch from which the TCPGen starts to train
self.tcpsche = self.model.tcpsche
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _, tcpgen_dist, p_gen = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
batch.tries,
self.current_epoch,
)
if self.biasing and self.current_epoch >= self.tcpsche and p_gen is not None:
# Assuming blank is the last token
model_output = torch.softmax(output, dim=-1)
p_not_null = 1.0 - model_output[:, :, :, -1:]
# Exclude blank prob
ptr_dist_fact = torch.cat([tcpgen_dist[:, :, :, :-2], tcpgen_dist[:, :, :, -1:]], dim=-1) * p_not_null
ptr_gen_complement = tcpgen_dist[:, :, :, -1:] * p_gen
# Interpolate between TPGen distribution and model distribution
p_partial = ptr_dist_fact[:, :, :, :-1] * p_gen + model_output[:, :, :, :-1] * (
1 - p_gen + ptr_gen_complement
)
# Add blank back
p_final = torch.cat([p_partial, model_output[:, :, :, -1:]], dim=-1)
# Numerical stability? Didn't need to do this in Espnet
logsmax_output = torch.log(p_final + 1e-12)
else:
logsmax_output = torch.log_softmax(output, dim=-1)
loss = self.loss(logsmax_output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True, batch_size=batch.targets.size(0))
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": "epoch"}],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearchBiasing(self.model, self.blank_idx, trie=batch.tries, biasing=self.biasing)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 10)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.features.size(0)
batch_sizes = self.all_gather(batch_size)
self.log(
"Gathered batch size",
batch_sizes.sum(),
on_step=True,
on_epoch=True,
batch_size=batch.targets.size(0),
)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.0)
opt.step()
# step every epoch
sch = self.lr_schedulers()
if self.trainer.is_last_batch:
sch.step()
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
#!/usr/bin/env bash
if [ $# -ne 1 ]; then
echo "Usage: $0 <DECODING_DIR>"
exit 1
fi
dir=$1 # the path to the decoding dir, e.g. experiments/librispeech_clean100_suffix600_tcpgen500_sche30_nodrop/decode_test_clean_b10_KB1000/
sclite -r "${dir}/ref.trn.txt" trn -h "${dir}/hyp.trn.txt" trn -i rm -o all stdout > "${dir}/result.wrd.txt"
import os
import pathlib
from argparse import ArgumentParser
from lightning import ConformerRNNTModule
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from transforms import get_data_module
def run_train(args):
seed_everything(1)
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [
checkpoint,
train_checkpoint,
lr_monitor,
]
if os.path.exists(args.resume) and args.resume != "":
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.nodes,
gpus=args.gpus,
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
resume_from_checkpoint=args.resume,
)
else:
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.nodes,
gpus=args.gpus,
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
model = ConformerRNNTModule(str(args.sp_model_path), args.biasing)
data_module = get_data_module(
str(args.librispeech_path),
str(args.global_stats_path),
str(args.sp_model_path),
subset=args.subset,
biasinglist=args.biasing_list,
droprate=args.droprate,
maxsize=args.maxsize,
)
trainer.fit(model, data_module, ckpt_path=args.checkpoint_path)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
default=None,
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats_100.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--nodes",
default=1,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=1,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=120,
type=int,
help="Number of epochs to train for. (Default: 120)",
)
parser.add_argument(
"--subset",
default="train-clean-100",
type=str,
help="Train on subset of librispeech.",
)
parser.add_argument(
"--biasing",
action="store_true",
help="Use biasing",
)
parser.add_argument(
"--biasing-list",
type=pathlib.Path,
help="Path to the biasing list.",
required=True,
)
parser.add_argument("--maxsize", default=1000, type=int, help="Size of biasing lists")
parser.add_argument("--droprate", default=0.0, type=float, help="Biasing component regularisation drop rate")
parser.add_argument(
"--resume",
default="",
type=str,
help="Path to resume model.",
)
args = parser.parse_args()
run_train(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500.
Using unigram wordpiece model and suffix-based wordpieces
Example:
python train_spm.py --librispeech-path ./datasets
"""
import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
def get_transcript_text(transcript_path):
with open(transcript_path) as f:
return [line.strip().split(" ", 1)[1].lower() for line in f]
def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*/*.trans.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts
def train_spm(input, suffix=False):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=600,
model_type="unigram",
input_sentence_size=-1,
character_coverage=1.0,
treat_whitespace_as_suffix=suffix,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def parse_args():
default_output_path = "./spm_unigram_600_100suffix.model"
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--librispeech-path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path(default_output_path),
type=pathlib.Path,
help=f"File to save model to. (Default: '{default_output_path}')",
)
parser.add_argument(
"--suffix",
action="store_true",
help="whether to use suffix-based wordpieces",
)
return parser.parse_args()
def run_cli():
args = parse_args()
root = args.librispeech_path / "LibriSpeech"
# Uncomment this for running bpe on full 960-hour data
# splits = ["train-clean-100", "train-clean-360", "train-other-500"]
splits = ["train-clean-100"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)
model = train_spm(merged_transcripts, suffix=args.suffix)
with open(args.output_file, "wb") as f:
f.write(model)
if __name__ == "__main__":
run_cli()
import json
import math
import random
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
from data_module import LibriSpeechDataModule
from lightning import Batch
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
random.seed(999)
def _piecewise_linear_log(x):
x = x * _gain
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
def _extract_labels(sp_model, samples: List):
targets = [sp_model.encode(sample[2].lower()) for sample in samples]
biasingwords = []
for sample in samples:
for word in sample[6]:
if word not in biasingwords:
biasingwords.append(word)
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths, biasingwords
def _extract_features(data_pipeline, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _extract_tries(sp_model, biasingwords, blist, droprate, maxsize):
if len(biasingwords) > 0 and droprate > 0:
biasingwords = random.sample(biasingwords, k=int(len(biasingwords) * (1 - droprate)))
if len(biasingwords) < maxsize:
distractors = random.sample(blist, k=max(0, maxsize - len(biasingwords)))
for word in distractors:
if word not in biasingwords:
biasingwords.append(word)
biasingwords = [sp_model.encode(word.lower()) for word in biasingwords]
biasingwords = sorted(biasingwords)
worddict = {tuple(word): i + 1 for i, word in enumerate(biasingwords)}
lextree = make_lexical_tree(worddict, -1)
return lextree, biasingwords
def make_lexical_tree(word_dict, word_unk):
"""Make a prefix tree"""
# node [dict(subword_id -> node), word_id, word_set[start-1, end]]
root = [{}, -1, None]
for w, wid in word_dict.items():
if wid > 0 and wid != word_unk:
succ = root[0]
for i, cid in enumerate(w):
if cid not in succ:
succ[cid] = [{}, -1, (wid - 1, wid)]
else:
prev = succ[cid][2]
succ[cid][2] = (min(prev[0], wid - 1), max(prev[1], wid))
if i == len(w) - 1:
succ[cid][1] = wid
succ = succ[cid][0]
return root
class TrainTransform:
def __init__(self, global_stats_path: str, sp_model_path: str, blist: List[str], droprate: float, maxsize: int):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.blist = blist
self.droprate = droprate
self.maxsize = maxsize
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.train_data_pipeline, samples)
targets, target_lengths, biasingwords = _extract_labels(self.sp_model, samples)
tries, biasingwords = _extract_tries(self.sp_model, biasingwords, self.blist, self.droprate, self.maxsize)
return Batch(features, feature_lengths, targets, target_lengths, tries)
class ValTransform:
def __init__(self, global_stats_path: str, sp_model_path: str, blist: List[str], droprate: float, maxsize: int):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
)
self.blist = blist
self.droprate = droprate
self.maxsize = maxsize
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.valid_data_pipeline, samples)
targets, target_lengths, biasingwords = _extract_labels(self.sp_model, samples)
if self.blist:
tries, biasingwords = _extract_tries(self.sp_model, biasingwords, self.blist, self.droprate, self.maxsize)
else:
tries = []
return Batch(features, feature_lengths, targets, target_lengths, tries)
class TestTransform:
def __init__(self, global_stats_path: str, sp_model_path: str, blist: List[str], droprate: float, maxsize: int):
self.val_transforms = ValTransform(global_stats_path, sp_model_path, blist, droprate, maxsize)
def __call__(self, sample):
return self.val_transforms([sample]), [sample]
def get_data_module(
librispeech_path, global_stats_path, sp_model_path, subset=None, biasinglist=None, droprate=0.0, maxsize=1000
):
fullbiasinglist = []
if biasinglist:
with open(biasinglist) as fin:
fullbiasinglist = [line.strip() for line in fin]
train_transform = TrainTransform(
global_stats_path=global_stats_path,
sp_model_path=sp_model_path,
blist=fullbiasinglist,
droprate=droprate,
maxsize=maxsize,
)
val_transform = ValTransform(
global_stats_path=global_stats_path,
sp_model_path=sp_model_path,
blist=fullbiasinglist,
droprate=droprate,
maxsize=maxsize,
)
test_transform = TestTransform(
global_stats_path=global_stats_path,
sp_model_path=sp_model_path,
blist=fullbiasinglist,
droprate=droprate,
maxsize=maxsize,
)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
subset=subset,
fullbiasinglist=fullbiasinglist,
)
...@@ -8,6 +8,7 @@ from .iemocap import IEMOCAP ...@@ -8,6 +8,7 @@ from .iemocap import IEMOCAP
from .librilight_limited import LibriLightLimited from .librilight_limited import LibriLightLimited
from .librimix import LibriMix from .librimix import LibriMix
from .librispeech import LIBRISPEECH from .librispeech import LIBRISPEECH
from .librispeech_biasing import LibriSpeechBiasing
from .libritts import LIBRITTS from .libritts import LIBRITTS
from .ljspeech import LJSPEECH from .ljspeech import LJSPEECH
from .musdb_hq import MUSDB_HQ from .musdb_hq import MUSDB_HQ
...@@ -23,6 +24,7 @@ from .yesno import YESNO ...@@ -23,6 +24,7 @@ from .yesno import YESNO
__all__ = [ __all__ = [
"COMMONVOICE", "COMMONVOICE",
"LIBRISPEECH", "LIBRISPEECH",
"LibriSpeechBiasing",
"LibriLightLimited", "LibriLightLimited",
"SPEECHCOMMANDS", "SPEECHCOMMANDS",
"VCTK_092", "VCTK_092",
......
import os
from pathlib import Path
from typing import List, Tuple, Union
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _extract_tar, _load_waveform
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech"
SAMPLE_RATE = 16000
_DATA_SUBSETS = [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]
_CHECKSUMS = {
"http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", # noqa: E501
"http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", # noqa: E501
"http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", # noqa: E501
"http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29", # noqa: E501
"http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2", # noqa: E501
"http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf", # noqa: E501
"http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2", # noqa: E501
}
def _download_librispeech(root, url):
base_url = "http://www.openslr.org/resources/12/"
ext_archive = ".tar.gz"
filename = url + ext_archive
archive = os.path.join(root, filename)
download_url = os.path.join(base_url, filename)
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(download_url, None)
download_url_to_file(download_url, archive, hash_prefix=checksum)
_extract_tar(archive)
def _get_librispeech_metadata(
fileid: str, root: str, folder: str, ext_audio: str, ext_txt: str, blist: List[str]
) -> Tuple[str, int, str, int, int, int]:
blist = blist or []
speaker_id, chapter_id, utterance_id = fileid.split("-")
# Get audio path and sample rate
fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
filepath = os.path.join(folder, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}")
# Load text
file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
file_text = os.path.join(root, folder, speaker_id, chapter_id, file_text)
uttblist = []
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
if fileid_audio == fileid_text:
# get utterance biasing list
for word in transcript.split():
if word in blist and word not in uttblist:
uttblist.append(word)
break
else:
# Translation not found
raise FileNotFoundError(f"Translation not found for {fileid_audio}")
return (
filepath,
SAMPLE_RATE,
transcript,
int(speaker_id),
int(chapter_id),
int(utterance_id),
uttblist,
)
class LibriSpeechBiasing(Dataset):
"""*LibriSpeech* :cite:`7178964` dataset with prefix-tree construction and biasing support.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload.
Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
``"train-other-500"``. (default: ``"train-clean-100"``)
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"LibriSpeech"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
blist (list, optional):
The list of biasing words (default: ``[]``).
"""
_ext_txt = ".trans.txt"
_ext_audio = ".flac"
def __init__(
self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
blist: List[str] = None,
) -> None:
self._url = url
if url not in _DATA_SUBSETS:
raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.")
root = os.fspath(root)
self._archive = os.path.join(root, folder_in_archive)
self._path = os.path.join(root, folder_in_archive, url)
if not os.path.isdir(self._path):
if download:
_download_librispeech(root, url)
else:
raise RuntimeError(
f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
)
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
self.blist = blist
def get_metadata(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:func:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
str:
Path to audio
int:
Sample rate
str:
Transcript
int:
Speaker ID
int:
Chapter ID
int:
Utterance ID
list:
List of biasing words in the utterance
"""
fileid = self._walker[n]
return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt, self.blist)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
Tensor:
Waveform
int:
Sample rate
str:
Transcript
int:
Speaker ID
int:
Chapter ID
int:
Utterance ID
list:
List of biasing words in the utterance
"""
metadata = self.get_metadata(n)
waveform = _load_waveform(self._archive, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self) -> int:
return len(self._walker)
...@@ -9,12 +9,16 @@ from ._conformer_wav2vec2 import ( ...@@ -9,12 +9,16 @@ from ._conformer_wav2vec2 import (
from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing
from .squim import squim_objective_base, squim_objective_model, SquimObjective from .squim import squim_objective_base, squim_objective_model, SquimObjective
__all__ = [ __all__ = [
"conformer_rnnt_base", "conformer_rnnt_base",
"conformer_rnnt_model", "conformer_rnnt_model",
"conformer_rnnt_biasing",
"conformer_rnnt_biasing_base",
"conv_tasnet_base",
"ConvEmformer", "ConvEmformer",
"conformer_wav2vec2_model", "conformer_wav2vec2_model",
"conformer_wav2vec2_base", "conformer_wav2vec2_base",
...@@ -24,6 +28,8 @@ __all__ = [ ...@@ -24,6 +28,8 @@ __all__ = [
"ConformerWav2Vec2PretrainModel", "ConformerWav2Vec2PretrainModel",
"emformer_hubert_base", "emformer_hubert_base",
"emformer_hubert_model", "emformer_hubert_model",
"Hypothesis",
"RNNTBeamSearchBiasing",
"HiFiGANVocoder", "HiFiGANVocoder",
"hifigan_vocoder_v1", "hifigan_vocoder_v1",
"hifigan_vocoder_v2", "hifigan_vocoder_v2",
......
This diff is collapsed.
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torchaudio.models import RNNT
from torchaudio.prototype.models.rnnt import TrieNode
__all__ = ["Hypothesis", "RNNTBeamSearchBiasing"]
Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float, list]
Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
represented as tuple of (tokens, prediction network output, prediction network state, score).
"""
def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
return hypo[0]
def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
return hypo[1]
def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
return hypo[2]
def _get_hypo_score(hypo: Hypothesis) -> float:
return hypo[3]
def _get_hypo_trie(hypo: Hypothesis) -> TrieNode:
return hypo[4]
def _set_hypo_trie(hypo: Hypothesis, trie: TrieNode) -> None:
hypo[4] = trie
def _get_hypo_key(hypo: Hypothesis) -> str:
return str(hypo[0])
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
states: List[List[torch.Tensor]] = []
for i in range(len(_get_hypo_state(hypos[0]))):
batched_state_components: List[torch.Tensor] = []
for j in range(len(_get_hypo_state(hypos[0])[i])):
batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
states.append(batched_state_components)
return states
def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
idx_tensor = torch.tensor([idx], device=device)
return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
def _default_hypo_sort_key(hypo: Hypothesis) -> float:
return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
def _compute_updated_scores(
hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
beam_width: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
for i, elem in enumerate(hypo_list):
if _get_hypo_key(hypo) == _get_hypo_key(elem):
del hypo_list[i]
break
class RNNTBeamSearchBiasing(torch.nn.Module):
r"""Beam search decoder for RNN-T model with biasing support.
Args:
model (RNNT): RNN-T model to use.
blank (int): index of blank token in vocabulary.
temperature (float, optional): temperature to apply to joint network output.
Larger values yield more uniform samples. (Default: 1.0)
hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
hypothesis score normalized by token sequence length. (Default: None)
step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
trie (list, optional): the prefix tree for TCPGen biasing
biasing (bool, optional): If true, do biasing, otherwise use standard RNN-T support
"""
def __init__(
self,
model: RNNT,
blank: int,
temperature: float = 1.0,
hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
step_max_tokens: int = 100,
trie: TrieNode = None,
biasing: bool = False,
) -> None:
super().__init__()
self.model = model
self.blank = blank
self.temperature = temperature
self.resettrie = trie or []
self.dobiasing = biasing
if hypo_sort_key is None:
self.hypo_sort_key = _default_hypo_sort_key
else:
self.hypo_sort_key = hypo_sort_key
self.step_max_tokens = step_max_tokens
def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
if hypo is not None:
token = _get_hypo_tokens(hypo)[-1]
state = _get_hypo_state(hypo)
else:
token = self.blank
state = None
one_tensor = torch.tensor([1], device=device)
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
init_hypo = ([token], pred_out[0].detach(), pred_state, 0.0, self.resettrie)
return [init_hypo]
def _get_trie_mask(self, trie):
step_mask = torch.ones(len(self.model.char_list) + 1)
step_mask[list(trie[0].keys())] = 0
# step_mask[-1] = 0
return step_mask
def _get_generation_prob(self, trie):
if len(trie[0].keys()) == 0:
return True
else:
return False
def _gen_next_token_probs(
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
) -> torch.Tensor:
one_tensor = torch.tensor([1], device=device)
predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
if self.dobiasing:
# Get valid subset of wordpieces
trie_masks = torch.stack([self._get_trie_mask(_get_hypo_trie(h)) for h in hypos], dim=0)
trie_masks = trie_masks.to(enc_out.device).unsqueeze(1) # beam_width, 1, nchars
# Determine if there is any paths on the trie
genprob_masks = torch.tensor([self._get_generation_prob(_get_hypo_trie(h)) for h in hypos]) # beam_width
genprob_masks = genprob_masks.to(enc_out.device)
# Forward TCPGen component
last_tokens = torch.tensor([_get_hypo_tokens(h)[-1] for h in hypos]).unsqueeze(-1).to(enc_out.device)
hptr, tcpgen_dist = self.model.forward_tcpgen(last_tokens, trie_masks, enc_out)
else:
hptr = None
# hptr sent to joiner, if deepbiasing is True joiner will use it
joined_out, _, joined_activation = self.model.join(
enc_out,
one_tensor,
predictor_out,
torch.tensor([1] * len(hypos), device=device),
hptr=hptr,
) # [beam_width, 1, 1, num_tokens]
if self.dobiasing:
p_gen = torch.sigmoid(self.model.pointer_gate(torch.cat((joined_activation, hptr), dim=-1)))
p_gen = p_gen.masked_fill(genprob_masks.view(p_gen.size(0), 1, 1, 1), 0)
model_tu = torch.softmax(joined_out / self.temperature, dim=3)
# assuming last token is blank
p_not_null = 1.0 - model_tu[:, :, :, -1:]
ptr_dist_fact = torch.cat([tcpgen_dist[:, :, :, :-2], tcpgen_dist[:, :, :, -1:]], dim=-1) * p_not_null
ptr_gen_complement = tcpgen_dist[:, :, :, -1:] * p_gen
p_partial = ptr_dist_fact[:, :, :, :-1] * p_gen + model_tu[:, :, :, :-1] * (1 - p_gen + ptr_gen_complement)
p_final = torch.cat([p_partial, model_tu[:, :, :, -1:]], dim=-1)
joined_out = torch.log(p_final)
else:
joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
return joined_out[:, 0, 0]
def _gen_b_hypos(
self,
b_hypos: List[Hypothesis],
a_hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
key_to_b_hypo: Dict[str, Hypothesis],
) -> List[Hypothesis]:
for i in range(len(a_hypos)):
h_a = a_hypos[i]
append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
if _get_hypo_key(h_a) in key_to_b_hypo:
h_b = key_to_b_hypo[_get_hypo_key(h_a)]
_remove_hypo(h_b, b_hypos)
score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
else:
score = float(append_blank_score)
h_b = (
_get_hypo_tokens(h_a),
_get_hypo_predictor_out(h_a),
_get_hypo_state(h_a),
score,
_get_hypo_trie(h_a),
)
b_hypos.append(h_b)
key_to_b_hypo[_get_hypo_key(h_b)] = h_b
_, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
return [b_hypos[idx] for idx in sorted_idx]
def _gen_a_hypos(
self,
a_hypos: List[Hypothesis],
b_hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
t: int,
beam_width: int,
device: torch.device,
) -> List[Hypothesis]:
(
nonblank_nbest_scores,
nonblank_nbest_hypo_idx,
nonblank_nbest_token,
) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
if len(b_hypos) < beam_width:
b_nbest_score = -float("inf")
else:
b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
base_hypos: List[Hypothesis] = []
new_tokens: List[int] = []
new_scores: List[float] = []
for i in range(beam_width):
score = float(nonblank_nbest_scores[i])
if score > b_nbest_score:
a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
base_hypos.append(a_hypos[a_hypo_idx])
new_tokens.append(int(nonblank_nbest_token[i]))
new_scores.append(score)
if base_hypos:
new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
else:
new_hypos: List[Hypothesis] = []
return new_hypos
def _gen_new_hypos(
self,
base_hypos: List[Hypothesis],
tokens: List[int],
scores: List[float],
t: int,
device: torch.device,
) -> List[Hypothesis]:
tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
states = _batch_state(base_hypos)
pred_out, _, pred_states = self.model.predict(
tgt_tokens,
torch.tensor([1] * len(base_hypos), device=device),
states,
)
new_hypos: List[Hypothesis] = []
for i, h_a in enumerate(base_hypos):
new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
if self.dobiasing:
new_trie = self.model.get_tcpgen_step(tokens[i], _get_hypo_trie(h_a), self.resettrie)
else:
new_trie = self.resettrie
new_hypos.append(
(new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i], new_trie)
)
return new_hypos
def _search(
self,
enc_out: torch.Tensor,
hypo: Optional[Hypothesis],
beam_width: int,
) -> List[Hypothesis]:
n_time_steps = enc_out.shape[1]
device = enc_out.device
a_hypos: List[Hypothesis] = []
b_hypos = self._init_b_hypos(hypo, device)
for t in range(n_time_steps):
a_hypos = b_hypos
b_hypos = torch.jit.annotate(List[Hypothesis], [])
key_to_b_hypo: Dict[str, Hypothesis] = {}
symbols_current_t = 0
while a_hypos:
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
next_token_probs = next_token_probs.cpu()
b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
if symbols_current_t == self.step_max_tokens:
break
a_hypos = self._gen_a_hypos(
a_hypos,
b_hypos,
next_token_probs,
t,
beam_width,
device,
)
if a_hypos:
symbols_current_t += 1
_, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
b_hypos = [b_hypos[idx] for idx in sorted_idx]
return b_hypos
def forward(
self,
input: torch.Tensor,
length: torch.Tensor,
beam_width: int,
) -> List[Hypothesis]:
r"""Performs beam search for the given input sequence.
T: number of frames;
D: feature dimension of each frame.
Args:
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
length (torch.Tensor): number of valid frames in input
sequence, with shape () or (1,).
beam_width (int): beam size to use during search.
Returns:
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
"""
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
raise ValueError("input must be of shape (T, D) or (1, T, D)")
if input.dim() == 2:
input = input.unsqueeze(0)
if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)")
if input.dim() == 0:
input = input.unsqueeze(0)
enc_out, _ = self.model.transcribe(input, length)
return self._search(enc_out, None, beam_width)
@torch.jit.export
def infer(
self,
input: torch.Tensor,
length: torch.Tensor,
beam_width: int,
state: Optional[List[List[torch.Tensor]]] = None,
hypothesis: Optional[Hypothesis] = None,
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
r"""Performs beam search for the given input sequence in streaming mode.
T: number of frames;
D: feature dimension of each frame.
Args:
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
length (torch.Tensor): number of valid frames in input
sequence, with shape () or (1,).
beam_width (int): beam size to use during search.
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
representing transcription network internal state generated in preceding
invocation. (Default: ``None``)
hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed
search with. (Default: ``None``)
Returns:
(List[Hypothesis], List[List[torch.Tensor]]):
List[Hypothesis]
top-``beam_width`` hypotheses found by beam search.
List[List[torch.Tensor]]
list of lists of tensors representing transcription network
internal state generated in current invocation.
"""
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
raise ValueError("input must be of shape (T, D) or (1, T, D)")
if input.dim() == 2:
input = input.unsqueeze(0)
if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)")
if length.dim() == 0:
length = length.unsqueeze(0)
enc_out, _, state = self.model.transcribe_streaming(input, length, state)
return self._search(enc_out, hypothesis, beam_width), state
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment