Unverified Commit 1838f927 authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Refactor LJSpeech unittest (#1138)

parent 64956d54
...@@ -2,8 +2,6 @@ import csv ...@@ -2,8 +2,6 @@ import csv
import os import os
from pathlib import Path from pathlib import Path
from torchaudio.datasets import ljspeech
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -12,30 +10,29 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,30 +10,29 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
) )
from torchaudio.datasets import ljspeech
class TestLJSpeech(TempDirMixin, TorchaudioTestCase): _TRANSCRIPTS = [
backend = "default"
root_dir = None
data = []
transcripts = [
"Test transcript 1", "Test transcript 1",
"Test transcript 2", "Test transcript 2",
"Test transcript 3", "Test transcript 3",
"In 1465 Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome," "In 1465 Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,"
] ]
normalized_transcripts = [ _NORMALIZED_TRANSCRIPT = [
"Test transcript one", "Test transcript one",
"Test transcript two", "Test transcript two",
"Test transcript three", "Test transcript three",
"In fourteen sixty-five Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome," "In fourteen sixty-five Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,"
] ]
@classmethod
def setUpClass(cls): def get_mock_dataset(root_dir):
cls.root_dir = cls.get_base_temp_dir() """
base_dir = os.path.join(cls.root_dir, "LJSpeech-1.1") root_dir: path to the mocked dataset
"""
mocked_data = []
base_dir = os.path.join(root_dir, "LJSpeech-1.1")
archive_dir = os.path.join(base_dir, "wavs") archive_dir = os.path.join(base_dir, "wavs")
os.makedirs(archive_dir, exist_ok=True) os.makedirs(archive_dir, exist_ok=True)
metadata_path = os.path.join(base_dir, "metadata.csv") metadata_path = os.path.join(base_dir, "metadata.csv")
...@@ -46,7 +43,7 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase): ...@@ -46,7 +43,7 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
metadata_file, delimiter="|", quoting=csv.QUOTE_NONE metadata_file, delimiter="|", quoting=csv.QUOTE_NONE
) )
for i, (transcript, normalized_transcript) in enumerate( for i, (transcript, normalized_transcript) in enumerate(
zip(cls.transcripts, cls.normalized_transcripts) zip(_TRANSCRIPTS, _NORMALIZED_TRANSCRIPT)
): ):
fileid = f'LJ001-{i:04d}' fileid = f'LJ001-{i:04d}'
metadata_writer.writerow([fileid, transcript, normalized_transcript]) metadata_writer.writerow([fileid, transcript, normalized_transcript])
...@@ -56,15 +53,28 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase): ...@@ -56,15 +53,28 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
sample_rate=sample_rate, duration=1, n_channels=1, dtype="int16", seed=i sample_rate=sample_rate, duration=1, n_channels=1, dtype="int16", seed=i
) )
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
cls.data.append(normalize_wav(data)) mocked_data.append(normalize_wav(data))
return mocked_data, _TRANSCRIPTS, _NORMALIZED_TRANSCRIPT
class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
data, _transcripts, _normalized_transcript = [], [], []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.data, cls._transcripts, cls._normalized_transcript = get_mock_dataset(cls.root_dir)
def _test_ljspeech(self, dataset): def _test_ljspeech(self, dataset):
n_ite = 0 n_ite = 0
for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate( for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate(
dataset dataset
): ):
expected_transcript = self.transcripts[i] expected_transcript = self._transcripts[i]
expected_normalized_transcript = self.normalized_transcripts[i] expected_normalized_transcript = self._normalized_transcript[i]
expected_data = self.data[i] expected_data = self.data[i]
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
assert sample_rate == sample_rate assert sample_rate == sample_rate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment