Unverified Commit 64956d54 authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Refactor TEDLIUM unittest (#1135)

parent 8f02af5f
import os import os
import platform import platform
import unittest
from pathlib import Path from pathlib import Path
from torchaudio.datasets import tedlium
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin,
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
get_whitenoise, get_whitenoise,
save_wav, save_wav,
normalize_wav,
) )
from torchaudio.datasets import tedlium
# Used to generate a unique utterance for each dummy audio file # Used to generate a unique utterance for each dummy audio file
UTTERANCES = [ _UTTERANCES = [
"AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 <o,f0,female> script1\n", "AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 <o,f0,female> script1\n",
"AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 <o,f0,female> script2\n", "AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 <o,f0,female> script2\n",
"AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 <o,f0,female> script3\n", "AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 <o,f0,female> script3\n",
...@@ -23,7 +20,7 @@ UTTERANCES = [ ...@@ -23,7 +20,7 @@ UTTERANCES = [
"AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 <o,f0,female> script5\n", "AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 <o,f0,female> script5\n",
] ]
PHONEME = [ _PHONEME = [
"a AH", "a AH",
"a(2) EY", "a(2) EY",
"aachen AA K AH N", "aachen AA K AH N",
...@@ -34,6 +31,65 @@ PHONEME = [ ...@@ -34,6 +31,65 @@ PHONEME = [
] ]
def get_mock_dataset(dataset_dir):
"""
dataset_dir: directory of the mocked dataset
"""
mocked_samples = {}
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz
seed = 0
for release in ["release1", "release2", "release3"]:
data = get_whitenoise(sample_rate=sample_rate, duration=10.00, n_channels=1, dtype="float32", seed=seed)
if release in ["release1", "release2"]:
release_dir = os.path.join(
dataset_dir,
tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
tedlium._RELEASE_CONFIGS[release]["subset"],
)
else:
release_dir = os.path.join(
dataset_dir,
tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
tedlium._RELEASE_CONFIGS[release]["data_path"],
)
os.makedirs(release_dir, exist_ok=True)
os.makedirs(os.path.join(release_dir, "stm"), exist_ok=True) # Subfolder for transcripts
os.makedirs(os.path.join(release_dir, "sph"), exist_ok=True) # Subfolder for audio files
filename = f"{release}.sph"
path = os.path.join(os.path.join(release_dir, "sph"), filename)
save_wav(path, data, sample_rate)
trans_filename = f"{release}.stm"
trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename)
with open(trans_path, "w") as f:
f.write("".join(_UTTERANCES))
dict_filename = f"{release}.dic"
dict_path = os.path.join(release_dir, dict_filename)
with open(dict_path, "w") as f:
f.write("\n".join(_PHONEME))
# Create a samples list to compare with
mocked_samples[release] = []
for utterance in _UTTERANCES:
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6)
start_time = int(float(start_time)) * sample_rate
end_time = int(float(end_time)) * sample_rate
sample = (
data[:, start_time:end_time],
sample_rate,
transcript,
talk_id,
speaker_id,
identifier,
)
mocked_samples[release].append(sample)
seed += 1
return mocked_samples
class Tedlium(TempDirMixin): class Tedlium(TempDirMixin):
root_dir = None root_dir = None
samples = {} samples = {}
...@@ -42,57 +98,7 @@ class Tedlium(TempDirMixin): ...@@ -42,57 +98,7 @@ class Tedlium(TempDirMixin):
def setUpClass(cls): def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir() cls.root_dir = cls.get_base_temp_dir()
cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium") cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium")
os.makedirs(dataset_dir, exist_ok=True) cls.samples = get_mock_dataset(dataset_dir)
sample_rate = 16000 # 16kHz
seed = 0
for release in ["release1", "release2", "release3"]:
data = get_whitenoise(sample_rate=sample_rate, duration=10.00, n_channels=1, dtype="float32", seed=seed)
if release in ["release1", "release2"]:
release_dir = os.path.join(
dataset_dir,
tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
tedlium._RELEASE_CONFIGS[release]["subset"],
)
else:
release_dir = os.path.join(
dataset_dir,
tedlium._RELEASE_CONFIGS[release]["folder_in_archive"],
tedlium._RELEASE_CONFIGS[release]["data_path"],
)
os.makedirs(release_dir, exist_ok=True)
os.makedirs(os.path.join(release_dir, "stm"), exist_ok=True) # Subfolder for transcripts
os.makedirs(os.path.join(release_dir, "sph"), exist_ok=True) # Subfolder for audio files
filename = f"{release}.sph"
path = os.path.join(os.path.join(release_dir, "sph"), filename)
save_wav(path, data, sample_rate)
trans_filename = f"{release}.stm"
trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename)
with open(trans_path, "w") as f:
f.write("".join(UTTERANCES))
dict_filename = f"{release}.dic"
dict_path = os.path.join(release_dir, dict_filename)
with open(dict_path, "w") as f:
f.write("\n".join(PHONEME))
# Create a samples list to compare with
cls.samples[release] = []
for utterance in UTTERANCES:
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6)
start_time = int(float(start_time)) * sample_rate
end_time = int(float(end_time)) * sample_rate
sample = (
data[:, start_time:end_time],
sample_rate,
transcript,
talk_id,
speaker_id,
identifier,
)
cls.samples[release].append(sample)
seed += 1
def _test_tedlium(self, dataset, release): def _test_tedlium(self, dataset, release):
num_samples = 0 num_samples = 0
...@@ -110,7 +116,7 @@ class Tedlium(TempDirMixin): ...@@ -110,7 +116,7 @@ class Tedlium(TempDirMixin):
dataset._dict_path = os.path.join(dataset._path, f"{release}.dic") dataset._dict_path = os.path.join(dataset._path, f"{release}.dic")
phoneme_dict = dataset.phoneme_dict phoneme_dict = dataset.phoneme_dict
phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()] phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()]
assert phoenemes == PHONEME assert phoenemes == _PHONEME
def test_tedlium_release1_str(self): def test_tedlium_release1_str(self):
release = "release1" release = "release1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment