"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "fb9af9f36f131daf3dde1416faa4aa3cd1f203c7"
Unverified Commit 02e4f6d2 authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Refactor LibriTTS unittest (#1139)

parent 1838f927
import os import os
from pathlib import Path from pathlib import Path
from torchaudio.datasets.libritts import LIBRITTS
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -11,42 +9,55 @@ from torchaudio_unittest.common_utils import ( ...@@ -11,42 +9,55 @@ from torchaudio_unittest.common_utils import (
normalize_wav, normalize_wav,
) )
from torchaudio.datasets.libritts import LIBRITTS
_UTTERANCE_IDS = [
[19, 198, '000000', '000000'],
[26, 495, '000004', '000000'],
]
_ORIGINAL_TEXT = 'this is the original text.'
_NORMALIZED_TEXT = 'this is the normalized text.'
def get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""
mocked_data = []
base_dir = os.path.join(root_dir, 'LibriTTS', 'train-clean-100')
for i, utterance_id in enumerate(_UTTERANCE_IDS):
filename = f'{"_".join(str(u) for u in utterance_id)}.wav'
file_dir = os.path.join(base_dir, str(utterance_id[0]), str(utterance_id[1]))
os.makedirs(file_dir, exist_ok=True)
path = os.path.join(file_dir, filename)
data = get_whitenoise(sample_rate=24000, duration=2, n_channels=1, dtype='int16', seed=i)
save_wav(path, data, 24000)
mocked_data.append(normalize_wav(data))
original_text_filename = f'{"_".join(str(u) for u in utterance_id)}.original.txt'
path_original = os.path.join(file_dir, original_text_filename)
with open(path_original, 'w') as file_:
file_.write(_ORIGINAL_TEXT)
normalized_text_filename = f'{"_".join(str(u) for u in utterance_id)}.normalized.txt'
path_normalized = os.path.join(file_dir, normalized_text_filename)
with open(path_normalized, 'w') as file_:
file_.write(_NORMALIZED_TEXT)
return mocked_data, _UTTERANCE_IDS, _ORIGINAL_TEXT, _NORMALIZED_TEXT
class TestLibriTTS(TempDirMixin, TorchaudioTestCase): class TestLibriTTS(TempDirMixin, TorchaudioTestCase):
backend = 'default' backend = 'default'
root_dir = None root_dir = None
data = [] data = []
utterance_ids = [ _utterance_ids, _original_text, _normalized_text = [], [], []
[19, 198, '000000', '000000'],
[26, 495, '000004', '000000'],
]
original_text = 'this is the original text.'
normalized_text = 'this is the normalized text.'
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir() cls.root_dir = cls.get_base_temp_dir()
base_dir = os.path.join(cls.root_dir, 'LibriTTS', 'train-clean-100') cls.data, cls._utterance_ids, cls._original_text, cls._normalized_text = get_mock_dataset(cls.root_dir)
for i, utterance_id in enumerate(cls.utterance_ids):
filename = f'{"_".join(str(u) for u in utterance_id)}.wav'
file_dir = os.path.join(base_dir, str(utterance_id[0]), str(utterance_id[1]))
os.makedirs(file_dir, exist_ok=True)
path = os.path.join(file_dir, filename)
data = get_whitenoise(sample_rate=24000, duration=2, n_channels=1, dtype='int16', seed=i)
save_wav(path, data, 24000)
cls.data.append(normalize_wav(data))
original_text_filename = f'{"_".join(str(u) for u in utterance_id)}.original.txt'
path_original = os.path.join(file_dir, original_text_filename)
with open(path_original, 'w') as file_:
file_.write(cls.original_text)
normalized_text_filename = f'{"_".join(str(u) for u in utterance_id)}.normalized.txt'
path_normalized = os.path.join(file_dir, normalized_text_filename)
with open(path_normalized, 'w') as file_:
file_.write(cls.normalized_text)
def _test_libritts(self, dataset): def _test_libritts(self, dataset):
n_ites = 0 n_ites = 0
...@@ -57,18 +68,17 @@ class TestLibriTTS(TempDirMixin, TorchaudioTestCase): ...@@ -57,18 +68,17 @@ class TestLibriTTS(TempDirMixin, TorchaudioTestCase):
speaker_id, speaker_id,
chapter_id, chapter_id,
utterance_id) in enumerate(dataset): utterance_id) in enumerate(dataset):
expected_ids = self._utterance_ids[i]
expected_ids = self.utterance_ids[i]
expected_data = self.data[i] expected_data = self.data[i]
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
assert sample_rate == 24000 assert sample_rate == 24000
assert speaker_id == expected_ids[0] assert speaker_id == expected_ids[0]
assert chapter_id == expected_ids[1] assert chapter_id == expected_ids[1]
assert original_text == self.original_text assert original_text == self._original_text
assert normalized_text == self.normalized_text assert normalized_text == self._normalized_text
assert utterance_id == f'{"_".join(str(u) for u in expected_ids[-4:])}' assert utterance_id == f'{"_".join(str(u) for u in expected_ids[-4:])}'
n_ites += 1 n_ites += 1
assert n_ites == len(self.utterance_ids) assert n_ites == len(self._utterance_ids)
def test_libritts_str(self): def test_libritts_str(self):
dataset = LIBRITTS(self.root_dir) dataset = LIBRITTS(self.root_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment