wav2vec2_model_test.py 2.29 KB
Newer Older
1
2
import torchaudio
from torchaudio.models import (
3
4
5
6
7
8
9
10
11
12
13
14
15
    WAV2VEC2_BASE,
    WAV2VEC2_LARGE,
    WAV2VEC2_LARGE_LV60K,
    WAV2VEC2_ASR_BASE_10M,
    WAV2VEC2_ASR_BASE_100H,
    WAV2VEC2_ASR_BASE_960H,
    WAV2VEC2_ASR_LARGE_10M,
    WAV2VEC2_ASR_LARGE_100H,
    WAV2VEC2_ASR_LARGE_960H,
    WAV2VEC2_ASR_LARGE_LV60K_10M,
    WAV2VEC2_ASR_LARGE_LV60K_100H,
    WAV2VEC2_ASR_LARGE_LV60K_960H,
    WAV2VEC2_XLSR53,
16
    HUBERT_BASE,
17
18
    HUBERT_LARGE,
    HUBERT_XLARGE,
19
    HUBERT_ASR_LARGE,
20
    HUBERT_ASR_XLARGE,
21
22
23
24
25
26
27
)
import pytest


@pytest.mark.parametrize(
    "bundle",
    [
28
29
30
31
        WAV2VEC2_BASE,
        WAV2VEC2_LARGE,
        WAV2VEC2_LARGE_LV60K,
        WAV2VEC2_XLSR53,
32
        HUBERT_BASE,
33
34
        HUBERT_LARGE,
        HUBERT_XLARGE,
35
36
37
38
39
40
41
42
43
44
    ]
)
def test_pretraining_models(bundle):
    """Smoke test of downloading weights for pretraining models"""
    bundle.get_model()


@pytest.mark.parametrize(
    "bundle,expected",
    [
45
46
47
48
49
50
51
52
53
        (WAV2VEC2_ASR_BASE_10M, 'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_BASE_100H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_BASE_960H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_10M, 'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_100H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_960H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_LV60K_10M, 'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_LV60K_100H, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_LV60K_960H, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
54
        (HUBERT_ASR_LARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
55
        (HUBERT_ASR_XLARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|')
56
57
58
59
60
61
62
63
64
65
66
67
    ]
)
def test_finetune_asr_model(
        bundle,
        expected,
        sample_speech_16000_en,
        ctc_decoder,
):
    """Smoke test of downloading weights for fine-tuning models and simple transcription"""
    model = bundle.get_model().eval()
    waveform, sample_rate = torchaudio.load(sample_speech_16000_en)
    emission, _ = model(waveform)
68
    decoder = ctc_decoder(bundle.get_labels())
69
70
    result = decoder(emission[0])
    assert result == expected