libritts_test.py 3.08 KB
Newer Older
moto's avatar
moto committed
1
import os
2
from pathlib import Path
moto's avatar
moto committed
3

4
from torchaudio.datasets.libritts import LIBRITTS
5
from torchaudio_unittest.common_utils import (
moto's avatar
moto committed
6
7
    get_whitenoise,
    normalize_wav,
8
9
10
    save_wav,
    TempDirMixin,
    TorchaudioTestCase,
moto's avatar
moto committed
11
12
)

Aziz's avatar
Aziz committed
13
_UTTERANCE_IDS = [
14
15
    [19, 198, "000000", "000000"],
    [26, 495, "000004", "000000"],
Aziz's avatar
Aziz committed
16
]
17
18
_ORIGINAL_TEXT = "this is the original text."
_NORMALIZED_TEXT = "this is the normalized text."
Aziz's avatar
Aziz committed
19
20
21
22
23
24
25


def get_mock_dataset(root_dir):
    """
    root_dir: directory to the mocked dataset
    """
    mocked_data = []
26
    base_dir = os.path.join(root_dir, "LibriTTS", "train-clean-100")
Aziz's avatar
Aziz committed
27
28
29
30
31
32
    for i, utterance_id in enumerate(_UTTERANCE_IDS):
        filename = f'{"_".join(str(u) for u in utterance_id)}.wav'
        file_dir = os.path.join(base_dir, str(utterance_id[0]), str(utterance_id[1]))
        os.makedirs(file_dir, exist_ok=True)
        path = os.path.join(file_dir, filename)

33
        data = get_whitenoise(sample_rate=24000, duration=2, n_channels=1, dtype="int16", seed=i)
Aziz's avatar
Aziz committed
34
35
36
37
38
        save_wav(path, data, 24000)
        mocked_data.append(normalize_wav(data))

        original_text_filename = f'{"_".join(str(u) for u in utterance_id)}.original.txt'
        path_original = os.path.join(file_dir, original_text_filename)
39
        with open(path_original, "w") as file_:
Aziz's avatar
Aziz committed
40
41
42
43
            file_.write(_ORIGINAL_TEXT)

        normalized_text_filename = f'{"_".join(str(u) for u in utterance_id)}.normalized.txt'
        path_normalized = os.path.join(file_dir, normalized_text_filename)
44
        with open(path_normalized, "w") as file_:
Aziz's avatar
Aziz committed
45
46
47
            file_.write(_NORMALIZED_TEXT)
    return mocked_data, _UTTERANCE_IDS, _ORIGINAL_TEXT, _NORMALIZED_TEXT

moto's avatar
moto committed
48
49

class TestLibriTTS(TempDirMixin, TorchaudioTestCase):
50
    backend = "default"
moto's avatar
moto committed
51
52
53

    root_dir = None
    data = []
Aziz's avatar
Aziz committed
54
    _utterance_ids, _original_text, _normalized_text = [], [], []
moto's avatar
moto committed
55
56
57
58

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
Aziz's avatar
Aziz committed
59
        cls.data, cls._utterance_ids, cls._original_text, cls._normalized_text = get_mock_dataset(cls.root_dir)
moto's avatar
moto committed
60

61
    def _test_libritts(self, dataset):
62
        n_ites = 0
63
64
65
66
67
68
69
70
71
        for i, (
            waveform,
            sample_rate,
            original_text,
            normalized_text,
            speaker_id,
            chapter_id,
            utterance_id,
        ) in enumerate(dataset):
Aziz's avatar
Aziz committed
72
            expected_ids = self._utterance_ids[i]
moto's avatar
moto committed
73
74
            expected_data = self.data[i]
            self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
jimchen90's avatar
jimchen90 committed
75
            assert sample_rate == 24000
moto's avatar
moto committed
76
77
            assert speaker_id == expected_ids[0]
            assert chapter_id == expected_ids[1]
Aziz's avatar
Aziz committed
78
79
            assert original_text == self._original_text
            assert normalized_text == self._normalized_text
moto's avatar
moto committed
80
            assert utterance_id == f'{"_".join(str(u) for u in expected_ids[-4:])}'
81
            n_ites += 1
Aziz's avatar
Aziz committed
82
        assert n_ites == len(self._utterance_ids)
83
84
85
86
87
88
89
90

    def test_libritts_str(self):
        dataset = LIBRITTS(self.root_dir)
        self._test_libritts(dataset)

    def test_libritts_path(self):
        dataset = LIBRITTS(Path(self.root_dir))
        self._test_libritts(dataset)