"vscode:/vscode.git/clone" did not exist on "4ade15dd32397c0a45bd41202b9f949dd78cafe3"
conftest.py 1.51 KB
Newer Older
1
import torch
moto's avatar
moto committed
2
import requests
3
4
5
6
import pytest


class GreedyCTCDecoder(torch.nn.Module):
7
    def __init__(self, labels, blank: int = 0):
8
        super().__init__()
9
        self.blank = blank
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
        self.labels = labels

    def forward(self, logits: torch.Tensor) -> str:
        """Given a sequence logits over labels, get the best path string

        Args:
            logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
            str: The resulting transcript
        """
        best_path = torch.argmax(logits, dim=-1)  # [num_seq,]
        best_path = torch.unique_consecutive(best_path, dim=-1)
        hypothesis = []
        for i in best_path:
25
26
            if i != self.blank:
                hypothesis.append(self.labels[i])
27
28
29
30
31
32
33
34
        return ''.join(hypothesis)


@pytest.fixture
def ctc_decoder():
    return GreedyCTCDecoder


moto's avatar
moto committed
35
36
_FILES = {
    'en': 'Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac',
37
    'fr': '20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac',
moto's avatar
moto committed
38
39
40
}


41
@pytest.fixture
moto's avatar
moto committed
42
43
44
45
46
47
48
49
50
51
52
53
54
def sample_speech(tmp_path, lang):
    if lang not in _FILES:
        raise NotImplementedError(f'Unexpected lang: {lang}')
    filename = _FILES[lang]
    path = tmp_path.parent / filename
    if not path.exists():
        url = f'https://download.pytorch.org/torchaudio/test-assets/{filename}'
        print(f'downloading from {url}')
        with open(path, 'wb') as file:
            with requests.get(url) as resp:
                resp.raise_for_status()
                file.write(resp.content)
    return path