conftest.py 2.25 KB
Newer Older
1
import torch
moto's avatar
moto committed
2
from torchaudio._internal import download_url_to_file
3
4
5
6
import pytest


class GreedyCTCDecoder(torch.nn.Module):
7
    def __init__(self, labels, blank: int = 0):
8
        super().__init__()
9
        self.blank = blank
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
        self.labels = labels

    def forward(self, logits: torch.Tensor) -> str:
        """Given a sequence logits over labels, get the best path string

        Args:
            logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
            str: The resulting transcript
        """
        best_path = torch.argmax(logits, dim=-1)  # [num_seq,]
        best_path = torch.unique_consecutive(best_path, dim=-1)
        hypothesis = []
        for i in best_path:
25
26
            if i != self.blank:
                hypothesis.append(self.labels[i])
27
28
29
30
31
32
33
34
        return ''.join(hypothesis)


@pytest.fixture
def ctc_decoder():
    return GreedyCTCDecoder


moto's avatar
moto committed
35
36
_FILES = {
    'en': 'Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac',
37
    'de': '20090505-0900-PLENARY-16-de_20090505-21_56_00_8.flac',
38
    'en2': '20120613-0900-PLENARY-8-en_20120613-13_46_50_3.flac',
39
    'es': '20130207-0900-PLENARY-7-es_20130207-13_02_05_5.flac',
40
    'fr': '20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac',
41
    'it': '20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac',
moto's avatar
moto committed
42
43
44
}


45
@pytest.fixture
moto's avatar
moto committed
46
47
48
49
50
51
52
53
def sample_speech(tmp_path, lang):
    if lang not in _FILES:
        raise NotImplementedError(f'Unexpected lang: {lang}')
    filename = _FILES[lang]
    path = tmp_path.parent / filename
    if not path.exists():
        url = f'https://download.pytorch.org/torchaudio/test-assets/{filename}'
        print(f'downloading from {url}')
moto's avatar
moto committed
54
        download_url_to_file(url, path, progress=False)
moto's avatar
moto committed
55
    return path
moto's avatar
moto committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77


def pytest_addoption(parser):
    parser.addoption(
        "--use-tmp-hub-dir",
        action="store_true",
        help=(
            "When provided, tests will use temporary directory as Torch Hub directory. "
            "Downloaded models will be deleted after each test."
        )
    )


@pytest.fixture(autouse=True)
def temp_hub_dir(tmpdir, pytestconfig):
    if not pytestconfig.getoption('use_tmp_hub_dir'):
        yield
    else:
        org_dir = torch.hub.get_dir()
        torch.hub.set_dir(tmpdir)
        yield
        torch.hub.set_dir(org_dir)