Commit 9b4ee17c authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Fix download links of RNNT pipelines in prototype (#2444)

Summary:
In https://github.com/pytorch/audio/issues/2283, torchaudio's downloading function is updated to reduce code duplication. The links in `EMFORMER_RNNT_BASE_LIBRISPEECH` are updated, but the ones in prototype pipelines are not. This PR addresses it by updating the download links of `EMFORMER_RNNT_BASE_MUSTC` and `EMFORMER_RNNT_BASE_TEDLIUM3` in prototype. Corresponding integration tests are added as well.

Pull Request resolved: https://github.com/pytorch/audio/pull/2444

Reviewed By: mthrok

Differential Revision: D37389178

Pulled By: nateanl

fbshipit-source-id: 46598dd71c95be47d1e1b54cef89ea51d280e17a
parent 4ba7dc38
import pytest import pytest
import torchaudio import torchaudio
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
@pytest.mark.parametrize( @pytest.mark.parametrize(
"bundle,lang,expected", "bundle,lang,expected",
[(EMFORMER_RNNT_BASE_LIBRISPEECH, "en", "i have that curiosity beside me at this moment")], [
(EMFORMER_RNNT_BASE_LIBRISPEECH, "en", "i have that curiosity beside me at this moment"),
(EMFORMER_RNNT_BASE_MUSTC, "en", "I had that curiosity beside me at this moment."),
(EMFORMER_RNNT_BASE_TEDLIUM3, "en", "i had that curiosity beside me at this moment"),
],
) )
def test_rnnt(bundle, sample_speech, expected): def test_rnnt(bundle, sample_speech, expected):
feature_extractor = bundle.get_feature_extractor() feature_extractor = bundle.get_feature_extractor()
......
...@@ -5,10 +5,10 @@ from torchaudio.pipelines import RNNTBundle ...@@ -5,10 +5,10 @@ from torchaudio.pipelines import RNNTBundle
EMFORMER_RNNT_BASE_MUSTC = RNNTBundle( EMFORMER_RNNT_BASE_MUSTC = RNNTBundle(
_rnnt_path="emformer_rnnt_base_mustc.pt", _rnnt_path="models/emformer_rnnt_base_mustc.pt",
_rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501), _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501),
_global_stats_path="global_stats_rnnt_mustc.json", _global_stats_path="pipeline-assets/global_stats_rnnt_mustc.json",
_sp_model_path="spm_bpe_500_mustc.model", _sp_model_path="pipeline-assets/spm_bpe_500_mustc.model",
_right_padding=4, _right_padding=4,
_blank=500, _blank=500,
_sample_rate=16000, _sample_rate=16000,
...@@ -27,10 +27,10 @@ EMFORMER_RNNT_BASE_MUSTC.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeli ...@@ -27,10 +27,10 @@ EMFORMER_RNNT_BASE_MUSTC.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeli
EMFORMER_RNNT_BASE_TEDLIUM3 = RNNTBundle( EMFORMER_RNNT_BASE_TEDLIUM3 = RNNTBundle(
_rnnt_path="emformer_rnnt_base_tedlium3.pt", _rnnt_path="models/emformer_rnnt_base_tedlium3.pt",
_rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501), _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=501),
_global_stats_path="global_stats_rnnt_tedlium3.json", _global_stats_path="pipeline-assets/global_stats_rnnt_tedlium3.json",
_sp_model_path="spm_bpe_500_tedlium3.model", _sp_model_path="pipeline-assets/spm_bpe_500_tedlium3.model",
_right_padding=4, _right_padding=4,
_blank=500, _blank=500,
_sample_rate=16000, _sample_rate=16000,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment