import csv import os from pathlib import Path from torchaudio.datasets import ljspeech from torchaudio_unittest.common_utils import ( get_whitenoise, normalize_wav, save_wav, TempDirMixin, TorchaudioTestCase, ) _TRANSCRIPTS = [ "Test transcript 1", "Test transcript 2", "Test transcript 3", "In 1465 Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,", ] _NORMALIZED_TRANSCRIPT = [ "Test transcript one", "Test transcript two", "Test transcript three", "In fourteen sixty-five Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome,", ] def get_mock_dataset(root_dir): """ root_dir: path to the mocked dataset """ mocked_data = [] base_dir = os.path.join(root_dir, "LJSpeech-1.1") archive_dir = os.path.join(base_dir, "wavs") os.makedirs(archive_dir, exist_ok=True) metadata_path = os.path.join(base_dir, "metadata.csv") sample_rate = 22050 with open(metadata_path, mode="w", newline="") as metadata_file: metadata_writer = csv.writer(metadata_file, delimiter="|", quoting=csv.QUOTE_NONE) for i, (transcript, normalized_transcript) in enumerate(zip(_TRANSCRIPTS, _NORMALIZED_TRANSCRIPT)): fileid = f"LJ001-{i:04d}" metadata_writer.writerow([fileid, transcript, normalized_transcript]) filename = fileid + ".wav" path = os.path.join(archive_dir, filename) data = get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=1, dtype="int16", seed=i) save_wav(path, data, sample_rate) mocked_data.append(normalize_wav(data)) return mocked_data, _TRANSCRIPTS, _NORMALIZED_TRANSCRIPT class TestLJSpeech(TempDirMixin, TorchaudioTestCase): backend = "default" root_dir = None data, _transcripts, _normalized_transcript = [], [], [] @classmethod def setUpClass(cls): cls.root_dir = cls.get_base_temp_dir() cls.data, cls._transcripts, cls._normalized_transcript = get_mock_dataset(cls.root_dir) def _test_ljspeech(self, dataset): n_ite = 0 for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate(dataset): expected_transcript = self._transcripts[i] expected_normalized_transcript = self._normalized_transcript[i] expected_data = self.data[i] self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) assert sample_rate == sample_rate assert transcript == expected_transcript assert normalized_transcript == expected_normalized_transcript n_ite += 1 assert n_ite == len(self.data) def test_ljspeech_str(self): dataset = ljspeech.LJSPEECH(self.root_dir) self._test_ljspeech(dataset) def test_ljspeech_path(self): dataset = ljspeech.LJSPEECH(Path(self.root_dir)) self._test_ljspeech(dataset)