"vscode:/vscode.git/clone" did not exist on "8a63aa5e4f02fde83755d1a5066713dffcd76248"
Commit 99b5ef5c authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add EMFORMER_RNNT_BASE_MUSTC bundle to torchaudio.prototype (#2241)

Summary:
This PR provides a RNNTBundle that is pre-trained on the MuST-C release v2.0 dataset.
The model preserves the casing and punctuations of the transcripts when training the SentencePiece model.

Here is the model performance on the dev and test sets of MuST-C 2.0:
|                   |          WER |
|:-----------------:|-------------:|
| dev               |       0.190  |
| tst-COMMON        |       0.213  |
| tst-HE            |       0.186  |

Pull Request resolved: https://github.com/pytorch/audio/pull/2241

Reviewed By: mthrok

Differential Revision: D34267792

Pulled By: nateanl

fbshipit-source-id: 67bca9f277e66d41a4530d01615f249b3cec7167
parent 81f56f64
...@@ -9,6 +9,14 @@ The pipelines subpackage contains APIs to models with pretrained weights and rel ...@@ -9,6 +9,14 @@ The pipelines subpackage contains APIs to models with pretrained weights and rel
RNN-T Streaming/Non-Streaming ASR RNN-T Streaming/Non-Streaming ASR
--------------------------------- ---------------------------------
EMFORMER_RNNT_BASE_MUSTC
~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_MUSTC
:no-value:
EMFORMER_RNNT_BASE_TEDLIUM3 EMFORMER_RNNT_BASE_TEDLIUM3
~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
from .rnnt_pipeline import EMFORMER_RNNT_BASE_TEDLIUM3 from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
__all__ = [ __all__ = [
"EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3", "EMFORMER_RNNT_BASE_TEDLIUM3",
] ]
...@@ -4,6 +4,28 @@ from torchaudio.models import emformer_rnnt_base ...@@ -4,6 +4,28 @@ from torchaudio.models import emformer_rnnt_base
from torchaudio.pipelines import RNNTBundle from torchaudio.pipelines import RNNTBundle
EMFORMER_RNNT_BASE_MUSTC = RNNTBundle(
_rnnt_path="emformer_rnnt_base_mustc.pt",
_rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501),
_global_stats_path="global_stats_rnnt_mustc.json",
_sp_model_path="spm_bpe_500_mustc.model",
_right_padding=4,
_blank=500,
_sample_rate=16000,
_n_fft=400,
_n_mels=80,
_hop_length=160,
_segment_length=16,
_right_context_length=4,
)
EMFORMER_RNNT_BASE_MUSTC.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference.
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on MuST-C release v2.0 dataset using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with ``num_symbols=501``.
Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions.
"""
EMFORMER_RNNT_BASE_TEDLIUM3 = RNNTBundle( EMFORMER_RNNT_BASE_TEDLIUM3 = RNNTBundle(
_rnnt_path="emformer_rnnt_base_tedlium3.pt", _rnnt_path="emformer_rnnt_base_tedlium3.pt",
_rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501), _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment