gtzan_test.py 4.27 KB
Newer Older
1
import os
2
from pathlib import Path
3
4

from torchaudio.datasets import gtzan
5
from torchaudio_unittest.common_utils import (
6
7
    get_whitenoise,
    normalize_wav,
8
9
10
    save_wav,
    TempDirMixin,
    TorchaudioTestCase,
11
12
13
)


14
15
16
17
18
19
20
21
22
23
24
25
def get_mock_dataset(root_dir):
    """
    root_dir: directory to the mocked dataset
    """
    mocked_samples = []
    mocked_training = []
    mocked_validation = []
    mocked_testing = []
    sample_rate = 22050

    seed = 0
    for genre in gtzan.gtzan_genres:
26
        base_dir = os.path.join(root_dir, "genres", genre)
27
28
        os.makedirs(base_dir, exist_ok=True)
        for i in range(100):
29
30
31
            filename = f"{genre}.{i:05d}"
            path = os.path.join(base_dir, f"{filename}.wav")
            data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype="int16", seed=seed)
32
33
34
35
36
37
38
39
40
41
42
43
44
            save_wav(path, data, sample_rate)
            sample = (normalize_wav(data), sample_rate, genre)
            mocked_samples.append(sample)
            if filename in gtzan.filtered_test:
                mocked_testing.append(sample)
            if filename in gtzan.filtered_train:
                mocked_training.append(sample)
            if filename in gtzan.filtered_valid:
                mocked_validation.append(sample)
            seed += 1
    return (mocked_samples, mocked_training, mocked_validation, mocked_testing)


45
class TestGTZAN(TempDirMixin, TorchaudioTestCase):
46
    backend = "default"
47
48
49
50
51
52
53
54
55
56

    root_dir = None
    samples = []
    training = []
    validation = []
    testing = []

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
57
58
59
60
61
        mocked_data = get_mock_dataset(cls.root_dir)
        cls.samples = mocked_data[0]
        cls.training = mocked_data[1]
        cls.validation = mocked_data[2]
        cls.testing = mocked_data[3]
62
63
64
65
66
67
68
69
70
71
72
73

    def test_no_subset(self):
        dataset = gtzan.GTZAN(self.root_dir)

        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
            self.assertEqual(waveform, self.samples[i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.samples[i][1]
            assert label == self.samples[i][2]
            n_ite += 1
        assert n_ite == len(self.samples)

74
    def _test_training(self, dataset):
75
76
77
78
79
80
81
82
        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
            self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.training[i][1]
            assert label == self.training[i][2]
            n_ite += 1
        assert n_ite == len(self.training)

83
    def _test_validation(self, dataset):
84
85
86
87
88
89
90
91
        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
            self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.validation[i][1]
            assert label == self.validation[i][2]
            n_ite += 1
        assert n_ite == len(self.validation)

92
    def _test_testing(self, dataset):
93
94
95
96
97
98
99
        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
            self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.testing[i][1]
            assert label == self.testing[i][2]
            n_ite += 1
        assert n_ite == len(self.testing)
100
101

    def test_training_str(self):
102
        train_dataset = gtzan.GTZAN(self.root_dir, subset="training")
103
104
105
        self._test_training(train_dataset)

    def test_validation_str(self):
106
        val_dataset = gtzan.GTZAN(self.root_dir, subset="validation")
107
108
109
        self._test_validation(val_dataset)

    def test_testing_str(self):
110
        test_dataset = gtzan.GTZAN(self.root_dir, subset="testing")
111
112
113
114
        self._test_testing(test_dataset)

    def test_training_path(self):
        root_dir = Path(self.root_dir)
115
        train_dataset = gtzan.GTZAN(root_dir, subset="training")
116
117
118
119
        self._test_training(train_dataset)

    def test_validation_path(self):
        root_dir = Path(self.root_dir)
120
        val_dataset = gtzan.GTZAN(root_dir, subset="validation")
121
122
123
124
        self._test_validation(val_dataset)

    def test_testing_path(self):
        root_dir = Path(self.root_dir)
125
        test_dataset = gtzan.GTZAN(root_dir, subset="testing")
126
        self._test_testing(test_dataset)