wav2vec2_pipeline_test.py 3.36 KB
Newer Older
1
import torchaudio
2
from torchaudio.pipelines import (
3
4
5
6
7
8
9
10
11
12
13
14
15
    WAV2VEC2_BASE,
    WAV2VEC2_LARGE,
    WAV2VEC2_LARGE_LV60K,
    WAV2VEC2_ASR_BASE_10M,
    WAV2VEC2_ASR_BASE_100H,
    WAV2VEC2_ASR_BASE_960H,
    WAV2VEC2_ASR_LARGE_10M,
    WAV2VEC2_ASR_LARGE_100H,
    WAV2VEC2_ASR_LARGE_960H,
    WAV2VEC2_ASR_LARGE_LV60K_10M,
    WAV2VEC2_ASR_LARGE_LV60K_100H,
    WAV2VEC2_ASR_LARGE_LV60K_960H,
    WAV2VEC2_XLSR53,
16
    HUBERT_BASE,
17
18
    HUBERT_LARGE,
    HUBERT_XLARGE,
19
    HUBERT_ASR_LARGE,
20
    HUBERT_ASR_XLARGE,
21
    VOXPOPULI_ASR_BASE_10K_EN,
22
    VOXPOPULI_ASR_BASE_10K_ES,
23
    VOXPOPULI_ASR_BASE_10K_DE,
24
    VOXPOPULI_ASR_BASE_10K_FR,
25
    VOXPOPULI_ASR_BASE_10K_IT,
26
27
28
29
30
31
32
)
import pytest


@pytest.mark.parametrize(
    "bundle",
    [
33
34
35
36
        WAV2VEC2_BASE,
        WAV2VEC2_LARGE,
        WAV2VEC2_LARGE_LV60K,
        WAV2VEC2_XLSR53,
37
        HUBERT_BASE,
38
39
        HUBERT_LARGE,
        HUBERT_XLARGE,
40
41
42
43
44
45
46
47
    ]
)
def test_pretraining_models(bundle):
    """Smoke test of downloading weights for pretraining models"""
    bundle.get_model()


@pytest.mark.parametrize(
moto's avatar
moto committed
48
    "bundle,lang,expected",
49
    [
moto's avatar
moto committed
50
51
52
53
54
55
56
57
58
59
60
        (WAV2VEC2_ASR_BASE_10M, 'en', 'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_BASE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_BASE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_10M, 'en', 'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_LV60K_10M, 'en', 'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_LV60K_100H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (WAV2VEC2_ASR_LARGE_LV60K_960H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (HUBERT_ASR_LARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
        (HUBERT_ASR_XLARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
61
        (VOXPOPULI_ASR_BASE_10K_EN, 'en2', 'i|hope|that|we|will|see|a|ddrasstic|decrease|of|funding|for|the|failed|eu|project|and|that|more|money|will|come|back|to|the|taxpayers'),  # noqa: E501
62
        (VOXPOPULI_ASR_BASE_10K_ES, 'es', "la|primera|que|es|imprescindible|pensar|a|pequeña|a|escala|para|implicar|y|complementar|así|la|actuación|global"),  # noqa: E501
63
        (VOXPOPULI_ASR_BASE_10K_DE, 'de', "dabei|spielt|auch|eine|sorgfältige|berichterstattung|eine|wichtige|rolle"),
64
        (VOXPOPULI_ASR_BASE_10K_FR, 'fr', 'la|commission|va|faire|des|propositions|sur|ce|sujet|comment|mettre|en|place|cette|capacité|fiscale|et|le|conseil|européen|y|reviendra|sour|les|sujets|au|moins|de|mars'),  # noqa: E501
65
        (VOXPOPULI_ASR_BASE_10K_IT, 'it', 'credo|che|illatino|non|sia|contemplato|tra|le|traduzioni|e|quindi|mi|attengo|allitaliano')  # noqa: E501
66
67
68
69
    ]
)
def test_finetune_asr_model(
        bundle,
moto's avatar
moto committed
70
        lang,
71
        expected,
moto's avatar
moto committed
72
        sample_speech,
73
74
75
76
        ctc_decoder,
):
    """Smoke test of downloading weights for fine-tuning models and simple transcription"""
    model = bundle.get_model().eval()
moto's avatar
moto committed
77
    waveform, sample_rate = torchaudio.load(sample_speech)
78
    emission, _ = model(waveform)
79
    decoder = ctc_decoder(bundle.get_labels())
80
81
    result = decoder(emission[0])
    assert result == expected