Unverified Commit 1838f927 authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Refactor LJSpeech unittest (#1138)

parent 64956d54
...@@ -2,8 +2,6 @@ import csv ...@@ -2,8 +2,6 @@ import csv
import os import os
from pathlib import Path from pathlib import Path
from torchaudio.datasets import ljspeech
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -12,59 +10,71 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,59 +10,71 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
) )
from torchaudio.datasets import ljspeech
_TRANSCRIPTS = [
"Test transcript 1",
"Test transcript 2",
"Test transcript 3",
"In 1465 Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,"
]
_NORMALIZED_TRANSCRIPT = [
"Test transcript one",
"Test transcript two",
"Test transcript three",
"In fourteen sixty-five Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,"
]
def get_mock_dataset(root_dir):
"""
root_dir: path to the mocked dataset
"""
mocked_data = []
base_dir = os.path.join(root_dir, "LJSpeech-1.1")
archive_dir = os.path.join(base_dir, "wavs")
os.makedirs(archive_dir, exist_ok=True)
metadata_path = os.path.join(base_dir, "metadata.csv")
sample_rate = 22050
with open(metadata_path, mode="w", newline='') as metadata_file:
metadata_writer = csv.writer(
metadata_file, delimiter="|", quoting=csv.QUOTE_NONE
)
for i, (transcript, normalized_transcript) in enumerate(
zip(_TRANSCRIPTS, _NORMALIZED_TRANSCRIPT)
):
fileid = f'LJ001-{i:04d}'
metadata_writer.writerow([fileid, transcript, normalized_transcript])
filename = fileid + ".wav"
path = os.path.join(archive_dir, filename)
data = get_whitenoise(
sample_rate=sample_rate, duration=1, n_channels=1, dtype="int16", seed=i
)
save_wav(path, data, sample_rate)
mocked_data.append(normalize_wav(data))
return mocked_data, _TRANSCRIPTS, _NORMALIZED_TRANSCRIPT
class TestLJSpeech(TempDirMixin, TorchaudioTestCase): class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
backend = "default" backend = "default"
root_dir = None root_dir = None
data = [] data, _transcripts, _normalized_transcript = [], [], []
transcripts = [
"Test transcript 1",
"Test transcript 2",
"Test transcript 3",
"In 1465 Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,"
]
normalized_transcripts = [
"Test transcript one",
"Test transcript two",
"Test transcript three",
"In fourteen sixty-five Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,"
]
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir() cls.root_dir = cls.get_base_temp_dir()
base_dir = os.path.join(cls.root_dir, "LJSpeech-1.1") cls.data, cls._transcripts, cls._normalized_transcript = get_mock_dataset(cls.root_dir)
archive_dir = os.path.join(base_dir, "wavs")
os.makedirs(archive_dir, exist_ok=True)
metadata_path = os.path.join(base_dir, "metadata.csv")
sample_rate = 22050
with open(metadata_path, mode="w", newline='') as metadata_file:
metadata_writer = csv.writer(
metadata_file, delimiter="|", quoting=csv.QUOTE_NONE
)
for i, (transcript, normalized_transcript) in enumerate(
zip(cls.transcripts, cls.normalized_transcripts)
):
fileid = f'LJ001-{i:04d}'
metadata_writer.writerow([fileid, transcript, normalized_transcript])
filename = fileid + ".wav"
path = os.path.join(archive_dir, filename)
data = get_whitenoise(
sample_rate=sample_rate, duration=1, n_channels=1, dtype="int16", seed=i
)
save_wav(path, data, sample_rate)
cls.data.append(normalize_wav(data))
def _test_ljspeech(self, dataset): def _test_ljspeech(self, dataset):
n_ite = 0 n_ite = 0
for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate( for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate(
dataset dataset
): ):
expected_transcript = self.transcripts[i] expected_transcript = self._transcripts[i]
expected_normalized_transcript = self.normalized_transcripts[i] expected_normalized_transcript = self._normalized_transcript[i]
expected_data = self.data[i] expected_data = self.data[i]
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
assert sample_rate == sample_rate assert sample_rate == sample_rate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment