wav2vec2_pipeline_test.py 4 KB
Newer Older
mayp777's avatar
UPDATE  
mayp777 committed
1
2
import os

3
import pytest
4
import torchaudio
5
from torchaudio.pipelines import (
6
7
    HUBERT_ASR_LARGE,
    HUBERT_ASR_XLARGE,
8
    HUBERT_BASE,
9
10
    HUBERT_LARGE,
    HUBERT_XLARGE,
11
    VOXPOPULI_ASR_BASE_10K_DE,
12
    VOXPOPULI_ASR_BASE_10K_EN,
13
    VOXPOPULI_ASR_BASE_10K_ES,
14
    VOXPOPULI_ASR_BASE_10K_FR,
15
    VOXPOPULI_ASR_BASE_10K_IT,
16
17
18
19
20
21
22
23
24
25
26
27
28
    WAV2VEC2_ASR_BASE_100H,
    WAV2VEC2_ASR_BASE_10M,
    WAV2VEC2_ASR_BASE_960H,
    WAV2VEC2_ASR_LARGE_100H,
    WAV2VEC2_ASR_LARGE_10M,
    WAV2VEC2_ASR_LARGE_960H,
    WAV2VEC2_ASR_LARGE_LV60K_100H,
    WAV2VEC2_ASR_LARGE_LV60K_10M,
    WAV2VEC2_ASR_LARGE_LV60K_960H,
    WAV2VEC2_BASE,
    WAV2VEC2_LARGE,
    WAV2VEC2_LARGE_LV60K,
    WAV2VEC2_XLSR53,
mayp777's avatar
UPDATE  
mayp777 committed
29
30
31
32
33
    WAV2VEC2_XLSR_1B,
    WAV2VEC2_XLSR_300M,
    WAVLM_BASE,
    WAVLM_BASE_PLUS,
    WAVLM_LARGE,
34
35
36
37
38
39
)


@pytest.mark.parametrize(
    "bundle",
    [
40
41
42
43
        WAV2VEC2_BASE,
        WAV2VEC2_LARGE,
        WAV2VEC2_LARGE_LV60K,
        WAV2VEC2_XLSR53,
44
        HUBERT_BASE,
45
46
        HUBERT_LARGE,
        HUBERT_XLARGE,
mayp777's avatar
UPDATE  
mayp777 committed
47
48
49
        WAVLM_BASE,
        WAVLM_BASE_PLUS,
        WAVLM_LARGE,
50
    ],
51
52
53
54
55
56
)
def test_pretraining_models(bundle):
    """Smoke test of downloading weights for pretraining models"""
    bundle.get_model()


mayp777's avatar
UPDATE  
mayp777 committed
57
58
59
60
61
62
63
64
65
66
67
68
69
@pytest.mark.skipif("CI" not in os.environ, reason="Run tests only in CI environment.")
@pytest.mark.parametrize(
    "bundle",
    [
        WAV2VEC2_XLSR_300M,
        WAV2VEC2_XLSR_1B,
    ],
)
def test_xlsr_pretraining_models(bundle):
    """Smoke test of downloading weights for pretraining models"""
    bundle.get_model()


70
@pytest.mark.parametrize(
moto's avatar
moto committed
71
    "bundle,lang,expected",
72
    [
73
74
75
76
77
78
        (WAV2VEC2_ASR_BASE_10M, "en", "I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|"),
        (WAV2VEC2_ASR_BASE_100H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
        (WAV2VEC2_ASR_BASE_960H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
        (WAV2VEC2_ASR_LARGE_10M, "en", "I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|"),
        (WAV2VEC2_ASR_LARGE_100H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
        (WAV2VEC2_ASR_LARGE_960H, "en", "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
mayp777's avatar
UPDATE  
mayp777 committed
79
        (WAV2VEC2_ASR_LARGE_LV60K_10M, "en", "I|HAD|THAT|CURIOUSITY|BESID|ME|AT|THISS|MOMENT|"),
80
81
82
83
84
85
86
        (WAV2VEC2_ASR_LARGE_LV60K_100H, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
        (WAV2VEC2_ASR_LARGE_LV60K_960H, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
        (HUBERT_ASR_LARGE, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
        (HUBERT_ASR_XLARGE, "en", "I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"),
        (
            VOXPOPULI_ASR_BASE_10K_EN,
            "en2",
87
88
            "i|hope|that|we|will|see|a|ddrasstic|decrease|of|funding|for|the|failed|eu|project|and|that|more|money|will|come|back|to|the|taxpayers",  # noqa: E501
        ),
89
90
91
        (
            VOXPOPULI_ASR_BASE_10K_ES,
            "es",
92
93
            "la|primera|que|es|imprescindible|pensar|a|pequeña|a|escala|para|implicar|y|complementar|así|la|actuación|global",  # noqa: E501
        ),
94
95
96
97
        (VOXPOPULI_ASR_BASE_10K_DE, "de", "dabei|spielt|auch|eine|sorgfältige|berichterstattung|eine|wichtige|rolle"),
        (
            VOXPOPULI_ASR_BASE_10K_FR,
            "fr",
98
99
            "la|commission|va|faire|des|propositions|sur|ce|sujet|comment|mettre|en|place|cette|capacité|fiscale|et|le|conseil|européen|y|reviendra|sour|les|sujets|au|moins|de|mars",  # noqa: E501
        ),
100
101
102
103
        (
            VOXPOPULI_ASR_BASE_10K_IT,
            "it",
            "credo|che|illatino|non|sia|contemplato|tra|le|traduzioni|e|quindi|mi|attengo|allitaliano",
104
        ),
105
    ],
106
107
)
def test_finetune_asr_model(
108
109
110
111
112
    bundle,
    lang,
    expected,
    sample_speech,
    ctc_decoder,
113
114
115
):
    """Smoke test of downloading weights for fine-tuning models and simple transcription"""
    model = bundle.get_model().eval()
moto's avatar
moto committed
116
    waveform, sample_rate = torchaudio.load(sample_speech)
117
    emission, _ = model(waveform)
118
    decoder = ctc_decoder(bundle.get_labels())
119
120
    result = decoder(emission[0])
    assert result == expected