Commit 16d02a9e authored by nateanl's avatar nateanl Committed by Facebook GitHub Bot
Browse files

Refactor pipeline_demo.py to support variant EMFORMER_RNNT bundles (#2203)

Summary:
We refactored the demo script that can apply RNNT decoding using both `torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH` and `torchaudio.prototype.pipelines.EMFORMER_RNNT_BASE_TEDLIUM3` in both streaming and non-streaming mode. (The first hypothesis prediction is streaming and the second one is non-streaming).

We convert each token id sequence to word pieces and then manually join the word pieces. This allows us to preserve leading whitespaces on output strings and therefore account for word breaks and continuations across token processor invocations, which is particularly useful when performing streaming ASR.

https://user-images.githubusercontent.com/8653221/153627956-f0806f18-3c1c-44df-ac07-ec2def58a0cf.mov

Pull Request resolved: https://github.com/pytorch/audio/pull/2203

Reviewed By: carolineechen

Differential Revision: D34006388

Pulled By: nateanl

fbshipit-source-id: 3d31173ee10cdab8a2f5802570e22b50fcce5632
parent bbdbd582
...@@ -12,6 +12,11 @@ This directory contains sample implementations of training and evaluation pipeli ...@@ -12,6 +12,11 @@ This directory contains sample implementations of training and evaluation pipeli
[`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on a given dataset. [`eval.py`](./eval.py) evaluates a trained Emformer RNN-T model on a given dataset.
### Pipeline Demo
[`pipeline_demo.py`](./pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH`
or `EMFORMER_RNNT_BASE_TEDLIUM3` bundle that wraps a pre-trained Emformer RNN-T produced by the corresponding recipe below to perform streaming and full-context ASR on several audio samples.
## Model Types ## Model Types
Currently, we have training recipes for the LibriSpeech and TED-LIUM Release 3 datasets. Currently, we have training recipes for the LibriSpeech and TED-LIUM Release 3 datasets.
...@@ -39,8 +44,6 @@ The table below contains WER results for various splits. ...@@ -39,8 +44,6 @@ The table below contains WER results for various splits.
| dev-clean | 0.0415 | | dev-clean | 0.0415 |
| dev-other | 0.1110 | | dev-other | 0.1110 |
[`librispeech/pipeline_demo.py`](./librispeech/pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH` bundle that wraps a pre-trained Emformer RNN-T produced by the above recipe to perform streaming and full-context ASR on several LibriSpeech samples.
### TED-LIUM Release 3 ### TED-LIUM Release 3
Whereas the LibriSpeech model is configured with a vocabulary size of 4096, the TED-LIUM Release 3 model is configured with a vocabulary size of 500. Consequently, the TED-LIUM Release 3 model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol); the rest of the model is identical to the LibriSpeech model. Whereas the LibriSpeech model is configured with a vocabulary size of 4096, the TED-LIUM Release 3 model is configured with a vocabulary size of 500. Consequently, the TED-LIUM Release 3 model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol); the rest of the model is identical to the LibriSpeech model.
......
import logging
import pathlib import pathlib
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
import torchaudio import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3
def cli_main(): logger = logging.getLogger()
parser = ArgumentParser()
parser.add_argument(
"--librispeech_path", def get_dataset(model_type, dataset_path):
type=pathlib.Path, if model_type == MODEL_TYPE_LIBRISPEECH:
required=True, return torchaudio.datasets.LIBRISPEECH(dataset_path, url="test-clean")
help="Path to LibriSpeech datasets.", elif model_type == MODEL_TYPE_TEDLIUM3:
) return torchaudio.datasets.TEDLIUM(dataset_path, release="release3", subset="test")
args = parser.parse_args() else:
raise ValueError(f"Encountered unsupported model type {model_type}.")
dataset = torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url="test-clean")
decoder = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder()
token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor() def get_pipeline_bundle(model_type):
feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor() if model_type == MODEL_TYPE_LIBRISPEECH:
streaming_feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_streaming_feature_extractor() return EMFORMER_RNNT_BASE_LIBRISPEECH
elif model_type == MODEL_TYPE_TEDLIUM3:
hop_length = EMFORMER_RNNT_BASE_LIBRISPEECH.hop_length return EMFORMER_RNNT_BASE_TEDLIUM3
num_samples_segment = EMFORMER_RNNT_BASE_LIBRISPEECH.segment_length * hop_length else:
num_samples_segment_right_context = ( raise ValueError(f"Encountered unsupported model type {model_type}.")
num_samples_segment + EMFORMER_RNNT_BASE_LIBRISPEECH.right_context_length * hop_length
)
def run_eval_streaming(args):
dataset = get_dataset(args.model_type, args.dataset_path)
bundle = get_pipeline_bundle(args.model_type)
decoder = bundle.get_decoder()
token_processor = bundle.get_token_processor()
feature_extractor = bundle.get_feature_extractor()
streaming_feature_extractor = bundle.get_streaming_feature_extractor()
hop_length = bundle.hop_length
num_samples_segment = bundle.segment_length * hop_length
num_samples_segment_right_context = num_samples_segment + bundle.right_context_length * hop_length
for idx in range(10): for idx in range(10):
sample = dataset[idx] sample = dataset[idx]
waveform = sample[0].squeeze() waveform = sample[0].squeeze()
# Streaming decode. # Streaming decode.
state, hypothesis = None, None state, hypothesis = None, None
for idx in range(0, len(waveform), num_samples_segment): for idx in range(0, len(waveform), num_samples_segment):
...@@ -53,5 +65,30 @@ def cli_main(): ...@@ -53,5 +65,30 @@ def cli_main():
print() print()
def parse_args():
parser = ArgumentParser()
parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
parser.add_argument(
"--dataset_path",
type=pathlib.Path,
help="Path to dataset.",
required=True,
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = parse_args()
init_logger(args.debug)
run_eval_streaming(args)
if __name__ == "__main__": if __name__ == "__main__":
cli_main() cli_main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment