cmuarctic_test.py 2.49 KB
Newer Older
1
import os
2
from pathlib import Path
3
4

from torchaudio.datasets import cmuarctic
5
from torchaudio_unittest.common_utils import get_whitenoise, normalize_wav, save_wav, TempDirMixin, TorchaudioTestCase
6
7


8
9
10
11
12
13
def get_mock_dataset(root_dir):
    """
    root_dir: directory to the mocked dataset
    """
    mocked_data = []
    sample_rate = 16000
14
    transcript = "This is a test transcript."
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

    base_dir = os.path.join(root_dir, "ARCTIC", "cmu_us_aew_arctic")
    txt_dir = os.path.join(base_dir, "etc")
    os.makedirs(txt_dir, exist_ok=True)
    txt_file = os.path.join(txt_dir, "txt.done.data")
    audio_dir = os.path.join(base_dir, "wav")
    os.makedirs(audio_dir, exist_ok=True)

    seed = 42
    with open(txt_file, "w") as txt:
        for c in ["a", "b"]:
            for i in range(5):
                utterance_id = f"arctic_{c}{i:04d}"
                path = os.path.join(audio_dir, f"{utterance_id}.wav")
                data = get_whitenoise(
                    sample_rate=sample_rate,
                    duration=3,
                    n_channels=1,
                    dtype="int16",
                    seed=seed,
                )
                save_wav(path, data, sample_rate)
                sample = (
                    normalize_wav(data),
                    sample_rate,
40
                    transcript,
41
42
43
                    utterance_id.split("_")[1],
                )
                mocked_data.append(sample)
44
                txt.write(f'( {utterance_id} "{transcript}" )\n')
45
46
47
48
                seed += 1
    return mocked_data


49
50
51
52
53
54
55
56
class TestCMUARCTIC(TempDirMixin, TorchaudioTestCase):

    root_dir = None
    samples = []

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
57
        cls.samples = get_mock_dataset(cls.root_dir)
58

59
    def _test_cmuarctic(self, dataset):
60
        n_ite = 0
61
        for i, (waveform, sample_rate, transcript, utterance_id) in enumerate(dataset):
62
63
            expected_sample = self.samples[i]
            assert sample_rate == expected_sample[1]
64
            assert transcript == expected_sample[2]
65
66
67
68
            assert utterance_id == expected_sample[3]
            self.assertEqual(expected_sample[0], waveform, atol=5e-5, rtol=1e-8)
            n_ite += 1
        assert n_ite == len(self.samples)
69
70
71
72
73
74
75
76

    def test_cmuarctic_str(self):
        dataset = cmuarctic.CMUARCTIC(self.root_dir)
        self._test_cmuarctic(dataset)

    def test_cmuarctic_path(self):
        dataset = cmuarctic.CMUARCTIC(Path(self.root_dir))
        self._test_cmuarctic(dataset)