tedlium_test.py 5.18 KB
Newer Older
1
import os
Vincent QB's avatar
Vincent QB committed
2
import platform
3
from pathlib import Path
4

Aziz's avatar
Aziz committed
5
from torchaudio.datasets import tedlium
6
7
8
9
10
11
12
from torchaudio_unittest.common_utils import (
    get_whitenoise,
    save_wav,
    skipIfNoSox,
    TempDirMixin,
    TorchaudioTestCase,
)
Aziz's avatar
Aziz committed
13

14
# Used to generate a unique utterance for each dummy audio file
Aziz's avatar
Aziz committed
15
_UTTERANCES = [
16
17
18
19
20
21
22
    "AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 <o,f0,female> script1\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 <o,f0,female> script2\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 <o,f0,female> script3\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 6.0 8.0 <o,f0,female> script4\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 <o,f0,female> script5\n",
]

Aziz's avatar
Aziz committed
23
_PHONEME = [
24
25
26
27
28
29
30
31
32
33
    "a AH",
    "a(2) EY",
    "aachen AA K AH N",
    "aad AE D",
    "aaden EY D AH N",
    "aadmi AE D M IY",
    "aae EY EY",
]


Aziz's avatar
Aziz committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def get_mock_dataset(dataset_dir):
    """
    dataset_dir: directory of the mocked dataset
    """
    mocked_samples = {}
    os.makedirs(dataset_dir, exist_ok=True)
    sample_rate = 16000  # 16kHz
    seed = 0

    for release in ["release1", "release2", "release3"]:
        data = get_whitenoise(sample_rate=sample_rate, duration=10.00, n_channels=1, dtype="float32", seed=seed)
        if release in ["release1", "release2"]:
            release_dir = os.path.join(
                dataset_dir,
                tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
                tedlium._RELEASE_CONFIGS[release]["subset"],
            )
        else:
            release_dir = os.path.join(
                dataset_dir,
                tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
                tedlium._RELEASE_CONFIGS[release]["data_path"],
            )
        os.makedirs(release_dir, exist_ok=True)
        os.makedirs(os.path.join(release_dir, "stm"), exist_ok=True)  # Subfolder for transcripts
        os.makedirs(os.path.join(release_dir, "sph"), exist_ok=True)  # Subfolder for audio files
        filename = f"{release}.sph"
        path = os.path.join(os.path.join(release_dir, "sph"), filename)
        save_wav(path, data, sample_rate)

        trans_filename = f"{release}.stm"
        trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename)
        with open(trans_path, "w") as f:
            f.write("".join(_UTTERANCES))

        dict_filename = f"{release}.dic"
        dict_path = os.path.join(release_dir, dict_filename)
        with open(dict_path, "w") as f:
            f.write("\n".join(_PHONEME))

        # Create a samples list to compare with
        mocked_samples[release] = []
        for utterance in _UTTERANCES:
            talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6)
            start_time = int(float(start_time)) * sample_rate
            end_time = int(float(end_time)) * sample_rate
            sample = (
                data[:, start_time:end_time],
                sample_rate,
                transcript,
                talk_id,
                speaker_id,
                identifier,
            )
            mocked_samples[release].append(sample)
        seed += 1
    return mocked_samples


Vincent QB's avatar
Vincent QB committed
93
class Tedlium(TempDirMixin):
94
95
96
97
98
99
100
    root_dir = None
    samples = {}

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
        cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium")
Aziz's avatar
Aziz committed
101
        cls.samples = get_mock_dataset(dataset_dir)
102

103
    def _test_tedlium(self, dataset, release):
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        num_samples = 0
        for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
            self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.samples[release][i][1]
            assert transcript == self.samples[release][i][2]
            assert talk_id == self.samples[release][i][3]
            assert speaker_id == self.samples[release][i][4]
            assert identifier == self.samples[release][i][5]
            num_samples += 1

        assert num_samples == len(self.samples[release])

        dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
        phoneme_dict = dataset.phoneme_dict
        phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
Aziz's avatar
Aziz committed
119
        assert phoenemes == _PHONEME
120

121
122
    def test_tedlium_release1_str(self):
        release = "release1"
123
        dataset = tedlium.TEDLIUM(self.root_dir, release=release)
124
        self._test_tedlium(dataset, release)
125

126
127
128
129
    def test_tedlium_release1_path(self):
        release = "release1"
        dataset = tedlium.TEDLIUM(Path(self.root_dir), release=release)
        self._test_tedlium(dataset, release)
130

131
132
133
134
    def test_tedlium_release2(self):
        release = "release2"
        dataset = tedlium.TEDLIUM(self.root_dir, release=release)
        self._test_tedlium(dataset, release)
135
136
137
138

    def test_tedlium_release3(self):
        release = "release3"
        dataset = tedlium.TEDLIUM(self.root_dir, release=release)
139
        self._test_tedlium(dataset, release)
Vincent QB's avatar
Vincent QB committed
140
141
142
143
144
145
146


class TestTedliumSoundfile(Tedlium, TorchaudioTestCase):
    backend = "soundfile"


if platform.system() != "Windows":
147

Caroline Chen's avatar
Caroline Chen committed
148
    @skipIfNoSox
Vincent QB's avatar
Vincent QB committed
149
150
    class TestTedliumSoxIO(Tedlium, TorchaudioTestCase):
        backend = "sox_io"