ljspeech_test.py 2.9 KB
Newer Older
1
2
import csv
import os
3
from pathlib import Path
4

5
from torchaudio.datasets import ljspeech
6
from torchaudio_unittest.common_utils import (
7
8
9
    get_whitenoise,
    normalize_wav,
    save_wav,
10
11
    TempDirMixin,
    TorchaudioTestCase,
12
13
)

Aziz's avatar
Aziz committed
14
15
16
17
_TRANSCRIPTS = [
    "Test transcript 1",
    "Test transcript 2",
    "Test transcript 3",
18
    "In 1465 Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,",
Aziz's avatar
Aziz committed
19
20
21
22
23
24
]

_NORMALIZED_TRANSCRIPT = [
    "Test transcript one",
    "Test transcript two",
    "Test transcript three",
25
    "In fourteen sixty-five Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,",
Aziz's avatar
Aziz committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
]


def get_mock_dataset(root_dir):
    """
    root_dir: path to the mocked dataset
    """
    mocked_data = []
    base_dir = os.path.join(root_dir, "LJSpeech-1.1")
    archive_dir = os.path.join(base_dir, "wavs")
    os.makedirs(archive_dir, exist_ok=True)
    metadata_path = os.path.join(base_dir, "metadata.csv")
    sample_rate = 22050

40
41
42
43
    with open(metadata_path, mode="w", newline="") as metadata_file:
        metadata_writer = csv.writer(metadata_file, delimiter="|", quoting=csv.QUOTE_NONE)
        for i, (transcript, normalized_transcript) in enumerate(zip(_TRANSCRIPTS, _NORMALIZED_TRANSCRIPT)):
            fileid = f"LJ001-{i:04d}"
Aziz's avatar
Aziz committed
44
45
46
            metadata_writer.writerow([fileid, transcript, normalized_transcript])
            filename = fileid + ".wav"
            path = os.path.join(archive_dir, filename)
47
            data = get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=1, dtype="int16", seed=i)
Aziz's avatar
Aziz committed
48
49
50
51
            save_wav(path, data, sample_rate)
            mocked_data.append(normalize_wav(data))
    return mocked_data, _TRANSCRIPTS, _NORMALIZED_TRANSCRIPT

52
53
54
55
56

class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
    backend = "default"

    root_dir = None
Aziz's avatar
Aziz committed
57
    data, _transcripts, _normalized_transcript = [], [], []
58
59
60
61

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
Aziz's avatar
Aziz committed
62
        cls.data, cls._transcripts, cls._normalized_transcript = get_mock_dataset(cls.root_dir)
63

64
    def _test_ljspeech(self, dataset):
65
        n_ite = 0
66
        for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate(dataset):
Aziz's avatar
Aziz committed
67
68
            expected_transcript = self._transcripts[i]
            expected_normalized_transcript = self._normalized_transcript[i]
69
70
71
72
73
74
75
            expected_data = self.data[i]
            self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
            assert sample_rate == sample_rate
            assert transcript == expected_transcript
            assert normalized_transcript == expected_normalized_transcript
            n_ite += 1
        assert n_ite == len(self.data)
76
77
78
79
80
81
82
83

    def test_ljspeech_str(self):
        dataset = ljspeech.LJSPEECH(self.root_dir)
        self._test_ljspeech(dataset)

    def test_ljspeech_path(self):
        dataset = ljspeech.LJSPEECH(Path(self.root_dir))
        self._test_ljspeech(dataset)