ljspeech_test.py 2.85 KB
Newer Older
1
2
import csv
import os
3
from pathlib import Path
4

5
from torchaudio.datasets import ljspeech
6
from torchaudio_unittest.common_utils import get_whitenoise, normalize_wav, save_wav, TempDirMixin, TorchaudioTestCase
7

Aziz's avatar
Aziz committed
8
9
10
11
_TRANSCRIPTS = [
    "Test transcript 1",
    "Test transcript 2",
    "Test transcript 3",
12
    "In 1465 Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,",
Aziz's avatar
Aziz committed
13
14
15
16
17
18
]

_NORMALIZED_TRANSCRIPT = [
    "Test transcript one",
    "Test transcript two",
    "Test transcript three",
19
    "In fourteen sixty-five Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,",
Aziz's avatar
Aziz committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
]


def get_mock_dataset(root_dir):
    """
    root_dir: path to the mocked dataset
    """
    mocked_data = []
    base_dir = os.path.join(root_dir, "LJSpeech-1.1")
    archive_dir = os.path.join(base_dir, "wavs")
    os.makedirs(archive_dir, exist_ok=True)
    metadata_path = os.path.join(base_dir, "metadata.csv")
    sample_rate = 22050

34
35
36
37
    with open(metadata_path, mode="w", newline="") as metadata_file:
        metadata_writer = csv.writer(metadata_file, delimiter="|", quoting=csv.QUOTE_NONE)
        for i, (transcript, normalized_transcript) in enumerate(zip(_TRANSCRIPTS, _NORMALIZED_TRANSCRIPT)):
            fileid = f"LJ001-{i:04d}"
Aziz's avatar
Aziz committed
38
39
40
            metadata_writer.writerow([fileid, transcript, normalized_transcript])
            filename = fileid + ".wav"
            path = os.path.join(archive_dir, filename)
41
            data = get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=1, dtype="int16", seed=i)
Aziz's avatar
Aziz committed
42
43
44
45
            save_wav(path, data, sample_rate)
            mocked_data.append(normalize_wav(data))
    return mocked_data, _TRANSCRIPTS, _NORMALIZED_TRANSCRIPT

46
47
48
49

class TestLJSpeech(TempDirMixin, TorchaudioTestCase):

    root_dir = None
Aziz's avatar
Aziz committed
50
    data, _transcripts, _normalized_transcript = [], [], []
51
52
53
54

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
Aziz's avatar
Aziz committed
55
        cls.data, cls._transcripts, cls._normalized_transcript = get_mock_dataset(cls.root_dir)
56

57
    def _test_ljspeech(self, dataset):
58
        n_ite = 0
59
        for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate(dataset):
Aziz's avatar
Aziz committed
60
61
            expected_transcript = self._transcripts[i]
            expected_normalized_transcript = self._normalized_transcript[i]
62
63
64
65
66
67
68
            expected_data = self.data[i]
            self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
            assert sample_rate == sample_rate
            assert transcript == expected_transcript
            assert normalized_transcript == expected_normalized_transcript
            n_ite += 1
        assert n_ite == len(self.data)
69
70
71
72
73
74
75
76

    def test_ljspeech_str(self):
        dataset = ljspeech.LJSPEECH(self.root_dir)
        self._test_ljspeech(dataset)

    def test_ljspeech_path(self):
        dataset = ljspeech.LJSPEECH(Path(self.root_dir))
        self._test_ljspeech(dataset)