libritts_test.py 3.08 KB
Newer Older
moto's avatar
moto committed
1
import os
2
from pathlib import Path
moto's avatar
moto committed
3

4
from torchaudio_unittest.common_utils import (
moto's avatar
moto committed
5
6
7
8
9
10
11
    TempDirMixin,
    TorchaudioTestCase,
    get_whitenoise,
    save_wav,
    normalize_wav,
)

Aziz's avatar
Aziz committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from torchaudio.datasets.libritts import LIBRITTS

_UTTERANCE_IDS = [
    [19, 198, '000000', '000000'],
    [26, 495, '000004', '000000'],
]
_ORIGINAL_TEXT = 'this is the original text.'
_NORMALIZED_TEXT = 'this is the normalized text.'


def get_mock_dataset(root_dir):
    """
    root_dir: directory to the mocked dataset
    """
    mocked_data = []
    base_dir = os.path.join(root_dir, 'LibriTTS', 'train-clean-100')
    for i, utterance_id in enumerate(_UTTERANCE_IDS):
        filename = f'{"_".join(str(u) for u in utterance_id)}.wav'
        file_dir = os.path.join(base_dir, str(utterance_id[0]), str(utterance_id[1]))
        os.makedirs(file_dir, exist_ok=True)
        path = os.path.join(file_dir, filename)

        data = get_whitenoise(sample_rate=24000, duration=2, n_channels=1, dtype='int16', seed=i)
        save_wav(path, data, 24000)
        mocked_data.append(normalize_wav(data))

        original_text_filename = f'{"_".join(str(u) for u in utterance_id)}.original.txt'
        path_original = os.path.join(file_dir, original_text_filename)
        with open(path_original, 'w') as file_:
            file_.write(_ORIGINAL_TEXT)

        normalized_text_filename = f'{"_".join(str(u) for u in utterance_id)}.normalized.txt'
        path_normalized = os.path.join(file_dir, normalized_text_filename)
        with open(path_normalized, 'w') as file_:
            file_.write(_NORMALIZED_TEXT)
    return mocked_data, _UTTERANCE_IDS, _ORIGINAL_TEXT, _NORMALIZED_TEXT

moto's avatar
moto committed
49
50
51
52
53
54

class TestLibriTTS(TempDirMixin, TorchaudioTestCase):
    backend = 'default'

    root_dir = None
    data = []
Aziz's avatar
Aziz committed
55
    _utterance_ids, _original_text, _normalized_text = [], [], []
moto's avatar
moto committed
56
57
58
59

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
Aziz's avatar
Aziz committed
60
        cls.data, cls._utterance_ids, cls._original_text, cls._normalized_text = get_mock_dataset(cls.root_dir)
moto's avatar
moto committed
61

62
    def _test_libritts(self, dataset):
63
        n_ites = 0
moto's avatar
moto committed
64
65
66
67
68
69
        for i, (waveform,
                sample_rate,
                original_text,
                normalized_text,
                speaker_id,
                chapter_id,
70
                utterance_id) in enumerate(dataset):
Aziz's avatar
Aziz committed
71
            expected_ids = self._utterance_ids[i]
moto's avatar
moto committed
72
73
            expected_data = self.data[i]
            self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
jimchen90's avatar
jimchen90 committed
74
            assert sample_rate == 24000
moto's avatar
moto committed
75
76
            assert speaker_id == expected_ids[0]
            assert chapter_id == expected_ids[1]
Aziz's avatar
Aziz committed
77
78
            assert original_text == self._original_text
            assert normalized_text == self._normalized_text
moto's avatar
moto committed
79
            assert utterance_id == f'{"_".join(str(u) for u in expected_ids[-4:])}'
80
            n_ites += 1
Aziz's avatar
Aziz committed
81
        assert n_ites == len(self._utterance_ids)
82
83
84
85
86
87
88
89

    def test_libritts_str(self):
        dataset = LIBRITTS(self.root_dir)
        self._test_libritts(dataset)

    def test_libritts_path(self):
        dataset = LIBRITTS(Path(self.root_dir))
        self._test_libritts(dataset)