Unverified Commit 64956d54 authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Refactor TEDLIUM unittest (#1135)

parent 8f02af5f
import os import os
import platform import platform
import unittest
from pathlib import Path from pathlib import Path
from torchaudio.datasets import tedlium
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin,
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
get_whitenoise, get_whitenoise,
save_wav, save_wav,
normalize_wav,
) )
from torchaudio.datasets import tedlium
# Used to generate a unique utterance for each dummy audio file # Used to generate a unique utterance for each dummy audio file
UTTERANCES = [ _UTTERANCES = [
"AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 <o,f0,female> script1\n", "AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 <o,f0,female> script1\n",
"AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 <o,f0,female> script2\n", "AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 <o,f0,female> script2\n",
"AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 <o,f0,female> script3\n", "AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 <o,f0,female> script3\n",
...@@ -23,7 +20,7 @@ UTTERANCES = [ ...@@ -23,7 +20,7 @@ UTTERANCES = [
"AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 <o,f0,female> script5\n", "AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 <o,f0,female> script5\n",
] ]
PHONEME = [ _PHONEME = [
"a AH", "a AH",
"a(2) EY", "a(2) EY",
"aachen AA K AH N", "aachen AA K AH N",
...@@ -34,14 +31,11 @@ PHONEME = [ ...@@ -34,14 +31,11 @@ PHONEME = [
] ]
class Tedlium(TempDirMixin): def get_mock_dataset(dataset_dir):
root_dir = None """
samples = {} dataset_dir: directory of the mocked dataset
"""
@classmethod mocked_samples = {}
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium")
os.makedirs(dataset_dir, exist_ok=True) os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz sample_rate = 16000 # 16kHz
seed = 0 seed = 0
...@@ -70,16 +64,16 @@ class Tedlium(TempDirMixin): ...@@ -70,16 +64,16 @@ class Tedlium(TempDirMixin):
trans_filename = f"{release}.stm" trans_filename = f"{release}.stm"
trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename) trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename)
with open(trans_path, "w") as f: with open(trans_path, "w") as f:
f.write("".join(UTTERANCES)) f.write("".join(_UTTERANCES))
dict_filename = f"{release}.dic" dict_filename = f"{release}.dic"
dict_path = os.path.join(release_dir, dict_filename) dict_path = os.path.join(release_dir, dict_filename)
with open(dict_path, "w") as f: with open(dict_path, "w") as f:
f.write("\n".join(PHONEME)) f.write("\n".join(_PHONEME))
# Create a samples list to compare with # Create a samples list to compare with
cls.samples[release] = [] mocked_samples[release] = []
for utterance in UTTERANCES: for utterance in _UTTERANCES:
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6) talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6)
start_time = int(float(start_time)) * sample_rate start_time = int(float(start_time)) * sample_rate
end_time = int(float(end_time)) * sample_rate end_time = int(float(end_time)) * sample_rate
...@@ -91,8 +85,20 @@ class Tedlium(TempDirMixin): ...@@ -91,8 +85,20 @@ class Tedlium(TempDirMixin):
speaker_id, speaker_id,
identifier, identifier,
) )
cls.samples[release].append(sample) mocked_samples[release].append(sample)
seed += 1 seed += 1
return mocked_samples
class Tedlium(TempDirMixin):
root_dir = None
samples = {}
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium")
cls.samples = get_mock_dataset(dataset_dir)
def _test_tedlium(self, dataset, release): def _test_tedlium(self, dataset, release):
num_samples = 0 num_samples = 0
...@@ -110,7 +116,7 @@ class Tedlium(TempDirMixin): ...@@ -110,7 +116,7 @@ class Tedlium(TempDirMixin):
dataset._dict_path = os.path.join(dataset._path, f"{release}.dic") dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()] phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME assert phoenemes == _PHONEME
def test_tedlium_release1_str(self): def test_tedlium_release1_str(self):
release = "release1" release = "release1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment