Commit ffeba11a authored by mayp777's avatar mayp777
Browse files

UPDATE

parent 29deb085
# Contextual Conformer RNN-T with TCPGen Example
This directory contains sample implementations of training and evaluation pipelines for the Conformer RNN-T model with tree-constrained pointer generator (TCPGen) for contextual biasing, as described in the paper: [Tree-Constrained Pointer Generator for End-to-End Contextual Speech Recognition](https://ieeexplore.ieee.org/abstract/document/9687915)
## Setup
### Install PyTorch and TorchAudio nightly or from source
Because Conformer RNN-T is currently a prototype feature, you will need to either use the TorchAudio nightly build or build TorchAudio from source. Note also that GPU support is required for training.
To install the nightly, follow the directions at <https://pytorch.org/>.
To build TorchAudio from source, refer to the [contributing guidelines](https://github.com/pytorch/audio/blob/main/CONTRIBUTING.md).
### Install additional dependencies
```bash
pip install pytorch-lightning sentencepiece
```
## Usage
### Training
[`train.py`](./train.py) trains an Conformer RNN-T model with TCPGen on LibriSpeech using PyTorch Lightning. Note that the script expects users to have the following:
- Access to GPU nodes for training.
- Full LibriSpeech dataset.
- SentencePiece model to be used to encode targets; the model can be generated using [`train_spm.py`](./train_spm.py). **Note that suffix-based wordpieces are used in this example**. [`run_spm.sh`](./run_spm.sh) will generate 600 suffix-based wordpieces which is used in the paper.
- File (--global_stats_path) that contains training set feature statistics; this file can be generated using [`global_stats.py`](../emformer_rnnt/global_stats.py). The [`global_stats_100.json`](./global_stats_100.json) has been generated for train-clean-100.
- Biasing list: Please download [`rareword_f15.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f15.txt), [`rareword_f30.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f30.txt) and [`all_rare_words.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/all_rare_words.txt) and put it in [`blists`](./blists) directory. See [`rareword_f15.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f15.txt) as an example for the biasing list used for training with clean-100 data. Words appeared less than 15 times were treated as rare words. For inference, [`all_rare_words.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/all_rare_words.txt) which is the same list used in [https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias](https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias).
Additional training options:
- `--droprate` is the drop rate of biasing words appeared in the reference text to avoid over-confidence
- `--maxsize` is the size of the biasing list used for training, which is the sum of biasing words in reference and distractors
Sample SLURM command:
```
srun --cpus-per-task=16 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python train.py --exp-dir <Path_to_exp> --librispeech-path <Path_to_librispeech_data> --sp-model-path ./spm_unigram_600_100suffix.model --biasing --biasing-list ./blists/rareword_f15.txt --droprate 0.1 --maxsize 200 --epochs 90
```
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Conformer RNN-T model with TCPGen on LibriSpeech test-clean.
Additional decoding options:
- `--biasing-list` should be [`all_rare_words.txt`](blists/all_rare_words.txt) for Librispeech experiments
- `--droprate` normally should be 0 because we assume the reference biasing words are included
- `--maxsize` is the size of the biasing list used for decoding, where 1000 was used in the paper.
Sample SLURM command:
```
srun --cpus-per-task=16 --gpus-per-node=1 -N 1 --ntasks-per-node=1 python eval.py --checkpoint-path <Path_to_model_checkpoint> --librispeech-path <Path_to_librispeech_data> --sp-model-path ./spm_unigram_600_100suffix.model --expdir <Path_to_exp> --use-cuda --biasing --biasing-list ./blists/all_rare_words.txt --droprate 0.0 --maxsize 1000
```
### Scoring
Need to install SCTK, the NIST Scoring Toolkit first following: [https://github.com/usnistgov/SCTK/blob/master/README.md](https://github.com/usnistgov/SCTK/blob/master/README.md)
Example scoring script using sclite is in [`score.sh`](./score.sh).
```
./score.sh <path_to_decoding_dir>
```
Note that this will generate a file named `results.wrd.txt` which is in the format that will be used in the following script to calculate rare word error rate. Follow these steps to calculate rare word error rate:
```bash
cd error_analysis
python get_error_word_count.py <path_to_results.wrd.txt>
```
Note that the `word_freq.txt` file contains word frequencies for train-clean-100 only. For the full set it should be calculated again, which will only slightly affect OOV word error rate calculation in this case.
The table below contains WER results for the test-clean sets using clean-100 training data. R-WER stands for rare word error rate, for words in the biasing list.
| | WER | R-WER |
|:-------------------:|-------------:|-----------:|
| test-clean | 0.0836 | 0.2366|
This is the default directory where rare word list files should be found.
To train or evaluate a model, please download the following files, and save them here.
- [`rareword_f15.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f15.txt)
- [`rareword_f30.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/rareword_f30.txt)
- [`all_rare_words.txt`](https://download.pytorch.org/torchaudio/pipeline-assets/tcpgen/all_rare_words.txt)
import os
import random
import torch
import torchaudio
from pytorch_lightning import LightningDataModule
def _batch_by_token_count(idx_target_lengths, max_tokens, batch_size=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > max_tokens or (batch_size and len(current_batch) == batch_size):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}
def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
return [_target_length(fileid) for fileid in librispeech_dataset._walker]
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset,
lengths,
max_tokens,
num_buckets,
shuffle=False,
batch_size=None,
):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_tokens >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets],
max_tokens,
batch_size=batch_size,
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn
def __getitem__(self, idx):
return self.transform_fn(self.dataset[idx])
def __len__(self):
return len(self.dataset)
class LibriSpeechDataModule(LightningDataModule):
librispeech_cls = torchaudio.datasets.LibriSpeechBiasing
def __init__(
self,
*,
librispeech_path,
train_transform,
val_transform,
test_transform,
max_tokens=3200,
batch_size=16,
train_num_buckets=50,
train_shuffle=True,
num_workers=10,
subset=None,
fullbiasinglist=None,
):
super().__init__()
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
self.max_tokens = max_tokens
self.batch_size = batch_size
self.train_num_buckets = train_num_buckets
self.train_shuffle = train_shuffle
self.num_workers = num_workers
if subset is not None and subset != "train-clean-100":
raise ValueError('subset must be ``None`` or `"train-clean-100"`. Found: {subset}')
self.subset = subset
self.fullbiasinglist = fullbiasinglist or []
def train_dataloader(self):
if self.subset is None:
datasets = [
self.librispeech_cls(self.librispeech_path, url="train-clean-360"),
self.librispeech_cls(self.librispeech_path, url="train-clean-100"),
self.librispeech_cls(self.librispeech_path, url="train-other-500"),
]
elif self.subset == "train-clean-100":
datasets = [self.librispeech_cls(self.librispeech_path, url="train-clean-100", blist=self.fullbiasinglist)]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
)
return dataloader
def val_dataloader(self):
datasets = [
self.librispeech_cls(self.librispeech_path, url="dev-clean", blist=self.fullbiasinglist),
self.librispeech_cls(self.librispeech_path, url="dev-other", blist=self.fullbiasinglist),
]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
def test_dataloader(self):
dataset = self.librispeech_cls(self.librispeech_path, url="test-clean", blist=self.fullbiasinglist)
dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
import sys
error_words_freqs = {}
infile = sys.argv[1]
# setname = sys.argv[2]
insert_error = 0
insert_rare = 0
freqlist_test = {}
freqlist = {}
# TODO: Change to path to your word frequency file
with open("word_freq.txt") as fin:
for line in fin:
word, freq = line.split()
freqlist[word.upper()] = int(freq)
with open("../blists/all_rare_words.txt") as fin:
rareset = set()
for line in fin:
rareset.add(line.strip().upper())
project_set = set()
with open(infile) as fin:
lines = fin.readlines()
for i, line in enumerate(lines):
if line.startswith("id:"):
project = line.strip(")\n").split("-")[-3:]
project = "-".join(project)
if "REF:" in line:
nextline = lines[i + 1].split()
for j, word in enumerate(line.split()):
if "*" in word:
insert_error += 1
if nextline[j].upper() in rareset:
insert_rare += 1
line = line.replace("*", "")
line.replace("%BCACK", "")
for word in line.split()[1:]:
if not word.startswith("("):
if word.upper() not in freqlist_test:
freqlist_test[word.upper()] = 1
else:
freqlist_test[word.upper()] += 1
if word != word.lower() and word.upper() in error_words_freqs:
error_words_freqs[word.upper()] += 1
elif word != word.lower() and word.upper() not in error_words_freqs:
error_words_freqs[word.upper()] = 1
elif word == word.lower() and word.upper() not in error_words_freqs:
if word == word.upper():
print("special token found in: {}".format(project))
error_words_freqs[word.upper()] = 0
elif word == word.upper():
print("special token found in: {}".format(project))
print(len(error_words_freqs.keys()))
print(insert_rare)
commonwords = []
rarewords = []
oovwords = []
common_freq = 0
rare_freq = 0
oov_freq = 0
common_error = 0
rare_error = 0
oov_error = 0
partial_error = 0
partial_freq = 0
very_common_error = 0
very_common_words = 0
words_error_freq = {}
words_total_freq = {}
low_freq_error = 0
low_freq_total = 0
for word, error in error_words_freqs.items():
if word in rareset:
rarewords.append(word)
rare_freq += freqlist_test[word]
rare_error += error
elif word not in freqlist:
oovwords.append(word)
oov_freq += freqlist_test[word] if word in freqlist_test else 1
oov_error += error
else:
commonwords.append(word)
common_freq += freqlist_test[word]
common_error += error
total_words = common_freq + rare_freq + oov_freq
insert_common = insert_error - insert_rare
total_errors = common_error + rare_error + oov_error + insert_error
WER = total_errors / total_words
print("=" * 89)
print(
"Common words error freq: {} / {} = {}".format(
common_error + insert_common, common_freq, (common_error + insert_common) / common_freq
)
)
print(
"Rare words error freq: {} / {} = {}".format(
rare_error + insert_rare, rare_freq, (rare_error + insert_rare) / rare_freq
)
)
print("OOV words error freq: {} / {} = {}".format(oov_error, oov_freq, oov_error / max(oov_freq, 1)))
print("WER estimate: {} / {} = {}".format(total_errors, total_words, WER))
print(
"Insert error: {} / {} = {}".format(
insert_error - insert_rare, total_words, (insert_error - insert_rare) / total_words
)
)
print("Insertion + OOV error {}".format((insert_error + oov_error - insert_rare) / total_words))
print("=" * 89)
import logging
import os
import pathlib
from argparse import ArgumentParser
import torch
import torchaudio
from lightning import ConformerRNNTModule
from transforms import get_data_module
logger = logging.getLogger()
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def run_eval(args):
model = ConformerRNNTModule.load_from_checkpoint(
args.checkpoint_path, sp_model=str(args.sp_model_path), biasing=args.biasing
).eval()
data_module = get_data_module(
str(args.librispeech_path),
str(args.global_stats_path),
str(args.sp_model_path),
biasinglist=args.biasing_list,
droprate=args.droprate,
maxsize=args.maxsize,
)
if args.use_cuda:
model = model.to(device="cuda")
total_edit_distance = 0
total_length = 0
dataloader = data_module.test_dataloader()
hypout = []
refout = []
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
filename = "librispeech_clean_100_{}".format(idx)
actual = sample[0][2]
predicted = model(batch)
hypout.append("{} ({})\n".format(predicted.upper().strip(), filename))
refout.append("{} ({})\n".format(actual.upper().strip(), filename))
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.warning(f"Final WER: {total_edit_distance / total_length}")
with open(os.path.join(args.expdir, "hyp.trn.txt"), "w") as fout:
fout.writelines(hypout)
with open(os.path.join(args.expdir, "ref.trn.txt"), "w") as fout:
fout.writelines(refout)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
required=True,
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats_100.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--expdir",
type=pathlib.Path,
help="Output path.",
required=True,
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
parser.add_argument(
"--biasing-list",
type=str,
default="",
help="Path to the biasing list used for inference.",
)
parser.add_argument(
"--droprate",
type=float,
default=0.0,
help="biasing list true entry drop rate",
)
parser.add_argument(
"--maxsize",
type=int,
default=0,
help="biasing list size",
)
parser.add_argument(
"--biasing",
type=str,
help="Use biasing",
)
args = parser.parse_args()
run_eval(args)
if __name__ == "__main__":
cli_main()
{
"mean": [
15.637551307678223,
16.923139572143555,
16.822391510009766,
16.71995735168457,
17.798818588256836,
17.773509979248047,
17.83729362487793,
18.358478546142578,
17.9212646484375,
17.89328956604004,
17.39158821105957,
17.29935646057129,
17.368602752685547,
17.506956100463867,
17.485977172851562,
17.350055694580078,
17.144203186035156,
16.917232513427734,
16.664018630981445,
16.391685485839844,
16.14568328857422,
15.940634727478027,
15.770298957824707,
15.61804485321045,
15.464021682739258,
15.357192039489746,
15.381829261779785,
15.079191207885742,
15.000809669494629,
15.170866012573242,
14.815556526184082,
14.997357368469238,
14.849116325378418,
15.036391258239746,
14.997495651245117,
15.095179557800293,
15.101740837097168,
15.14519214630127,
15.180743217468262,
15.156403541564941,
15.10532283782959,
15.052924156188965,
15.000785827636719,
14.922870635986328,
14.845956802368164,
14.808100700378418,
14.871763229370117,
14.811545372009277,
14.901408195495605,
14.831536293029785,
14.855134010314941,
14.778738975524902,
14.771122932434082,
14.732138633728027,
14.647477149963379,
14.561445236206055,
14.481091499328613,
14.422819137573242,
14.360650062561035,
14.320985794067383,
14.269508361816406,
14.194745063781738,
14.015007972717285,
13.864882469177246,
13.71847915649414,
13.572352409362793,
13.416520118713379,
13.26490306854248,
13.141557693481445,
13.059102058410645,
12.958319664001465,
12.893214225769043,
12.862544059753418,
12.809120178222656,
12.764240264892578,
12.69538688659668,
12.663688659667969,
12.577835083007812,
12.516019821166992,
12.431295394897461
],
"invstddev": [
0.2907441258430481,
0.2902139723300934,
0.2692812979221344,
0.26928815245628357,
0.23503832519054413,
0.2255384773015976,
0.21647363901138306,
0.20862863957881927,
0.2129201591014862,
0.21450363099575043,
0.21588537096977234,
0.21638962626457214,
0.21542198956012726,
0.2138291299343109,
0.21290543675422668,
0.21341055631637573,
0.2152717262506485,
0.21785758435726166,
0.22014079988002777,
0.22229592502117157,
0.2258182317018509,
0.229590505361557,
0.23325258493423462,
0.23702110350131989,
0.24055296182632446,
0.24432097375392914,
0.2479754239320755,
0.25079545378685,
0.25225189328193665,
0.2529183626174927,
0.2534293234348297,
0.25444915890693665,
0.25508981943130493,
0.26440101861953735,
0.27133750915527344,
0.2644997835159302,
0.257159560918808,
0.2571694552898407,
0.25754109025001526,
0.2585192620754242,
0.2599046528339386,
0.26100826263427734,
0.2632879614830017,
0.2646331489086151,
0.26438355445861816,
0.263375461101532,
0.2620519995689392,
0.259790301322937,
0.25840267539024353,
0.2579191029071808,
0.25915762782096863,
0.2594946622848511,
0.260376900434494,
0.2614041864871979,
0.262315034866333,
0.26391837000846863,
0.26464715600013733,
0.2644098997116089,
0.26398345828056335,
0.2636142671108246,
0.2644863724708557,
0.26914820075035095,
0.2653448283672333,
0.2650754153728485,
0.2656823992729187,
0.2668116092681885,
0.2679266929626465,
0.2683681547641754,
0.26889851689338684,
0.2706693708896637,
0.27155616879463196,
0.2738206088542938,
0.27583277225494385,
0.2766948342323303,
0.27835148572921753,
0.2804813086986542,
0.2821349799633026,
0.2801983952522278,
0.2796038091182709,
0.27942603826522827
]
}
import logging
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from pytorch_lightning import LightningModule
from torchaudio.prototype.models import conformer_rnnt_biasing_base, Hypothesis, RNNTBeamSearchBiasing
logger = logging.getLogger()
_expected_spm_vocab_size = 600
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths", "tries"])
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
r"""Learning rate scheduler that performs linear warmup and exponential annealing.
Args:
optimizer (torch.optim.Optimizer): optimizer to use.
warmup_steps (int): number of scheduler steps for which to warm up learning rate.
force_anneal_step (int): scheduler step at which annealing of learning rate begins.
anneal_factor (float): factor to scale base learning rate by at each annealing step.
last_epoch (int, optional): The index of last epoch. (Default: -1)
verbose (bool, optional): If ``True``, prints a message to stdout for
each update. (Default: ``False``)
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
force_anneal_step: int,
anneal_factor: float,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_steps
self.force_anneal_step = force_anneal_step
self.anneal_factor = anneal_factor
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count < self.force_anneal_step:
return [(min(1.0, self._step_count / self.warmup_steps)) * base_lr for base_lr in self.base_lrs]
else:
scaling_factor = self.anneal_factor ** (self._step_count - self.force_anneal_step)
return [scaling_factor * base_lr for base_lr in self.base_lrs]
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class ConformerRNNTModule(LightningModule):
def __init__(self, sp_model, biasing=False):
super().__init__()
self.sp_model = sp_model
self.sp_model = spm.SentencePieceProcessor(model_file=self.sp_model)
spm_vocab_size = self.sp_model.get_piece_size()
self.char_list = [self.sp_model.id_to_piece(idx) for idx in range(spm_vocab_size)]
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
self.char_list.append("<blank>")
# ``conformer_rnnt_biasing_base`` hardcodes a specific Conformer RNN-T configuration.
# For greater customizability, please refer to ``conformer_rnnt_biasing``.
self.biasing = biasing
self.model = conformer_rnnt_biasing_base(charlist=self.char_list, biasing=self.biasing)
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", fused_log_softmax=False)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9)
# This scheduler is for clean 100 and train 90 epochs, should change it when running longer
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 35, 60, 0.92)
# The epoch from which the TCPGen starts to train
self.tcpsche = self.model.tcpsche
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _, tcpgen_dist, p_gen = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
batch.tries,
self.current_epoch,
)
if self.biasing and self.current_epoch >= self.tcpsche and p_gen is not None:
# Assuming blank is the last token
model_output = torch.softmax(output, dim=-1)
p_not_null = 1.0 - model_output[:, :, :, -1:]
# Exclude blank prob
ptr_dist_fact = torch.cat([tcpgen_dist[:, :, :, :-2], tcpgen_dist[:, :, :, -1:]], dim=-1) * p_not_null
ptr_gen_complement = tcpgen_dist[:, :, :, -1:] * p_gen
# Interpolate between TPGen distribution and model distribution
p_partial = ptr_dist_fact[:, :, :, :-1] * p_gen + model_output[:, :, :, :-1] * (
1 - p_gen + ptr_gen_complement
)
# Add blank back
p_final = torch.cat([p_partial, model_output[:, :, :, -1:]], dim=-1)
# Numerical stability? Didn't need to do this in Espnet
logsmax_output = torch.log(p_final + 1e-12)
else:
logsmax_output = torch.log_softmax(output, dim=-1)
loss = self.loss(logsmax_output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True, batch_size=batch.targets.size(0))
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": "epoch"}],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearchBiasing(self.model, self.blank_idx, trie=batch.tries, biasing=self.biasing)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 10)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.features.size(0)
batch_sizes = self.all_gather(batch_size)
self.log(
"Gathered batch size",
batch_sizes.sum(),
on_step=True,
on_epoch=True,
batch_size=batch.targets.size(0),
)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.0)
opt.step()
# step every epoch
sch = self.lr_schedulers()
if self.trainer.is_last_batch:
sch.step()
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
#!/usr/bin/env bash
if [ $# -ne 1 ]; then
echo "Usage: $0 <DECODING_DIR>"
exit 1
fi
dir=$1 # the path to the decoding dir, e.g. experiments/librispeech_clean100_suffix600_tcpgen500_sche30_nodrop/decode_test_clean_b10_KB1000/
sclite -r "${dir}/ref.trn.txt" trn -h "${dir}/hyp.trn.txt" trn -i rm -o all stdout > "${dir}/result.wrd.txt"
import os
import pathlib
from argparse import ArgumentParser
from lightning import ConformerRNNTModule
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from transforms import get_data_module
def run_train(args):
seed_everything(1)
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [
checkpoint,
train_checkpoint,
lr_monitor,
]
if os.path.exists(args.resume) and args.resume != "":
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.nodes,
gpus=args.gpus,
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
resume_from_checkpoint=args.resume,
)
else:
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.nodes,
gpus=args.gpus,
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
model = ConformerRNNTModule(str(args.sp_model_path), args.biasing)
data_module = get_data_module(
str(args.librispeech_path),
str(args.global_stats_path),
str(args.sp_model_path),
subset=args.subset,
biasinglist=args.biasing_list,
droprate=args.droprate,
maxsize=args.maxsize,
)
trainer.fit(model, data_module, ckpt_path=args.checkpoint_path)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
default=None,
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats_100.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--nodes",
default=1,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=1,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=120,
type=int,
help="Number of epochs to train for. (Default: 120)",
)
parser.add_argument(
"--subset",
default="train-clean-100",
type=str,
help="Train on subset of librispeech.",
)
parser.add_argument(
"--biasing",
action="store_true",
help="Use biasing",
)
parser.add_argument(
"--biasing-list",
type=pathlib.Path,
help="Path to the biasing list.",
required=True,
)
parser.add_argument("--maxsize", default=1000, type=int, help="Size of biasing lists")
parser.add_argument("--droprate", default=0.0, type=float, help="Biasing component regularisation drop rate")
parser.add_argument(
"--resume",
default="",
type=str,
help="Path to resume model.",
)
args = parser.parse_args()
run_train(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500.
Using unigram wordpiece model and suffix-based wordpieces
Example:
python train_spm.py --librispeech-path ./datasets
"""
import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
def get_transcript_text(transcript_path):
with open(transcript_path) as f:
return [line.strip().split(" ", 1)[1].lower() for line in f]
def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*/*.trans.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts
def train_spm(input, suffix=False):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=600,
model_type="unigram",
input_sentence_size=-1,
character_coverage=1.0,
treat_whitespace_as_suffix=suffix,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def parse_args():
default_output_path = "./spm_unigram_600_100suffix.model"
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--librispeech-path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path(default_output_path),
type=pathlib.Path,
help=f"File to save model to. (Default: '{default_output_path}')",
)
parser.add_argument(
"--suffix",
action="store_true",
help="whether to use suffix-based wordpieces",
)
return parser.parse_args()
def run_cli():
args = parse_args()
root = args.librispeech_path / "LibriSpeech"
# Uncomment this for running bpe on full 960-hour data
# splits = ["train-clean-100", "train-clean-360", "train-other-500"]
splits = ["train-clean-100"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)
model = train_spm(merged_transcripts, suffix=args.suffix)
with open(args.output_file, "wb") as f:
f.write(model)
if __name__ == "__main__":
run_cli()
import json
import math
import random
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
from data_module import LibriSpeechDataModule
from lightning import Batch
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
random.seed(999)
def _piecewise_linear_log(x):
x = x * _gain
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
def _extract_labels(sp_model, samples: List):
targets = [sp_model.encode(sample[2].lower()) for sample in samples]
biasingwords = []
for sample in samples:
for word in sample[6]:
if word not in biasingwords:
biasingwords.append(word)
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths, biasingwords
def _extract_features(data_pipeline, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _extract_tries(sp_model, biasingwords, blist, droprate, maxsize):
if len(biasingwords) > 0 and droprate > 0:
biasingwords = random.sample(biasingwords, k=int(len(biasingwords) * (1 - droprate)))
if len(biasingwords) < maxsize:
distractors = random.sample(blist, k=max(0, maxsize - len(biasingwords)))
for word in distractors:
if word not in biasingwords:
biasingwords.append(word)
biasingwords = [sp_model.encode(word.lower()) for word in biasingwords]
biasingwords = sorted(biasingwords)
worddict = {tuple(word): i + 1 for i, word in enumerate(biasingwords)}
lextree = make_lexical_tree(worddict, -1)
return lextree, biasingwords
def make_lexical_tree(word_dict, word_unk):
"""Make a prefix tree"""
# node [dict(subword_id -> node), word_id, word_set[start-1, end]]
root = [{}, -1, None]
for w, wid in word_dict.items():
if wid > 0 and wid != word_unk:
succ = root[0]
for i, cid in enumerate(w):
if cid not in succ:
succ[cid] = [{}, -1, (wid - 1, wid)]
else:
prev = succ[cid][2]
succ[cid][2] = (min(prev[0], wid - 1), max(prev[1], wid))
if i == len(w) - 1:
succ[cid][1] = wid
succ = succ[cid][0]
return root
class TrainTransform:
def __init__(self, global_stats_path: str, sp_model_path: str, blist: List[str], droprate: float, maxsize: int):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.blist = blist
self.droprate = droprate
self.maxsize = maxsize
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.train_data_pipeline, samples)
targets, target_lengths, biasingwords = _extract_labels(self.sp_model, samples)
tries, biasingwords = _extract_tries(self.sp_model, biasingwords, self.blist, self.droprate, self.maxsize)
return Batch(features, feature_lengths, targets, target_lengths, tries)
class ValTransform:
def __init__(self, global_stats_path: str, sp_model_path: str, blist: List[str], droprate: float, maxsize: int):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
)
self.blist = blist
self.droprate = droprate
self.maxsize = maxsize
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.valid_data_pipeline, samples)
targets, target_lengths, biasingwords = _extract_labels(self.sp_model, samples)
if self.blist:
tries, biasingwords = _extract_tries(self.sp_model, biasingwords, self.blist, self.droprate, self.maxsize)
else:
tries = []
return Batch(features, feature_lengths, targets, target_lengths, tries)
class TestTransform:
def __init__(self, global_stats_path: str, sp_model_path: str, blist: List[str], droprate: float, maxsize: int):
self.val_transforms = ValTransform(global_stats_path, sp_model_path, blist, droprate, maxsize)
def __call__(self, sample):
return self.val_transforms([sample]), [sample]
def get_data_module(
librispeech_path, global_stats_path, sp_model_path, subset=None, biasinglist=None, droprate=0.0, maxsize=1000
):
fullbiasinglist = []
if biasinglist:
with open(biasinglist) as fin:
fullbiasinglist = [line.strip() for line in fin]
train_transform = TrainTransform(
global_stats_path=global_stats_path,
sp_model_path=sp_model_path,
blist=fullbiasinglist,
droprate=droprate,
maxsize=maxsize,
)
val_transform = ValTransform(
global_stats_path=global_stats_path,
sp_model_path=sp_model_path,
blist=fullbiasinglist,
droprate=droprate,
maxsize=maxsize,
)
test_transform = TestTransform(
global_stats_path=global_stats_path,
sp_model_path=sp_model_path,
blist=fullbiasinglist,
droprate=droprate,
maxsize=maxsize,
)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
subset=subset,
fullbiasinglist=fullbiasinglist,
)
......@@ -20,9 +20,14 @@ python inference.py \
## Results
The table below contains WER results for various pretrained models on LibriSpeech, using a beam size of 1500, and language model weight and word insertion scores taken from Table 7 of [wav2vec 2.0](https://arxiv.org/pdf/2006.11477.pdf).
| Model | test-clean | test-other |
|:------------------------------------------------------------------------------------------------:|-----------:|-----------:|
| [WAV2VEC2_ASR_BASE_10M](https://pytorch.org/audio/main/pipelines.html#wav2vec2-asr-base-10m) | 9.35| 15.91|
| [WAV2VEC2_ASR_BASE_100H](https://pytorch.org/audio/main/pipelines.html#wav2vec2-asr-base-100h) | 3.42| 8.07|
| [WAV2VEC2_ASR_BASE_960H](https://pytorch.org/audio/main/pipelines.html#wav2vec2-asr-base-960h) | 2.61| 6.15|
| [WAV2VEC2_ASR_LARGE_960H](https://pytorch.org/audio/main/pipelines.html#wav2vec2-asr-large-960h) | 2.34| 4.98|
| Model | LM weight | word insertion | dev-clean | dev-other | test-clean | test-other |
|:------------------------------------------------------------------------------------------------|-----------:|-----------:|-----------:|-----------:|-----------:|-----------:|
| [WAV2VEC2_ASR_BASE_10M](https://pytorch.org/audio/main/generated/torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M.html#torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M) | 3.23| -0.26| 9.41| 15.95| 9.35| 15.91|
| [WAV2VEC2_ASR_BASE_100H](https://pytorch.org/audio/main/generated/torchaudio.pipelines.WAV2VEC2_ASR_BASE_100H.html#torchaudio.pipelines.WAV2VEC2_ASR_BASE_100H) | 2.15| -0.52| 3.08| 7.89| 3.42| 8.07|
| [WAV2VEC2_ASR_BASE_960H](https://pytorch.org/audio/main/generated/torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.html#torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H) | 1.74| 0.52| 2.56| 6.26| 2.61| 6.15|
| [WAV2VEC2_ASR_LARGE_960H](https://pytorch.org/audio/main/generated/torchaudio.pipelines.WAV2VEC2_ASR_LARGE_960H.html#torchaudio.pipelines.WAV2VEC2_ASR_LARGE_960H) | 1.74| 0.52| 2.14| 4.62| 2.34| 4.98|
| [WAV2VEC2_ASR_LARGE_LV60K_10M](https://pytorch.org/audio/main/generated/torchaudio.pipelines.WAV2VEC2_ASR_LARGE_LV60K_10M.html#torchaudio.pipelines.WAV2VEC2_ASR_LARGE_LV60K_10M) | 3.86| -1.18| 6.77| 10.03| 6.87| 10.51|
| [WAV2VEC2_ASR_LARGE_LV60K_100H](https://pytorch.org/audio/main/generated/torchaudio.pipelines.WAV2VEC2_ASR_LARGE_LV60K_100H.html#torchaudio.pipelines.WAV2VEC2_ASR_LARGE_LV60K_100H) | 2.15| -0.52| 2.19| 4.55| 2.32| 4.64|
| [WAV2VEC2_ASR_LARGE_LV60K_960H](https://pytorch.org/audio/main/generated/torchaudio.pipelines.WAV2VEC2_ASR_LARGE_LV60K_960H.html#torchaudio.pipelines.WAV2VEC2_ASR_LARGE_LV60K_960H) | 1.57| -0.64| 1.78| 3.51| 2.03| 3.68|
| [HUBERT_ASR_LARGE](https://pytorch.org/audio/main/generated/torchaudio.pipelines.HUBERT_ASR_LARGE.html#torchaudio.pipelines.HUBERT_ASR_LARGE) | 1.57| -0.64| 1.77| 3.32| 2.03| 3.68|
| [HUBERT_ASR_XLARGE](https://pytorch.org/audio/main/generated/torchaudio.pipelines.HUBERT_ASR_XLARGE.html#torchaudio.pipelines.HUBERT_ASR_XLARGE) | 1.57| -0.64| 1.73| 2.72| 1.90| 3.16|
# Speech Recognition Inference with CUDA CTC Beam Search Decoder
This is an example inference script for running decoding on the LibriSpeech dataset and [zipformer](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_ctc) models, using a CUDA-based CTC beam search decoder that supports parallel decoding through batch and vocabulary axises.
## Usage
Additional command line parameters and information can is available with the `--help` option.
Sample command
```
pip install sentencepiece
# download pretrained files
wget -nc https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01/resolve/main/data/lang_bpe_500/bpe.model
wget -nc https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01/resolve/main/exp/cpu_jit.pt
python inference.py \
--librispeech_path ./librispeech/ \
--split test-other \
--model ./cpu_jit.pt \
--bp-model ./bpe.model \
--beam-size 10 \
--blank-skip-threshold 0.95
```
## Results
The table below contains throughput and WER benchmark results on librispeech test_other set between cuda ctc decoder and flashlight cpu decoder.
(Note: batch_size=4, beam_size=10, nbest=10, vocab_size=500, no LM, Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz, V100 GPU)
| Decoder | Setting | WER (%) | N-Best Oracle WER (%) | Decoder Cost Time (seconds) |
|:-----------|-----------:|-----------:|-----------:|-----------:|
|CUDA decoder|blank_skip_threshold=0.95| 5.81 | 4.11 | 2.57 |
|CUDA decoder|blank_skip_threshold=1.0 (no frame-skip)| 5.81 | 4.09 | 6.24 |
|flashlight decoder|beam_size_token=10| 5.86 | 4.30 | 28.61 |
|flashlight decoder|beam_size_token=vocab_size| 5.86 | 4.30 | 791.80 |
import argparse
import logging
import time
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchaudio.models.decoder import ctc_decoder, cuda_ctc_decoder
logger = logging.getLogger(__name__)
def collate_wrapper(batch):
speeches, labels = [], []
for (speech, _, label, _, _, _) in batch:
speeches.append(speech)
labels.append(label.strip().lower().strip())
return speeches, labels
def run_inference(args):
device = torch.device("cuda", 0)
model = torch.jit.load(args.model)
model.to(device)
model.eval()
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(args.bpe_model)
vocabs = [bpe_model.id_to_piece(id) for id in range(bpe_model.get_piece_size())]
if args.using_cpu_decoder:
cpu_decoder = ctc_decoder(
lexicon=None,
tokens=vocabs,
lm=None,
nbest=args.nbest,
beam_size=args.beam_size,
beam_size_token=args.beam_size_token,
beam_threshold=args.beam_threshold,
blank_token="<blk>",
sil_token="<blk>",
)
else:
assert vocabs[0] == "<blk>", "idx of blank token has to be zero"
cuda_decoder = cuda_ctc_decoder(
vocabs, nbest=args.nbest, beam_size=args.beam_size, blank_skip_threshold=args.blank_skip_threshold
)
dataset = torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url=args.split, download=True)
total_edit_distance, oracle_edit_distance, total_length = 0, 0, 0
data_loader = DataLoader(
dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, collate_fn=collate_wrapper
)
decoding_duration = 0
for idx, batch in enumerate(data_loader):
waveforms, transcripts = batch
waveforms = [wave.to(device) for wave in waveforms]
features = [torchaudio.compliance.kaldi.fbank(wave, num_mel_bins=80, snip_edges=False) for wave in waveforms]
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=torch.log(torch.tensor(1e-10)))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
x=features,
x_lens=feature_lengths,
)
nnet_output = model.ctc_output(encoder_out)
log_prob = torch.nn.functional.log_softmax(nnet_output, -1)
decoding_start = time.perf_counter()
preds = []
if args.using_cpu_decoder:
results = cpu_decoder(log_prob.cpu())
duration = time.perf_counter() - decoding_start
for i in range(len(results)):
ith_preds = bpe_model.decode([results[i][j].tokens.tolist() for j in range(len(results[i]))])
ith_preds = [pred.lower().split() for pred in ith_preds]
preds.append(ith_preds)
else:
results = cuda_decoder(log_prob, encoder_out_lens.to(torch.int32))
duration = time.perf_counter() - decoding_start
for i in range(len(results)):
ith_preds = bpe_model.decode([results[i][j].tokens for j in range(len(results[i]))])
ith_preds = [pred.lower().split() for pred in ith_preds]
preds.append(ith_preds)
decoding_duration += duration
for transcript, nbest_pred in zip(transcripts, preds):
total_edit_distance += torchaudio.functional.edit_distance(transcript.split(), nbest_pred[0])
oracle_edit_distance += min(
[torchaudio.functional.edit_distance(transcript.split(), nbest_pred[i]) for i in range(len(nbest_pred))]
)
total_length += len(transcript.split())
if idx % 10 == 0:
logger.info(
f"Processed elem {idx}; "
f"WER: {total_edit_distance / total_length}, "
f"Oracle WER: {oracle_edit_distance / total_length}, ",
f"decoding time for batch size {args.batch_size}: {duration}",
)
logger.info(
f"Final WER: {total_edit_distance / total_length}, ",
f"Oracle WER: {oracle_edit_distance / total_length}, ",
f"time for decoding {decoding_duration} [sec].",
)
def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--librispeech_path",
type=str,
help="folder where LibriSpeech is stored",
default="./librispeech",
)
parser.add_argument(
"--split",
type=str,
help="LibriSpeech dataset split",
choices=["dev-clean", "dev-other", "test-clean", "test-other"],
default="test-other",
)
parser.add_argument(
"--model",
type=str,
default="./cpu_jit.pt",
help="pretrained ASR model using CTC loss",
)
parser.add_argument(
"--bpe-model",
type=str,
default="./bpe.model",
help="bpe file for pretrained ASR model",
)
parser.add_argument(
"--nbest",
type=int,
default=10,
help="number of best hypotheses to return",
)
parser.add_argument(
"--beam-size",
type=int,
default=10,
help="beam size for determining number of hypotheses to store",
)
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="batch size for decoding",
)
parser.add_argument(
"--blank-skip-threshold",
type=float,
default=0.95,
help="skip frames where prob_blank > 0.95, https://ieeexplore.ieee.org/document/7736093",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
# cpu decoder specific parameters
parser.add_argument("--using-cpu-decoder", action="store_true", help="whether to use flashlight cpu ctc decoder")
parser.add_argument("--beam-threshold", type=int, default=50, help="beam threshold for pruning hypotheses")
parser.add_argument(
"--beam-size-token",
type=int,
default=None,
help="number of tokens to consider at each beam search step",
)
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def _main():
args = _parse_args()
_init_logger(args.debug)
run_inference(args)
if __name__ == "__main__":
_main()
<p align="center"><img width="160" src="https://download.pytorch.org/torchaudio/doc-assets/avsr/lip_white.png" alt="logo"></p>
<h1 align="center">Real-time ASR/VSR/AV-ASR Examples</h1>
<div align="center">
[📘Introduction](#introduction) |
[📊Training](#Training) |
[🔮Evaluation](#Evaluation)
</div>
## Introduction
This directory contains the training recipe for real-time audio, visual, and audio-visual speech recognition (ASR, VSR, AV-ASR) models, which is an extension of [Auto-AVSR](https://arxiv.org/abs/2303.14307).
## Preparation
1. Install PyTorch (pytorch, torchvision, torchaudio) from [source](https://pytorch.org/get-started/), along with all necessary packages:
```Shell
pip install torch torchvision torchaudio pytorch-lightning sentencepiece
```
2. Preprocess LRS3. See the instructions in the [data_prep](./data_prep) folder.
## Usage
### Training
```Shell
python train.py --exp-dir=[exp_dir] \
--exp-name=[exp_name] \
--modality=[modality] \
--mode=[mode] \
--root-dir=[root-dir] \
--sp-model-path=[sp_model_path] \
--num-nodes=[num_nodes] \
--gpus=[gpus]
```
- `exp-dir` and `exp-name`: The directory where the checkpoints will be saved, will be stored at the location `[exp_dir]`/`[exp_name]`.
- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
- `mode`: Type of the mode. Valid values are: `online` and `offline`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`, which can be produced using `train_spm.py`.
- `num-nodes`: The number of machines used. Default: 4.
- `gpus`: The number of gpus in each machine. Default: 8.
### Evaluation
```Shell
python eval.py --modality=[modality] \
--mode=[mode] \
--root-dir=[dataset_path] \
--sp-model-path=[sp_model_path] \
--checkpoint-path=[checkpoint_path]
```
- `modality`: Type of the input modality. Valid values are: `video`, `audio`, and `audiovisual`.
- `mode`: Type of the mode. Valid values are: `online` and `offline`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `sp-model-path`: Path to the sentencepiece model. Default: `./spm_unigram_1023.model`.
- `checkpoint-path`: Path to a pretraned model.
## Results
The table below contains WER for AV-ASR models that were trained from scratch [offline evaluation].
| Model | Training dataset (hours) | WER [%] | Params (M) |
|:--------------------:|:------------------------:|:-------:|:----------:|
| Non-streaming models | | | |
| AV-ASR | LRS3 (438) | 3.9 | 50 |
| Streaming models | | | |
| AV-ASR | LRS3 (438) | 3.9 | 40 |
import os
import torch
def average_checkpoints(last):
avg = None
for path in last:
states = torch.load(path, map_location=lambda storage, loc: storage)["state_dict"]
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
if avg[k].is_floating_point():
avg[k] /= len(last)
else:
avg[k] //= len(last)
return avg
def ensemble(args):
last = [os.path.join(args.exp_dir, args.exp_name, f"epoch={n}.ckpt") for n in range(args.epochs - 10, args.epochs)]
model_path = os.path.join(args.exp_dir, args.exp_name, "model_avg_10.pth")
torch.save({"state_dict": average_checkpoints(last)}, model_path)
import random
import torch
from lrs3 import LRS3
from pytorch_lightning import LightningDataModule
def _batch_by_token_count(idx_target_lengths, max_frames, batch_size=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > max_frames or (batch_size and len(current_batch) == batch_size):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset,
lengths,
max_frames,
num_buckets,
shuffle=False,
batch_size=None,
):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_frames >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets],
max_frames,
batch_size=batch_size,
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn
def __getitem__(self, idx):
return self.transform_fn(self.dataset[idx])
def __len__(self):
return len(self.dataset)
class LRS3DataModule(LightningDataModule):
def __init__(
self,
*,
args,
train_transform,
val_transform,
test_transform,
max_frames,
batch_size=None,
train_num_buckets=50,
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.args = args
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
self.max_frames = max_frames
self.batch_size = batch_size
self.train_num_buckets = train_num_buckets
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
dataset = LRS3(self.args, subset="train")
dataset = CustomBucketDataset(
dataset, dataset.lengths, self.max_frames, self.train_num_buckets, batch_size=self.batch_size
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset, num_workers=self.num_workers, batch_size=None, shuffle=self.train_shuffle
)
return dataloader
def val_dataloader(self):
dataset = LRS3(self.args, subset="val")
dataset = CustomBucketDataset(dataset, dataset.lengths, self.max_frames, 1, batch_size=self.batch_size)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
def test_dataloader(self):
dataset = LRS3(self.args, subset="test")
dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
# Pre-process LRS3
We provide a pre-processing pipeline in this repository for detecting and cropping full-face regions of interest (ROIs) as well as corresponding audio waveforms for LRS3.
## Introduction
Before feeding the raw stream into our model, each video sequence has to undergo a specific pre-processing procedure. This involves three critical steps. The first step is to perform face detection. Following that, each individual frame is aligned to a referenced frame, commonly known as the mean face, in order to normalize rotation and size differences across frames. The final step in the pre-processing module is to crop the face region from the aligned face image.
<div align="center">
<table style="display: inline-table;">
<tr><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/original.gif", width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/detected.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif" width="144"></td><td><img src="https://download.pytorch.org/torchaudio/doc-assets/avsr/cropped.gif" width="144"></td></tr>
<tr><td>0. Original</td> <td>1. Detection</td> <td>2. Transformation</td> <td>3. Face ROIs</td> </tr>
</table>
</div>
## Preparation
1. Install all dependency-packages.
```Shell
pip install -r requirements.txt
```
2. Install [retinaface](./tools) or [mediapipe](https://pypi.org/project/mediapipe/) tracker. If you have installed the tracker, please skip it.
## Preprocessing LRS3
To pre-process the LRS3 dataset, plrase follow these steps:
1. Download the LRS3 dataset from the official website.
2. Run the following command to preprocess the dataset:
```Shell
python preprocess_lrs3.py \
--data-dir=[data_dir] \
--detector=[detector] \
--dataset=[dataset] \
--root-dir=[root] \
--subset=[subset] \
--seg-duration=[seg_duration] \
--groups=[n] \
--job-index=[j]
```
- `data-dir`: Path to the directory containing video files.
- `detector`: Type of face detector. Valid values are: `mediapipe` and `retinaface`. Default: `retinaface`.
- `dataset`: Name of the dataset. Valid value is: `lrs3`.
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `subset`: Name of the subset. Valid values are: `train` and `test`.
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
- `groups`: Number of groups to split the dataset into.
- `job-index`: Job index for the current group. Valid values are an integer within the range of `[0, n)`.
3. Run the following command to merge all labels:
```Shell
python merge.py \
--root-dir=[root_dir] \
--dataset=[dataset] \
--subset=[subset] \
--seg-duration=[seg_duration] \
--groups=[n]
```
- `root-dir`: Path to the root directory where all preprocessed files will be stored.
- `dataset`: Name of the dataset. Valid values are: `lrs2` and `lrs3`.
- `subset`: The subset name of the dataset. For LRS2, valid values are `train`, `val`, and `test`. For LRS3, valid values are `train` and `test`.
- `seg-duration`: Length of the maximal segment in seconds. Default: `16`.
- `groups`: Number of groups to split the dataset into.
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torchaudio
import torchvision
class AVSRDataLoader:
def __init__(self, modality, detector="retinaface", resize=None):
self.modality = modality
if modality == "video":
if detector == "retinaface":
from detectors.retinaface.detector import LandmarksDetector
from detectors.retinaface.video_process import VideoProcess
self.landmarks_detector = LandmarksDetector(device="cuda:0")
self.video_process = VideoProcess(resize=resize)
if detector == "mediapipe":
from detectors.mediapipe.detector import LandmarksDetector
from detectors.mediapipe.video_process import VideoProcess
self.landmarks_detector = LandmarksDetector()
self.video_process = VideoProcess(resize=resize)
def load_data(self, data_filename, transform=True):
if self.modality == "audio":
audio, sample_rate = self.load_audio(data_filename)
audio = self.audio_process(audio, sample_rate)
return audio
if self.modality == "video":
video = self.load_video(data_filename)
landmarks = self.landmarks_detector(video)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
return video
def load_audio(self, data_filename):
waveform, sample_rate = torchaudio.load(data_filename, normalize=True)
return waveform, sample_rate
def load_video(self, data_filename):
return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy()
def audio_process(self, waveform, sample_rate, target_sample_rate=16000):
if sample_rate != target_sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
waveform = torch.mean(waveform, dim=0, keepdim=True)
return waveform
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import warnings
import mediapipe as mp
import numpy as np
warnings.filterwarnings("ignore")
class LandmarksDetector:
def __init__(self):
self.mp_face_detection = mp.solutions.face_detection
self.short_range_detector = self.mp_face_detection.FaceDetection(
min_detection_confidence=0.5, model_selection=0
)
self.full_range_detector = self.mp_face_detection.FaceDetection(min_detection_confidence=0.5, model_selection=1)
def __call__(self, video_frames):
landmarks = self.detect(video_frames, self.full_range_detector)
if all(element is None for element in landmarks):
landmarks = self.detect(video_frames, self.short_range_detector)
assert any(l is not None for l in landmarks), "Cannot detect any frames in the video"
return landmarks
def detect(self, video_frames, detector):
landmarks = []
for frame in video_frames:
results = detector.process(frame)
if not results.detections:
landmarks.append(None)
continue
face_points = []
for idx, detected_faces in enumerate(results.detections):
max_id, max_size = 0, 0
bboxC = detected_faces.location_data.relative_bounding_box
ih, iw, ic = frame.shape
bbox = int(bboxC.xmin * iw), int(bboxC.ymin * ih), int(bboxC.width * iw), int(bboxC.height * ih)
bbox_size = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1])
if bbox_size > max_size:
max_id, max_size = idx, bbox_size
lmx = [[int(bboxC.xmin * iw), int(bboxC.ymin * ih)], [int(bboxC.width * iw), int(bboxC.height * ih)]]
face_points.append(lmx)
landmarks.append(np.reshape(np.array(face_points[max_id]), (2, 2)))
return landmarks
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment