"vscode:/vscode.git/clone" did not exist on "38f0ed63f322b73e81c7d632272a2493c287298f"
speechcommands_test.py 4.98 KB
Newer Older
1
import os
2
from pathlib import Path
3

4
from torchaudio.datasets import speechcommands
5
from torchaudio_unittest.common_utils import get_whitenoise, normalize_wav, save_wav, TempDirMixin, TorchaudioTestCase
6

7
_LABELS = [
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    "bed",
    "bird",
    "cat",
    "dog",
    "down",
    "eight",
    "five",
    "follow",
    "forward",
    "four",
    "go",
    "happy",
    "house",
    "learn",
    "left",
    "marvin",
    "nine",
    "no",
    "off",
    "on",
    "one",
    "right",
    "seven",
    "sheila",
    "six",
    "stop",
    "three",
    "tree",
    "two",
    "up",
    "visual",
    "wow",
    "yes",
    "zero",
]


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def get_mock_dataset(dataset_dir):
    """
    dataset_dir: directory to the mocked dataset
    """
    mocked_samples = []
    mocked_train_samples = []
    mocked_valid_samples = []
    mocked_test_samples = []
    os.makedirs(dataset_dir, exist_ok=True)
    sample_rate = 16000  # 16kHz sample rate
    seed = 0
    valid_file = os.path.join(dataset_dir, "validation_list.txt")
    test_file = os.path.join(dataset_dir, "testing_list.txt")
    with open(valid_file, "w") as valid, open(test_file, "w") as test:
        for label in _LABELS:
            path = os.path.join(dataset_dir, label)
            os.makedirs(path, exist_ok=True)
            for j in range(6):
                # generate hash ID for speaker
                speaker = "{:08x}".format(j)

                for utterance in range(3):
                    filename = f"{speaker}{speechcommands.HASH_DIVIDER}{utterance}.wav"
                    file_path = os.path.join(path, filename)
                    seed += 1
                    data = get_whitenoise(
                        sample_rate=sample_rate,
                        duration=0.01,
                        n_channels=1,
                        dtype="int16",
                        seed=seed,
                    )
                    save_wav(file_path, data, sample_rate)
                    sample = (
                        normalize_wav(data),
                        sample_rate,
                        label,
                        speaker,
                        utterance,
                    )
                    mocked_samples.append(sample)
                    if j < 2:
                        mocked_train_samples.append(sample)
                    elif j < 4:
89
                        valid.write(f"{label}/{filename}\n")
90
91
                        mocked_valid_samples.append(sample)
                    elif j < 6:
92
                        test.write(f"{label}/{filename}\n")
93
94
95
96
                        mocked_test_samples.append(sample)
    return mocked_samples, mocked_train_samples, mocked_valid_samples, mocked_test_samples


97
98
99
100
101
class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
    backend = "default"

    root_dir = None
    samples = []
102
103
104
    train_samples = []
    valid_samples = []
    test_samples = []
105
106

    @classmethod
107
    def setUpClass(cls):
108
        cls.root_dir = cls.get_base_temp_dir()
109
        dataset_dir = os.path.join(cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL)
110
        cls.samples, cls.train_samples, cls.valid_samples, cls.test_samples = get_mock_dataset(dataset_dir)
111

112
    def _testSpeechCommands(self, dataset, data_samples):
113
        num_samples = 0
114
        for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(dataset):
115
116
117
118
119
            self.assertEqual(data, data_samples[i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == data_samples[i][1]
            assert label == data_samples[i][2]
            assert speaker_id == data_samples[i][3]
            assert utterance_number == data_samples[i][4]
120
121
            num_samples += 1

122
        assert num_samples == len(data_samples)
123

124
125
126
    def testSpeechCommands_str(self):
        dataset = speechcommands.SPEECHCOMMANDS(self.root_dir)
        self._testSpeechCommands(dataset, self.samples)
127

128
129
130
    def testSpeechCommands_path(self):
        dataset = speechcommands.SPEECHCOMMANDS(Path(self.root_dir))
        self._testSpeechCommands(dataset, self.samples)
131
132
133

    def testSpeechCommandsSubsetTrain(self):
        dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")
134
        self._testSpeechCommands(dataset, self.train_samples)
135
136
137

    def testSpeechCommandsSubsetValid(self):
        dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
138
        self._testSpeechCommands(dataset, self.valid_samples)
139
140
141

    def testSpeechCommandsSubsetTest(self):
        dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")
142
        self._testSpeechCommands(dataset, self.test_samples)
143
144
145
146
147
148
149
150

    def testSpeechCommandsSum(self):
        dataset_all = speechcommands.SPEECHCOMMANDS(self.root_dir)
        dataset_train = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")
        dataset_valid = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
        dataset_test = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")

        assert len(dataset_train) + len(dataset_valid) + len(dataset_test) == len(dataset_all)