Commit 38569ef0 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add EMFORMER_RNNT_BASE_MUSTC into pipeline demo script (#2248)

Summary:
This PR adds ``EMFORMER_RNNT_BASE_MUSTC`` support in `pipeline_demo.py`. The bundle is trained on MuST-C release 2.0 dataset. The model  preserves the casing and punctuations in the transcript.

Here is a screen recording of how it works in streaming and non-streaming modes:

https://user-images.githubusercontent.com/8653221/154356521-fe84bdc1-fb0c-41bd-8729-9edbb3224a07.mov

Pull Request resolved: https://github.com/pytorch/audio/pull/2248

Reviewed By: hwangjeff

Differential Revision: D34282598

Pulled By: nateanl

fbshipit-source-id: 42ed7e2623031dfebd176ef0c6bfd70da3c897d4
parent 87d79889
......@@ -13,11 +13,11 @@ from typing import Callable
import torch
import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_MUSTC, MODEL_TYPE_TEDLIUM3
from mustc.dataset import MUSTC
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
from torchaudio.pipelines import RNNTBundle
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
logger = logging.getLogger(__name__)
......@@ -33,6 +33,10 @@ _CONFIGS = {
partial(torchaudio.datasets.LIBRISPEECH, url="test-clean"),
EMFORMER_RNNT_BASE_LIBRISPEECH,
),
MODEL_TYPE_MUSTC: Config(
partial(MUSTC, subset="tst-COMMON"),
EMFORMER_RNNT_BASE_MUSTC,
),
MODEL_TYPE_TEDLIUM3: Config(
partial(torchaudio.datasets.TEDLIUM, release="release3", subset="test"),
EMFORMER_RNNT_BASE_TEDLIUM3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment