wav2vec2_model_test.py 1.07 KB
Newer Older
1
2
3
import torchaudio
from torchaudio.models import (
    HUBERT_BASE,
4
5
    HUBERT_LARGE,
    HUBERT_XLARGE,
6
    HUBERT_ASR_LARGE,
7
    HUBERT_ASR_XLARGE,
8
9
10
11
12
13
14
15
)
import pytest


@pytest.mark.parametrize(
    "bundle",
    [
        HUBERT_BASE,
16
17
        HUBERT_LARGE,
        HUBERT_XLARGE,
18
19
20
21
22
23
24
25
26
27
28
    ]
)
def test_pretraining_models(bundle):
    """Smoke test of downloading weights for pretraining models"""
    bundle.get_model()


@pytest.mark.parametrize(
    "bundle,expected",
    [
        (HUBERT_ASR_LARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
29
        (HUBERT_ASR_XLARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|')
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    ]
)
def test_finetune_asr_model(
        bundle,
        expected,
        sample_speech_16000_en,
        ctc_decoder,
):
    """Smoke test of downloading weights for fine-tuning models and simple transcription"""
    model = bundle.get_model().eval()
    waveform, sample_rate = torchaudio.load(sample_speech_16000_en)
    emission, _ = model(waveform)
    decoder = ctc_decoder(bundle.labels)
    result = decoder(emission[0])
    assert result == expected