yesno_test.py 1.85 KB
Newer Older
moto's avatar
moto committed
1
import os
2
from pathlib import Path
moto's avatar
moto committed
3
4

from torchaudio.datasets import yesno
5
from torchaudio_unittest.common_utils import (
moto's avatar
moto committed
6
7
    get_whitenoise,
    normalize_wav,
8
9
10
    save_wav,
    TempDirMixin,
    TorchaudioTestCase,
moto's avatar
moto committed
11
12
13
)


14
15
16
17
18
19
def get_mock_data(root_dir, labels):
    """
    root_dir: path
    labels: list of labels
    """
    mocked_data = []
20
    base_dir = os.path.join(root_dir, "waves_yesno")
21
22
23
24
    os.makedirs(base_dir, exist_ok=True)
    for i, label in enumerate(labels):
        filename = f'{"_".join(str(l) for l in label)}.wav'
        path = os.path.join(base_dir, filename)
25
        data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype="int16", seed=i)
26
27
28
29
30
        save_wav(path, data, 8000)
        mocked_data.append(normalize_wav(data))
    return mocked_data


moto's avatar
moto committed
31
class TestYesNo(TempDirMixin, TorchaudioTestCase):
32
    backend = "default"
moto's avatar
moto committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    root_dir = None
    data = []
    labels = [
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 1],
        [0, 1, 0, 1, 0, 1, 1, 0],
        [1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1],
    ]

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
47
        cls.data = get_mock_data(cls.root_dir, cls.labels)
moto's avatar
moto committed
48

49
    def _test_yesno(self, dataset):
50
51
        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
moto's avatar
moto committed
52
53
54
55
56
            expected_label = self.labels[i]
            expected_data = self.data[i]
            self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
            assert sample_rate == 8000
            assert label == expected_label
57
58
            n_ite += 1
        assert n_ite == len(self.data)
59
60
61
62
63
64
65
66

    def test_yesno_str(self):
        dataset = yesno.YESNO(self.root_dir)
        self._test_yesno(dataset)

    def test_yesno_path(self):
        dataset = yesno.YESNO(Path(self.root_dir))
        self._test_yesno(dataset)