Commit a1dc9e0a authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add RNNTBundle with weights pre-trained on tedlium3 dataset (#2177)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2177

Reviewed By: hwangjeff

Differential Revision: D33893052

Pulled By: nateanl

fbshipit-source-id: 00ff011eb96662b162c0327196a9564721e9c8f7
parent b986e9ef
...@@ -59,6 +59,7 @@ Prototype API References ...@@ -59,6 +59,7 @@ Prototype API References
prototype prototype
prototype.io prototype.io
prototype.ctc_decoder prototype.ctc_decoder
prototype.pipelines
Getting Started Getting Started
--------------- ---------------
......
torchaudio.prototype.pipelines
==============================
.. py:module:: torchaudio.prototype.pipelines
.. currentmodule:: torchaudio.prototype.pipelines
The pipelines subpackage contains APIs to models with pretrained weights and relevant utilities.
RNN-T Streaming/Non-Streaming ASR
---------------------------------
EMFORMER_RNNT_BASE_TEDLIUM3
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_TEDLIUM3
:no-value:
...@@ -19,3 +19,4 @@ imported explicitly, e.g. ...@@ -19,3 +19,4 @@ imported explicitly, e.g.
.. toctree:: .. toctree::
prototype.io prototype.io
prototype.ctc_decoder prototype.ctc_decoder
prototype.pipelines
...@@ -29,3 +29,14 @@ Sample SLURM command: ...@@ -29,3 +29,14 @@ Sample SLURM command:
``` ```
srun python eval.py --checkpoint-path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium-path ./datasets/ --sp-model-path ./spm-bpe-500.model --use-cuda srun python eval.py --checkpoint-path ./experiments/checkpoints/epoch=119-step=254999.ckpt --tedlium-path ./datasets/ --sp-model-path ./spm-bpe-500.model --use-cuda
``` ```
### Evaluation using `torchaudio.pipelines.EMFORMER_RNNT_BASE_TEDLIUM3` bundle
[`eval_pipeline.py`](./eval_pipeline.py) evaluates the `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3.
You should be able to get identical WER results in the above table.
Sample SLURM command:
```
srun python eval_pipeline.py --tedlium-path ./datasets/ --use-cuda
```
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import torch
import torchaudio
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3
logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def _eval_subset(tedlium_path, subset, feature_extractor, decoder, token_processor, use_cuda):
total_edit_distance = 0
total_length = 0
if subset == "dev":
dataset = torchaudio.datasets.TEDLIUM(tedlium_path, release="release3", subset="dev")
elif subset == "test":
dataset = torchaudio.datasets.TEDLIUM(tedlium_path, release="release3", subset="test")
with torch.no_grad():
for idx in range(len(dataset)):
sample = dataset[idx]
waveform = sample[0].squeeze()
if use_cuda:
waveform = waveform.to(device="cuda")
actual = sample[2].replace("\n", "")
if actual == "ignore_time_segment_in_scoring":
continue
features, length = feature_extractor(waveform)
hypos = decoder(features, length, 20)
hypothesis = hypos[0]
hypothesis = token_processor(hypothesis.tokens)
total_edit_distance += compute_word_level_distance(actual, hypothesis)
total_length += len(actual.split())
if idx % 100 == 0:
print(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
print(f"Final WER for {subset} set: {total_edit_distance / total_length}")
def run_eval_pipeline(args):
decoder = EMFORMER_RNNT_BASE_TEDLIUM3.get_decoder()
token_processor = EMFORMER_RNNT_BASE_TEDLIUM3.get_token_processor()
feature_extractor = EMFORMER_RNNT_BASE_TEDLIUM3.get_feature_extractor()
if args.use_cuda:
feature_extractor = feature_extractor.to(device="cuda").eval()
decoder = decoder.to(device="cuda")
_eval_subset(args.tedlium_path, "dev", feature_extractor, decoder, token_processor, args.use_cuda)
_eval_subset(args.tedlium_path, "test", feature_extractor, decoder, token_processor, args.use_cuda)
def _parse_args():
parser = ArgumentParser(
description=__doc__,
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--tedlium-path",
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
run_eval_pipeline(args)
if __name__ == "__main__":
cli_main()
from .rnnt_pipeline import EMFORMER_RNNT_BASE_TEDLIUM3
__all__ = [
"EMFORMER_RNNT_BASE_TEDLIUM3",
]
from functools import partial
from torchaudio.models import emformer_rnnt_base
from torchaudio.pipelines import RNNTBundle
EMFORMER_RNNT_BASE_TEDLIUM3 = RNNTBundle(
_rnnt_path="emformer_rnnt_base_tedlium3.pt",
_rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501),
_global_stats_path="global_stats_rnnt_tedlium3.json",
_sp_model_path="spm_bpe_500_tedlium3.model",
_right_padding=4,
_blank=500,
_sample_rate=16000,
_n_fft=400,
_n_mels=80,
_hop_length=160,
_segment_length=16,
_right_context_length=4,
)
EMFORMER_RNNT_BASE_TEDLIUM3.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference.
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on TED-LIUM Release 3 dataset using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/tedlium3_emformer_rnnt>`__ with ``num_symbols=501``.
Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment