conftest.py 3.33 KB
Newer Older
1
2
3
import os
import shutil

4
import pytest
5
import torch
6
import torchaudio
7
8
9


class GreedyCTCDecoder(torch.nn.Module):
10
    def __init__(self, labels, blank: int = 0):
11
        super().__init__()
12
        self.blank = blank
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
        self.labels = labels

    def forward(self, logits: torch.Tensor) -> str:
        """Given a sequence logits over labels, get the best path string

        Args:
            logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
            str: The resulting transcript
        """
        best_path = torch.argmax(logits, dim=-1)  # [num_seq,]
        best_path = torch.unique_consecutive(best_path, dim=-1)
        hypothesis = []
        for i in best_path:
28
29
            if i != self.blank:
                hypothesis.append(self.labels[i])
30
        return "".join(hypothesis)
31
32
33
34
35
36
37


@pytest.fixture
def ctc_decoder():
    return GreedyCTCDecoder


moto's avatar
moto committed
38
_FILES = {
39
40
41
42
43
44
    "en": "Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac",
    "de": "20090505-0900-PLENARY-16-de_20090505-21_56_00_8.flac",
    "en2": "20120613-0900-PLENARY-8-en_20120613-13_46_50_3.flac",
    "es": "20130207-0900-PLENARY-7-es_20130207-13_02_05_5.flac",
    "fr": "20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac",
    "it": "20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac",
moto's avatar
moto committed
45
}
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
_MIXTURE_FILES = {
    "speech_separation": "mixture_3729-6852-0037_8463-287645-0000.wav",
    "music_separation": "al_james_mixture_shorter.wav",
}
_CLEAN_FILES = {
    "speech_separation": [
        "s1_3729-6852-0037_8463-287645-0000.wav",
        "s2_3729-6852-0037_8463-287645-0000.wav",
    ],
    "music_separation": [
        "al_james_drums_shorter.wav",
        "al_james_bass_shorter.wav",
        "al_james_other_shorter.wav",
        "al_james_vocals_shorter.wav",
    ],
}
moto's avatar
moto committed
62
63


64
@pytest.fixture
65
def sample_speech(lang):
moto's avatar
moto committed
66
    if lang not in _FILES:
67
        raise NotImplementedError(f"Unexpected lang: {lang}")
moto's avatar
moto committed
68
    filename = _FILES[lang]
69
    path = torchaudio.utils.download_asset(f"test-assets/{filename}")
moto's avatar
moto committed
70
    return path
moto's avatar
moto committed
71
72


73
@pytest.fixture
74
75
76
77
def mixture_source(task):
    if task not in _MIXTURE_FILES:
        raise NotImplementedError(f"Unexpected task: {task}")
    path = torchaudio.utils.download_asset(f"test-assets/{_MIXTURE_FILES[task]}")
78
79
80
81
    return path


@pytest.fixture
82
83
84
def clean_sources(task):
    if task not in _CLEAN_FILES:
        raise NotImplementedError(f"Unexpected task: {task}")
85
    paths = []
86
    for file in _CLEAN_FILES[task]:
87
        path = torchaudio.utils.download_asset(f"test-assets/{file}")
88
89
90
91
        paths.append(path)
    return paths


moto's avatar
moto committed
92
93
94
95
96
97
98
def pytest_addoption(parser):
    parser.addoption(
        "--use-tmp-hub-dir",
        action="store_true",
        help=(
            "When provided, tests will use temporary directory as Torch Hub directory. "
            "Downloaded models will be deleted after each test."
99
        ),
moto's avatar
moto committed
100
101
102
103
    )


@pytest.fixture(autouse=True)
104
def temp_hub_dir(tmp_path, pytestconfig):
mayp777's avatar
UPDATE  
mayp777 committed
105
    if not pytestconfig.getoption("use_tmp_hub_dir", default=False):
moto's avatar
moto committed
106
107
108
        yield
    else:
        org_dir = torch.hub.get_dir()
109
110
        subdir = os.path.join(tmp_path, "hub")
        torch.hub.set_dir(subdir)
moto's avatar
moto committed
111
112
        yield
        torch.hub.set_dir(org_dir)
113
        shutil.rmtree(subdir, ignore_errors=True)
114
115
116
117
118
119


@pytest.fixture()
def emissions():
    path = torchaudio.utils.download_asset("test-assets/emissions-8555-28447-0012.pt")
    return torch.load(path)