tedlium_test.py 5.15 KB
Newer Older
1
import os
Vincent QB's avatar
Vincent QB committed
2
import platform
3
from pathlib import Path
4

Aziz's avatar
Aziz committed
5
from torchaudio.datasets import tedlium
6
from torchaudio_unittest.common_utils import get_whitenoise, save_wav, skipIfNoSox, TempDirMixin, TorchaudioTestCase
Aziz's avatar
Aziz committed
7

8
# Used to generate a unique utterance for each dummy audio file
Aziz's avatar
Aziz committed
9
_UTTERANCES = [
10
11
12
13
14
15
16
    "AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 <o,f0,female> script1\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 <o,f0,female> script2\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 <o,f0,female> script3\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 6.0 8.0 <o,f0,female> script4\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 <o,f0,female> script5\n",
]

Aziz's avatar
Aziz committed
17
_PHONEME = [
18
19
20
21
22
23
24
25
26
27
    "a AH",
    "a(2) EY",
    "aachen AA K AH N",
    "aad AE D",
    "aaden EY D AH N",
    "aadmi AE D M IY",
    "aae EY EY",
]


Aziz's avatar
Aziz committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def get_mock_dataset(dataset_dir):
    """
    dataset_dir: directory of the mocked dataset
    """
    mocked_samples = {}
    os.makedirs(dataset_dir, exist_ok=True)
    sample_rate = 16000  # 16kHz
    seed = 0

    for release in ["release1", "release2", "release3"]:
        data = get_whitenoise(sample_rate=sample_rate, duration=10.00, n_channels=1, dtype="float32", seed=seed)
        if release in ["release1", "release2"]:
            release_dir = os.path.join(
                dataset_dir,
                tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
                tedlium._RELEASE_CONFIGS[release]["subset"],
            )
        else:
            release_dir = os.path.join(
                dataset_dir,
                tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
                tedlium._RELEASE_CONFIGS[release]["data_path"],
            )
        os.makedirs(release_dir, exist_ok=True)
        os.makedirs(os.path.join(release_dir, "stm"), exist_ok=True)  # Subfolder for transcripts
        os.makedirs(os.path.join(release_dir, "sph"), exist_ok=True)  # Subfolder for audio files
        filename = f"{release}.sph"
        path = os.path.join(os.path.join(release_dir, "sph"), filename)
        save_wav(path, data, sample_rate)

        trans_filename = f"{release}.stm"
        trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename)
        with open(trans_path, "w") as f:
            f.write("".join(_UTTERANCES))

        dict_filename = f"{release}.dic"
        dict_path = os.path.join(release_dir, dict_filename)
        with open(dict_path, "w") as f:
            f.write("\n".join(_PHONEME))

        # Create a samples list to compare with
        mocked_samples[release] = []
        for utterance in _UTTERANCES:
            talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6)
            start_time = int(float(start_time)) * sample_rate
            end_time = int(float(end_time)) * sample_rate
            sample = (
                data[:, start_time:end_time],
                sample_rate,
                transcript,
                talk_id,
                speaker_id,
                identifier,
            )
            mocked_samples[release].append(sample)
        seed += 1
    return mocked_samples


Vincent QB's avatar
Vincent QB committed
87
class Tedlium(TempDirMixin):
88
89
90
91
92
93
94
    root_dir = None
    samples = {}

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
        cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium")
Aziz's avatar
Aziz committed
95
        cls.samples = get_mock_dataset(dataset_dir)
96

97
    def _test_tedlium(self, dataset, release):
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        num_samples = 0
        for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
            self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.samples[release][i][1]
            assert transcript == self.samples[release][i][2]
            assert talk_id == self.samples[release][i][3]
            assert speaker_id == self.samples[release][i][4]
            assert identifier == self.samples[release][i][5]
            num_samples += 1

        assert num_samples == len(self.samples[release])

        dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
        phoneme_dict = dataset.phoneme_dict
        phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
Aziz's avatar
Aziz committed
113
        assert phoenemes == _PHONEME
114

115
116
    def test_tedlium_release1_str(self):
        release = "release1"
117
        dataset = tedlium.TEDLIUM(self.root_dir, release=release)
118
        self._test_tedlium(dataset, release)
119

120
121
122
123
    def test_tedlium_release1_path(self):
        release = "release1"
        dataset = tedlium.TEDLIUM(Path(self.root_dir), release=release)
        self._test_tedlium(dataset, release)
124

125
126
127
128
    def test_tedlium_release2(self):
        release = "release2"
        dataset = tedlium.TEDLIUM(self.root_dir, release=release)
        self._test_tedlium(dataset, release)
129
130
131
132

    def test_tedlium_release3(self):
        release = "release3"
        dataset = tedlium.TEDLIUM(self.root_dir, release=release)
133
        self._test_tedlium(dataset, release)
Vincent QB's avatar
Vincent QB committed
134
135
136
137
138
139
140


class TestTedliumSoundfile(Tedlium, TorchaudioTestCase):
    backend = "soundfile"


if platform.system() != "Windows":
141

Caroline Chen's avatar
Caroline Chen committed
142
    @skipIfNoSox
Vincent QB's avatar
Vincent QB committed
143
144
    class TestTedliumSoxIO(Tedlium, TorchaudioTestCase):
        backend = "sox_io"