Commit fdea0a7c authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor pipeline_demo script in emformer_rnnt recipes (#2239)

Summary:
- Use dictionary to select the `RNNTBundle` and the corresponding dataset.
- Use the dictionary's keys as choices in ArgumentParser

Pull Request resolved: https://github.com/pytorch/audio/pull/2239

Reviewed By: mthrok

Differential Revision: D34267070

Pulled By: nateanl

fbshipit-source-id: 99c7942d5c7c1518694e1ae02a55a7decd87c220
parent e3b40d1c
......@@ -6,39 +6,43 @@ python pipeline_demo.py --model-type librispeech --dataset-path ./datasets/libri
"""
import logging
import pathlib
from argparse import ArgumentParser
from argparse import ArgumentParser, RawTextHelpFormatter
from dataclasses import dataclass
from functools import partial
from typing import Callable
import torch
import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
from torchaudio.pipelines import RNNTBundle
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3
logger = logging.getLogger()
logger = logging.getLogger(__name__)
def get_dataset(model_type, dataset_path):
if model_type == MODEL_TYPE_LIBRISPEECH:
return torchaudio.datasets.LIBRISPEECH(dataset_path, url="test-clean")
elif model_type == MODEL_TYPE_TEDLIUM3:
return torchaudio.datasets.TEDLIUM(dataset_path, release="release3", subset="test")
else:
raise ValueError(f"Encountered unsupported model type {model_type}.")
@dataclass
class Config:
dataset: Callable
bundle: RNNTBundle
def get_pipeline_bundle(model_type):
if model_type == MODEL_TYPE_LIBRISPEECH:
return EMFORMER_RNNT_BASE_LIBRISPEECH
elif model_type == MODEL_TYPE_TEDLIUM3:
return EMFORMER_RNNT_BASE_TEDLIUM3
else:
raise ValueError(f"Encountered unsupported model type {model_type}.")
_CONFIGS = {
MODEL_TYPE_LIBRISPEECH: Config(
partial(torchaudio.datasets.LIBRISPEECH, url="test-clean"),
EMFORMER_RNNT_BASE_LIBRISPEECH,
),
MODEL_TYPE_TEDLIUM3: Config(
partial(torchaudio.datasets.TEDLIUM, release="release3", subset="test"),
EMFORMER_RNNT_BASE_TEDLIUM3,
),
}
def run_eval_streaming(args):
dataset = get_dataset(args.model_type, args.dataset_path)
bundle = get_pipeline_bundle(args.model_type)
dataset = _CONFIGS[args.model_type].dataset(args.dataset_path)
bundle = _CONFIGS[args.model_type].bundle
decoder = bundle.get_decoder()
token_processor = bundle.get_token_processor()
feature_extractor = bundle.get_feature_extractor()
......@@ -72,8 +76,8 @@ def run_eval_streaming(args):
def parse_args():
parser = ArgumentParser()
parser.add_argument("--model-type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument("--model-type", type=str, choices=_CONFIGS.keys(), required=True)
parser.add_argument(
"--dataset-path",
type=pathlib.Path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment