tedlium_test.py 5.34 KB
Newer Older
1
import os
Vincent QB's avatar
Vincent QB committed
2
3
import platform
import unittest
4
from pathlib import Path
5
6
7
8

from torchaudio.datasets import tedlium

from torchaudio_unittest.common_utils import (
Vincent QB's avatar
Vincent QB committed
9
    TestBaseMixin,
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    TempDirMixin,
    TorchaudioTestCase,
    get_whitenoise,
    save_wav,
    normalize_wav,
)

# Used to generate a unique utterance for each dummy audio file
UTTERANCES = [
    "AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 <o,f0,female> script1\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 <o,f0,female> script2\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 <o,f0,female> script3\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 6.0 8.0 <o,f0,female> script4\n",
    "AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 <o,f0,female> script5\n",
]

PHONEME = [
    "a AH",
    "a(2) EY",
    "aachen AA K AH N",
    "aad AE D",
    "aaden EY D AH N",
    "aadmi AE D M IY",
    "aae EY EY",
]


Vincent QB's avatar
Vincent QB committed
37
class Tedlium(TempDirMixin):
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    root_dir = None
    samples = {}

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
        cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium")
        os.makedirs(dataset_dir, exist_ok=True)
        sample_rate = 16000  # 16kHz
        seed = 0

        for release in ["release1", "release2", "release3"]:
            data = get_whitenoise(sample_rate=sample_rate, duration=10.00, n_channels=1, dtype="float32", seed=seed)
            if release in ["release1", "release2"]:
                release_dir = os.path.join(
                    dataset_dir,
                    tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
                    tedlium._RELEASE_CONFIGS[release]["subset"],
                )
            else:
                release_dir = os.path.join(
                    dataset_dir,
                    tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
                    tedlium._RELEASE_CONFIGS[release]["data_path"],
                )
            os.makedirs(release_dir, exist_ok=True)
            os.makedirs(os.path.join(release_dir, "stm"), exist_ok=True)  # Subfolder for transcripts
            os.makedirs(os.path.join(release_dir, "sph"), exist_ok=True)  # Subfolder for audio files
            filename = f"{release}.sph"
            path = os.path.join(os.path.join(release_dir, "sph"), filename)
            save_wav(path, data, sample_rate)

            trans_filename = f"{release}.stm"
            trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename)
            with open(trans_path, "w") as f:
                f.write("".join(UTTERANCES))

            dict_filename = f"{release}.dic"
            dict_path = os.path.join(release_dir, dict_filename)
            with open(dict_path, "w") as f:
                f.write("\n".join(PHONEME))

            # Create a samples list to compare with
            cls.samples[release] = []
            for utterance in UTTERANCES:
                talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6)
                start_time = int(float(start_time)) * sample_rate
                end_time = int(float(end_time)) * sample_rate
                sample = (
                    data[:, start_time:end_time],
                    sample_rate,
                    transcript,
                    talk_id,
                    speaker_id,
                    identifier,
                )
                cls.samples[release].append(sample)
            seed += 1

97
    def _test_tedlium(self, dataset, release):
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        num_samples = 0
        for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset):
            self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.samples[release][i][1]
            assert transcript == self.samples[release][i][2]
            assert talk_id == self.samples[release][i][3]
            assert speaker_id == self.samples[release][i][4]
            assert identifier == self.samples[release][i][5]
            num_samples += 1

        assert num_samples == len(self.samples[release])

        dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
        phoneme_dict = dataset.phoneme_dict
        phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
        assert phoenemes == PHONEME

115
116
    def test_tedlium_release1_str(self):
        release = "release1"
117
        dataset = tedlium.TEDLIUM(self.root_dir, release=release)
118
        self._test_tedlium(dataset, release)
119

120
121
122
123
    def test_tedlium_release1_path(self):
        release = "release1"
        dataset = tedlium.TEDLIUM(Path(self.root_dir), release=release)
        self._test_tedlium(dataset, release)
124

125
126
127
128
    def test_tedlium_release2(self):
        release = "release2"
        dataset = tedlium.TEDLIUM(self.root_dir, release=release)
        self._test_tedlium(dataset, release)
129
130
131
132

    def test_tedlium_release3(self):
        release = "release3"
        dataset = tedlium.TEDLIUM(self.root_dir, release=release)
133
        self._test_tedlium(dataset, release)
Vincent QB's avatar
Vincent QB committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149


class TestTedliumSoundfile(Tedlium, TorchaudioTestCase):
    backend = "soundfile"


class TestTedliumSoundfileNew(Tedlium, TorchaudioTestCase):
    backend = "soundfile-new"


if platform.system() != "Windows":
    class TestTedliumSox(Tedlium, TorchaudioTestCase):
        backend = "sox"

    class TestTedliumSoxIO(Tedlium, TorchaudioTestCase):
        backend = "sox_io"