Commit fdea0a7c authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor pipeline_demo script in emformer_rnnt recipes (#2239)

Summary:
- Use dictionary to select the `RNNTBundle` and the corresponding dataset.
- Use the dictionary's keys as choices in ArgumentParser

Pull Request resolved: https://github.com/pytorch/audio/pull/2239

Reviewed By: mthrok

Differential Revision: D34267070

Pulled By: nateanl

fbshipit-source-id: 99c7942d5c7c1518694e1ae02a55a7decd87c220
parent e3b40d1c
...@@ -6,39 +6,43 @@ python pipeline_demo.py --model-type librispeech --dataset-path ./datasets/libri ...@@ -6,39 +6,43 @@ python pipeline_demo.py --model-type librispeech --dataset-path ./datasets/libri
""" """
import logging import logging
import pathlib import pathlib
from argparse import ArgumentParser from argparse import ArgumentParser, RawTextHelpFormatter
from dataclasses import dataclass
from functools import partial
from typing import Callable
import torch import torch
import torchaudio import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3 from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
from torchaudio.pipelines import RNNTBundle
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3 from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3
logger = logging.getLogger() logger = logging.getLogger(__name__)
def get_dataset(model_type, dataset_path): @dataclass
if model_type == MODEL_TYPE_LIBRISPEECH: class Config:
return torchaudio.datasets.LIBRISPEECH(dataset_path, url="test-clean") dataset: Callable
elif model_type == MODEL_TYPE_TEDLIUM3: bundle: RNNTBundle
return torchaudio.datasets.TEDLIUM(dataset_path, release="release3", subset="test")
else:
raise ValueError(f"Encountered unsupported model type {model_type}.")
def get_pipeline_bundle(model_type): _CONFIGS = {
if model_type == MODEL_TYPE_LIBRISPEECH: MODEL_TYPE_LIBRISPEECH: Config(
return EMFORMER_RNNT_BASE_LIBRISPEECH partial(torchaudio.datasets.LIBRISPEECH, url="test-clean"),
elif model_type == MODEL_TYPE_TEDLIUM3: EMFORMER_RNNT_BASE_LIBRISPEECH,
return EMFORMER_RNNT_BASE_TEDLIUM3 ),
else: MODEL_TYPE_TEDLIUM3: Config(
raise ValueError(f"Encountered unsupported model type {model_type}.") partial(torchaudio.datasets.TEDLIUM, release="release3", subset="test"),
EMFORMER_RNNT_BASE_TEDLIUM3,
),
}
def run_eval_streaming(args): def run_eval_streaming(args):
dataset = get_dataset(args.model_type, args.dataset_path) dataset = _CONFIGS[args.model_type].dataset(args.dataset_path)
bundle = get_pipeline_bundle(args.model_type) bundle = _CONFIGS[args.model_type].bundle
decoder = bundle.get_decoder() decoder = bundle.get_decoder()
token_processor = bundle.get_token_processor() token_processor = bundle.get_token_processor()
feature_extractor = bundle.get_feature_extractor() feature_extractor = bundle.get_feature_extractor()
...@@ -72,8 +76,8 @@ def run_eval_streaming(args): ...@@ -72,8 +76,8 @@ def run_eval_streaming(args):
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument("--model-type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True) parser.add_argument("--model-type", type=str, choices=_CONFIGS.keys(), required=True)
parser.add_argument( parser.add_argument(
"--dataset-path", "--dataset-path",
type=pathlib.Path, type=pathlib.Path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment