Unverified Commit 4d251485 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

[release 0.13] Remove prototype (#2749)

parent 84d8ced9
...@@ -53,18 +53,6 @@ API References ...@@ -53,18 +53,6 @@ API References
kaldi_io kaldi_io
utils utils
Prototype API References
------------------------
.. toctree::
:maxdepth: 1
:caption: Prototype API Reference
prototype
prototype.functional
prototype.models
prototype.pipelines
Getting Started Getting Started
--------------- ---------------
......
torchaudio.prototype.functional
===============================
.. py:module:: torchaudio.prototype.functional
.. currentmodule:: torchaudio.prototype.functional
add_noise
~~~~~~~~~
.. autofunction:: add_noise
convolve
~~~~~~~~
.. autofunction:: convolve
fftconvolve
~~~~~~~~~~~
.. autofunction:: fftconvolve
torchaudio.prototype.models
===========================
.. py:module:: torchaudio.prototype.models
.. currentmodule:: torchaudio.prototype.models
conformer_rnnt_model
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_rnnt_model
conformer_rnnt_base
~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_rnnt_base
ConvEmformer
~~~~~~~~~~~~
.. autoclass:: ConvEmformer
.. automethod:: forward
.. automethod:: infer
torchaudio.prototype.pipelines
==============================
.. py:module:: torchaudio.prototype.pipelines
.. currentmodule:: torchaudio.prototype.pipelines
The pipelines subpackage contains APIs to models with pretrained weights and relevant utilities.
RNN-T Streaming/Non-Streaming ASR
---------------------------------
EMFORMER_RNNT_BASE_MUSTC
~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_MUSTC
:no-value:
EMFORMER_RNNT_BASE_TEDLIUM3
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_TEDLIUM3
:no-value:
torchaudio.prototype
====================
``torchaudio.prototype`` provides prototype features;
they are at an early stage for feedback and testing.
Their interfaces might be changed without prior notice.
Most modules of prototypes are excluded from release.
Please refer to `here <https://pytorch.org/audio>`_ for
more information on prototype features.
The modules under ``torchaudio.prototype`` must be
imported explicitly, e.g.
.. code-block:: python
import torchaudio.prototype.models
.. toctree::
prototype.functional
prototype.models
prototype.pipelines
...@@ -15,7 +15,7 @@ This directory contains sample implementations of training and evaluation pipeli ...@@ -15,7 +15,7 @@ This directory contains sample implementations of training and evaluation pipeli
### Pipeline Demo ### Pipeline Demo
[`pipeline_demo.py`](./pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH` [`pipeline_demo.py`](./pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH`
or `EMFORMER_RNNT_BASE_TEDLIUM3` bundle that wraps a pre-trained Emformer RNN-T produced by the corresponding recipe below to perform streaming and full-context ASR on several audio samples. bundle that wraps a pre-trained Emformer RNN-T produced by the LibriSpeech recipe below to perform streaming and full-context ASR on several audio samples.
## Model Types ## Model Types
...@@ -67,8 +67,6 @@ The table below contains WER results for dev and test subsets of TED-LIUM releas ...@@ -67,8 +67,6 @@ The table below contains WER results for dev and test subsets of TED-LIUM releas
| dev | 0.108 | | dev | 0.108 |
| test | 0.098 | | test | 0.098 |
[`tedlium3/eval_pipeline.py`](./tedlium3/eval_pipeline.py) evaluates the pre-trained `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3. Running the script should produce WER results that are identical to those in the above table.
### MuST-C release v2.0 ### MuST-C release v2.0
The MuST-C model is configured with a vocabulary size of 500. Consequently, the MuST-C model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol). In contrast to those of the datasets for the above two models, MuST-C's transcripts are cased and punctuated; we preserve the casing and punctuation when training the SentencePiece model. The MuST-C model is configured with a vocabulary size of 500. Consequently, the MuST-C model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol). In contrast to those of the datasets for the above two models, MuST-C's transcripts are cased and punctuated; we preserve the casing and punctuation when training the SentencePiece model.
......
...@@ -13,10 +13,8 @@ from typing import Callable ...@@ -13,10 +13,8 @@ from typing import Callable
import torch import torch
import torchaudio import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_MUSTC, MODEL_TYPE_TEDLIUM3 from common import MODEL_TYPE_LIBRISPEECH
from mustc.dataset import MUSTC
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -32,14 +30,6 @@ _CONFIGS = { ...@@ -32,14 +30,6 @@ _CONFIGS = {
partial(torchaudio.datasets.LIBRISPEECH, url="test-clean"), partial(torchaudio.datasets.LIBRISPEECH, url="test-clean"),
EMFORMER_RNNT_BASE_LIBRISPEECH, EMFORMER_RNNT_BASE_LIBRISPEECH,
), ),
MODEL_TYPE_MUSTC: Config(
partial(MUSTC, subset="tst-COMMON"),
EMFORMER_RNNT_BASE_MUSTC,
),
MODEL_TYPE_TEDLIUM3: Config(
partial(torchaudio.datasets.TEDLIUM, release="release3", subset="test"),
EMFORMER_RNNT_BASE_TEDLIUM3,
),
} }
......
#!/usr/bin/env python3
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import torch
import torchaudio
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3
logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def _eval_subset(tedlium_path, subset, feature_extractor, decoder, token_processor, use_cuda):
total_edit_distance = 0
total_length = 0
if subset == "dev":
dataset = torchaudio.datasets.TEDLIUM(tedlium_path, release="release3", subset="dev")
elif subset == "test":
dataset = torchaudio.datasets.TEDLIUM(tedlium_path, release="release3", subset="test")
with torch.no_grad():
for idx in range(len(dataset)):
sample = dataset[idx]
waveform = sample[0].squeeze()
if use_cuda:
waveform = waveform.to(device="cuda")
actual = sample[2].replace("\n", "")
if actual == "ignore_time_segment_in_scoring":
continue
features, length = feature_extractor(waveform)
hypos = decoder(features, length, 20)
hypothesis = hypos[0]
hypothesis = token_processor(hypothesis[0])
total_edit_distance += compute_word_level_distance(actual, hypothesis)
total_length += len(actual.split())
if idx % 100 == 0:
print(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
print(f"Final WER for {subset} set: {total_edit_distance / total_length}")
def run_eval_pipeline(args):
decoder = EMFORMER_RNNT_BASE_TEDLIUM3.get_decoder()
token_processor = EMFORMER_RNNT_BASE_TEDLIUM3.get_token_processor()
feature_extractor = EMFORMER_RNNT_BASE_TEDLIUM3.get_feature_extractor()
if args.use_cuda:
feature_extractor = feature_extractor.to(device="cuda").eval()
decoder = decoder.to(device="cuda")
_eval_subset(args.tedlium_path, "dev", feature_extractor, decoder, token_processor, args.use_cuda)
_eval_subset(args.tedlium_path, "test", feature_extractor, decoder, token_processor, args.use_cuda)
def _parse_args():
parser = ArgumentParser(
description=__doc__,
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--tedlium-path",
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
run_eval_pipeline(args)
if __name__ == "__main__":
cli_main()
{
"mean": [
14.762723922729492,
16.020633697509766,
16.911531448364258,
16.80994415283203,
18.72406005859375,
18.84550666809082,
19.021404266357422,
19.623443603515625,
19.403806686401367,
19.52766990661621,
19.253433227539062,
19.211227416992188,
19.216045379638672,
19.315574645996094,
19.267532348632812,
19.146976470947266,
18.98181915283203,
18.81462287902832,
18.67916488647461,
18.5198917388916,
18.360441207885742,
18.18699836730957,
18.008447647094727,
17.82094955444336,
17.644861221313477,
17.51972007751465,
17.51348876953125,
17.171707153320312,
17.070415496826172,
17.21990394592285,
16.868940353393555,
17.048307418823242,
16.894960403442383,
17.04732322692871,
16.955705642700195,
17.053966522216797,
17.037548065185547,
17.03425407409668,
17.03618621826172,
16.979724884033203,
16.889690399169922,
16.779285430908203,
16.689767837524414,
16.62590789794922,
16.600360870361328,
16.610321044921875,
16.692338943481445,
16.61323356628418,
16.638328552246094,
16.494739532470703,
16.42980194091797,
16.23759651184082,
16.144210815429688,
16.018585205078125,
15.985218048095703,
15.947102546691895,
15.894798278808594,
15.832999229431152,
15.704426765441895,
15.538087844848633,
15.378302574157715,
15.19461441040039,
15.00456714630127,
14.861663818359375,
14.676336288452148,
14.594626426696777,
14.561753273010254,
14.464197158813477,
14.43082046508789,
14.388801574707031,
14.257562637329102,
14.231459617614746,
14.19768238067627,
14.123900413513184,
14.159867286682129,
14.059795379638672,
13.968880653381348,
13.927794456481934,
13.645783424377441,
12.086114883422852
],
"invstddev": [
0.3553205132484436,
0.3363242745399475,
0.3194723129272461,
0.3199574947357178,
0.28755369782447815,
0.2879481613636017,
0.27939942479133606,
0.27543479204177856,
0.2806696891784668,
0.28141146898269653,
0.2753477990627289,
0.274241179227829,
0.27815768122673035,
0.27794352173805237,
0.2763032615184784,
0.2744459807872772,
0.27375343441963196,
0.27415215969085693,
0.27628427743911743,
0.27667510509490967,
0.2806207835674286,
0.28371962904930115,
0.2893684506416321,
0.2944427728652954,
0.2989389896392822,
0.30326008796691895,
0.30760079622268677,
0.3089521527290344,
0.3105863034725189,
0.31274259090423584,
0.31318506598472595,
0.3154853880405426,
0.3167822062969208,
0.3182784914970398,
0.31875282526016235,
0.3185810148715973,
0.31908345222473145,
0.3207632303237915,
0.32282087206840515,
0.3241617977619171,
0.3260948061943054,
0.32735878229141235,
0.32947203516960144,
0.33052706718444824,
0.3309975266456604,
0.3301711678504944,
0.32793518900871277,
0.3252142369747162,
0.32336947321891785,
0.32320502400398254,
0.3264254927635193,
0.32860180735588074,
0.3322647213935852,
0.3100382685661316,
0.3216720223426819,
0.32280418276786804,
0.32710719108581543,
0.3284962773323059,
0.3319654166698456,
0.32880258560180664,
0.33075764775276184,
0.32947179675102234,
0.32880640029907227,
0.3296009302139282,
0.324250727891922,
0.3247823715209961,
0.328702837228775,
0.32418182492256165,
0.3247915208339691,
0.3251509964466095,
0.31811773777008057,
0.3195462226867676,
0.3187839686870575,
0.31459841132164,
0.32190003991127014,
0.3193890154361725,
0.315574049949646,
0.317360520362854,
0.3075887858867645,
0.3034747838973999
]
}
import os
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
from common import (
Batch,
batch_by_token_count,
FunctionalModule,
GlobalStatsNormalization,
piecewise_linear_log,
post_process_hypos,
spectrogram_transform,
WarmupLR,
)
from pytorch_lightning import LightningModule
from torchaudio.models import emformer_rnnt_base, RNNTBeamSearch
class CustomDataset(torch.utils.data.Dataset):
r"""Sort TEDLIUM3 samples by target length and batch to max durations."""
def __init__(self, base_dataset, max_token_limit):
super().__init__()
self.base_dataset = base_dataset
idx_target_lengths = [
(idx, self._target_length(fileid, line)) for idx, (fileid, line) in enumerate(self.base_dataset._filelist)
]
idx_target_lengths = [(idx, length) for idx, length in idx_target_lengths if length != -1]
assert len(idx_target_lengths) > 0
idx_target_lengths = sorted(idx_target_lengths, key=lambda x: x[1])
assert max_token_limit >= idx_target_lengths[-1][1]
self.batches = batch_by_token_count(idx_target_lengths, max_token_limit)
def _target_length(self, fileid, line):
transcript_path = os.path.join(self.base_dataset._path, "stm", fileid)
with open(transcript_path + ".stm") as f:
transcript = f.readlines()[line]
_, _, _, start_time, end_time, _, transcript = transcript.split(" ", 6)
if transcript.lower() == "ignore_time_segment_in_scoring\n":
return -1
else:
return float(end_time) - float(start_time)
def __getitem__(self, idx):
return [self.base_dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class EvalDataset(torch.utils.data.IterableDataset):
def __init__(self, base_dataset):
super().__init__()
self.base_dataset = base_dataset
def __iter__(self):
for sample in iter(self.base_dataset):
actual = sample[2].replace("\n", "")
if actual == "ignore_time_segment_in_scoring":
continue
yield sample
class TEDLIUM3RNNTModule(LightningModule):
def __init__(
self,
*,
tedlium_path: str,
sp_model_path: str,
global_stats_path: str,
):
super().__init__()
self.model = emformer_rnnt_base(num_symbols=501)
self.loss = torchaudio.transforms.RNNTLoss(reduction="mean", clamp=1.0)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.tedlium_path = tedlium_path
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.blank_idx = self.sp_model.get_piece_size()
def _extract_labels(self, samples: List):
"""Convert text transcript into int labels.
Note:
There are ``<unk>`` tokens in the training set that are regarded as normal tokens
by the SentencePiece model. This will impact RNNT decoding since the decoding result
of ``<unk>`` will be ``?? unk ??`` and will not be excluded from the final prediction.
To address it, here we replace ``<unk>`` with ``<garbage>`` and set
``user_defined_symbols=["<garbage>"]`` in the SentencePiece model training.
Then we map the index of ``<garbage>`` to the real ``unknown`` index.
"""
targets = [
self.sp_model.encode(sample[2].lower().replace("<unk>", "<garbage>").replace("\n", ""))
for sample in samples
]
targets = [
[ele if ele != 4 else self.sp_model.unk_id() for ele in target] for target in targets
] # map id of <unk> token to unk_id
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _train_collate_fn(self, samples: List):
features, feature_lengths = self._train_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _valid_collate_fn(self, samples: List):
features, feature_lengths = self._valid_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), [sample[2] for sample in samples]
def _step(self, batch, batch_idx, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _ = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[
{"scheduler": self.warmup_lr_scheduler, "interval": "step"},
],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch_tuple, batch_idx):
return self._step(batch_tuple[0], batch_idx, "test")
def train_dataloader(self):
dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="train"), 100)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._train_collate_fn,
num_workers=10,
shuffle=True,
)
return dataloader
def val_dataloader(self):
dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev"), 100)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._valid_collate_fn,
num_workers=10,
)
return dataloader
def test_dataloader(self):
dataset = EvalDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="test"))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
def dev_dataloader(self):
dataset = EvalDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev"))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
#!/usr/bin/env python3
"""Train the SentencePiece model by using the transcripts of TED-LIUM release 3 training set.
Example:
python train_spm.py --tedlium-path /home/datasets/
"""
import io
import logging
import os
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
logger = logging.getLogger(__name__)
def _parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--tedlium-path",
required=True,
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path("./spm_bpe_500.model"),
type=pathlib.Path,
help="File to save model to. (Default: './spm_bpe_500.model')",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _extract_train_text(tedlium_path, output_dir):
stm_path = tedlium_path / "TEDLIUM_release-3/data/stm/"
transcripts = []
for file in sorted(os.listdir(stm_path)):
if file.endswith(".stm"):
file = os.path.join(stm_path, file)
with open(file) as f:
for line in f.readlines():
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = line.split(" ", 6)
if transcript == "ignore_time_segment_in_scoring\n":
continue
else:
transcript = transcript.replace("<unk>", "<garbage>").replace("\n", "")
transcripts.append(transcript)
return transcripts
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
vocab_size=500,
model_type="bpe",
input_sentence_size=-1,
character_coverage=1.0,
user_defined_symbols=["<garbage>"],
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
transcripts = _extract_train_text(args.tedlium_path, args.output_dir)
model = train_spm(transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
logger.info("Successfully trained the sentencepiece model")
if __name__ == "__main__":
cli_main()
# Conformer RNN-T ASR Example
This directory contains sample implementations of training and evaluation pipelines for a Conformer RNN-T ASR model.
## Setup
### Install PyTorch and TorchAudio nightly or from source
Because Conformer RNN-T is currently a prototype feature, you will need to either use the TorchAudio nightly build or build TorchAudio from source. Note also that GPU support is required for training.
To install the nightly, follow the directions at <https://pytorch.org/>.
To build TorchAudio from source, refer to the [contributing guidelines](https://github.com/pytorch/audio/blob/main/CONTRIBUTING.md).
### Install additional dependencies
```bash
pip install pytorch-lightning sentencepiece
```
## Usage
### Training
[`train.py`](./train.py) trains an Conformer RNN-T model (30.2M parameters, 121MB) on LibriSpeech using PyTorch Lightning. Note that the script expects users to have the following:
- Access to GPU nodes for training.
- Full LibriSpeech dataset.
- SentencePiece model to be used to encode targets; the model can be generated using [`train_spm.py`](./train_spm.py).
- File (--global_stats_path) that contains training set feature statistics; this file can be generated using [`global_stats.py`](../emformer_rnnt/global_stats.py).
Sample SLURM command:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp_dir ./experiments --librispeech_path ./librispeech/ --global_stats_path ./global_stats.json --sp_model_path ./spm_unigram_1023.model --epochs 160
```
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Conformer RNN-T model on LibriSpeech test-clean.
Sample SLURM command:
```
srun python eval.py --checkpoint_path ./experiments/checkpoints/epoch=159.ckpt --librispeech_path ./librispeech/ --sp_model_path ./spm_unigram_1023.model --use_cuda
```
The table below contains WER results for various splits.
| | WER |
|:-------------------:|-------------:|
| test-clean | 0.0310 |
| test-other | 0.0805 |
| dev-clean | 0.0314 |
| dev-other | 0.0827 |
import os
import random
import torch
import torchaudio
from pytorch_lightning import LightningDataModule
def _batch_by_token_count(idx_target_lengths, max_tokens, batch_size=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > max_tokens or (batch_size and len(current_batch) == batch_size):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}
def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
return [_target_length(fileid) for fileid in librispeech_dataset._walker]
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset,
lengths,
max_tokens,
num_buckets,
shuffle=False,
batch_size=None,
):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_tokens >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets],
max_tokens,
batch_size=batch_size,
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn
def __getitem__(self, idx):
return self.transform_fn(self.dataset[idx])
def __len__(self):
return len(self.dataset)
class LibriSpeechDataModule(LightningDataModule):
librispeech_cls = torchaudio.datasets.LIBRISPEECH
def __init__(
self,
*,
librispeech_path,
train_transform,
val_transform,
test_transform,
max_tokens=700,
batch_size=2,
train_num_buckets=50,
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
self.max_tokens = max_tokens
self.batch_size = batch_size
self.train_num_buckets = train_num_buckets
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
datasets = [
self.librispeech_cls(self.librispeech_path, url="train-clean-360"),
self.librispeech_cls(self.librispeech_path, url="train-clean-100"),
self.librispeech_cls(self.librispeech_path, url="train-other-500"),
]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
)
return dataloader
def val_dataloader(self):
datasets = [
self.librispeech_cls(self.librispeech_path, url="dev-clean"),
self.librispeech_cls(self.librispeech_path, url="dev-other"),
]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
def test_dataloader(self):
dataset = self.librispeech_cls(self.librispeech_path, url="test-clean")
dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
import logging
import pathlib
from argparse import ArgumentParser
import sentencepiece as spm
import torch
import torchaudio
from lightning import ConformerRNNTModule
from transforms import get_data_module
logger = logging.getLogger()
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def run_eval(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
model = ConformerRNNTModule.load_from_checkpoint(args.checkpoint_path, sp_model=sp_model).eval()
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path))
if args.use_cuda:
model = model.to(device="cuda")
total_edit_distance = 0
total_length = 0
dataloader = data_module.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][2]
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.warning(f"Final WER: {total_edit_distance / total_length}")
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
required=True,
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
args = parser.parse_args()
run_eval(args)
if __name__ == "__main__":
cli_main()
{
"mean": [
15.058613777160645,
16.34557342529297,
16.34653663635254,
16.240671157836914,
17.45355224609375,
17.445302963256836,
17.52323341369629,
18.076807022094727,
17.699262619018555,
17.706790924072266,
17.24724578857422,
17.153791427612305,
17.213361740112305,
17.347240447998047,
17.331117630004883,
17.21516227722168,
17.030071258544922,
16.818960189819336,
16.573062896728516,
16.29717254638672,
16.00996971130371,
15.794167518615723,
15.616395950317383,
15.459056854248047,
15.306838989257812,
15.199165344238281,
15.208144187927246,
14.883454322814941,
14.787869453430176,
14.947835922241211,
14.5912504196167,
14.76955509185791,
14.617781639099121,
14.840407371520996,
14.83073616027832,
14.909119606018066,
14.89070987701416,
14.918207168579102,
14.939517974853516,
14.913643836975098,
14.863334655761719,
14.803299903869629,
14.751264572143555,
14.688116073608398,
14.63498306274414,
14.615056037902832,
14.680213928222656,
14.616259574890137,
14.707776069641113,
14.630264282226562,
14.644737243652344,
14.547430038452148,
14.529033660888672,
14.49357795715332,
14.411538124084473,
14.33312702178955,
14.260393142700195,
14.204919815063477,
14.130182266235352,
14.06987476348877,
14.010197639465332,
13.938552856445312,
13.750232696533203,
13.607213973999023,
13.457777976989746,
13.31512451171875,
13.167718887329102,
13.019341468811035,
12.8869047164917,
12.795098304748535,
12.685126304626465,
12.620392799377441,
12.58949089050293,
12.537697792053223,
12.496938705444336,
12.410022735595703,
12.346826553344727,
12.221966743469238,
12.122841835021973,
12.005624771118164
],
"invstddev": [
0.25952333211898804,
0.2590482831001282,
0.24866817891597748,
0.24776232242584229,
0.22200720012187958,
0.21363843977451324,
0.20652402937412262,
0.19909949600696564,
0.2021811604499817,
0.20355898141860962,
0.20546883344650269,
0.2061648815870285,
0.20569036900997162,
0.20412985980510712,
0.20357738435268402,
0.2041499763727188,
0.2055872678756714,
0.20807604491710663,
0.21054454147815704,
0.21341396868228912,
0.21418628096580505,
0.22065168619155884,
0.2248840034008026,
0.22723940014839172,
0.230172261595726,
0.23371541500091553,
0.23734734952449799,
0.23960146307945251,
0.24088498950004578,
0.241532102227211,
0.24218633770942688,
0.24371792376041412,
0.2447739839553833,
0.25564682483673096,
0.2632736265659332,
0.2549223005771637,
0.24608071148395538,
0.2464841604232788,
0.2470586597919464,
0.24785254895687103,
0.24904784560203552,
0.2503036856651306,
0.25226327776908875,
0.2532329559326172,
0.2527913451194763,
0.2518651783466339,
0.2504975199699402,
0.24836081266403198,
0.24765831232070923,
0.24767662584781647,
0.24965286254882812,
0.2501370906829834,
0.2508895993232727,
0.2512582540512085,
0.25150999426841736,
0.2525503635406494,
0.25313329696655273,
0.2534785270690918,
0.25330957770347595,
0.25366073846817017,
0.25502219796180725,
0.2608155608177185,
0.25662899017333984,
0.2558451294898987,
0.25671014189720154,
0.2577403485774994,
0.25914356112480164,
0.2596718966960907,
0.25953933596611023,
0.2610883116722107,
0.26132410764694214,
0.26272818446159363,
0.26397505402565,
0.26440608501434326,
0.26543495059013367,
0.26753780245780945,
0.26935192942619324,
0.26732245087623596,
0.26666897535324097,
0.2663257420063019
]
}
import logging
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from pytorch_lightning import LightningModule
from torchaudio.models import Hypothesis, RNNTBeamSearch
from torchaudio.prototype.models import conformer_rnnt_base
logger = logging.getLogger()
_expected_spm_vocab_size = 1023
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
r"""Learning rate scheduler that performs linear warmup and exponential annealing.
Args:
optimizer (torch.optim.Optimizer): optimizer to use.
warmup_steps (int): number of scheduler steps for which to warm up learning rate.
force_anneal_step (int): scheduler step at which annealing of learning rate begins.
anneal_factor (float): factor to scale base learning rate by at each annealing step.
last_epoch (int, optional): The index of last epoch. (Default: -1)
verbose (bool, optional): If ``True``, prints a message to stdout for
each update. (Default: ``False``)
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
force_anneal_step: int,
anneal_factor: float,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_steps
self.force_anneal_step = force_anneal_step
self.anneal_factor = anneal_factor
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count < self.force_anneal_step:
return [(min(1.0, self._step_count / self.warmup_steps)) * base_lr for base_lr in self.base_lrs]
else:
scaling_factor = self.anneal_factor ** (self._step_count - self.force_anneal_step)
return [scaling_factor * base_lr for base_lr in self.base_lrs]
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class ConformerRNNTModule(LightningModule):
def __init__(self, sp_model):
super().__init__()
self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
# ``conformer_rnnt_base`` hardcodes a specific Conformer RNN-T configuration.
# For greater customizability, please refer to ``conformer_rnnt_model``.
self.model = conformer_rnnt_base()
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum")
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96)
self.automatic_optimization = False
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _ = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": "epoch"}],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.features.size(0)
batch_sizes = self.all_gather(batch_size)
self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10.0)
opt.step()
# step every epoch
sch = self.lr_schedulers()
if self.trainer.is_last_batch:
sch.step()
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
import pathlib
from argparse import ArgumentParser
import sentencepiece as spm
from lightning import ConformerRNNTModule
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin
from transforms import get_data_module
def run_train(args):
seed_everything(1)
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [
checkpoint,
train_checkpoint,
lr_monitor,
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.nodes,
gpus=args.gpus,
accelerator="gpu",
strategy=DDPPlugin(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
)
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
model = ConformerRNNTModule(sp_model)
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path))
trainer.fit(model, data_module, ckpt_path=args.checkpoint_path)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
default=None,
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=120,
type=int,
help="Number of epochs to train for. (Default: 120)",
)
args = parser.parse_args()
run_train(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500.
Example:
python train_spm.py --librispeech-path ./datasets
"""
import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
def get_transcript_text(transcript_path):
with open(transcript_path) as f:
return [line.strip().split(" ", 1)[1].lower() for line in f]
def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*/*.trans.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=1023,
model_type="unigram",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def parse_args():
default_output_path = "./spm_unigram_1023.model"
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--librispeech-path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path(default_output_path),
type=pathlib.Path,
help=f"File to save model to. (Default: '{default_output_path}')",
)
return parser.parse_args()
def run_cli():
args = parse_args()
root = args.librispeech_path / "LibriSpeech"
splits = ["train-clean-100", "train-clean-360", "train-other-500"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)
model = train_spm(merged_transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
if __name__ == "__main__":
run_cli()
import json
import math
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
from data_module import LibriSpeechDataModule
from lightning import Batch
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
def _piecewise_linear_log(x):
x = x * _gain
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
def _extract_labels(sp_model, samples: List):
targets = [sp_model.encode(sample[2].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _extract_features(data_pipeline, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
class TrainTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.train_data_pipeline, samples)
targets, target_lengths = _extract_labels(self.sp_model, samples)
return Batch(features, feature_lengths, targets, target_lengths)
class ValTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
)
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.valid_data_pipeline, samples)
targets, target_lengths = _extract_labels(self.sp_model, samples)
return Batch(features, feature_lengths, targets, target_lengths)
class TestTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.val_transforms = ValTransform(global_stats_path, sp_model_path)
def __call__(self, sample):
return self.val_transforms([sample]), [sample]
def get_data_module(librispeech_path, global_stats_path, sp_model_path):
train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
)
import pytest import pytest
import torchaudio import torchaudio
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
@pytest.mark.parametrize( @pytest.mark.parametrize(
"bundle,lang,expected", "bundle,lang,expected",
[ [
(EMFORMER_RNNT_BASE_LIBRISPEECH, "en", "i have that curiosity beside me at this moment"), (EMFORMER_RNNT_BASE_LIBRISPEECH, "en", "i have that curiosity beside me at this moment"),
(EMFORMER_RNNT_BASE_MUSTC, "en", "I had that curiosity beside me at this moment."),
(EMFORMER_RNNT_BASE_TEDLIUM3, "en", "i had that curiosity beside me at this moment"),
], ],
) )
def test_rnnt(bundle, sample_speech, expected): def test_rnnt(bundle, sample_speech, expected):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment