rnnt_pipeline_test.py 974 Bytes
Newer Older
1
2
import pytest
import torchaudio
3
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
mayp777's avatar
UPDATE  
mayp777 committed
4
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
5
6
7
8


@pytest.mark.parametrize(
    "bundle,lang,expected",
9
10
    [
        (EMFORMER_RNNT_BASE_LIBRISPEECH, "en", "i have that curiosity beside me at this moment"),
mayp777's avatar
UPDATE  
mayp777 committed
11
12
        (EMFORMER_RNNT_BASE_MUSTC, "en", "I had that curiosity beside me at this moment."),
        (EMFORMER_RNNT_BASE_TEDLIUM3, "en", "i had that curiosity beside me at this moment"),
13
    ],
14
15
16
17
18
19
20
21
)
def test_rnnt(bundle, sample_speech, expected):
    feature_extractor = bundle.get_feature_extractor()
    decoder = bundle.get_decoder().eval()
    token_processor = bundle.get_token_processor()
    waveform, _ = torchaudio.load(sample_speech)
    features, length = feature_extractor(waveform.squeeze())
    hypotheses = decoder(features, length, 10)
22
    text = token_processor(hypotheses[0][0])
23
    assert text == expected