Unverified Commit 55175003 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Pathlib support for VCTK and LJSPEECH (#1028)

parent 0cf4b8a9
import csv import csv
import os import os
from pathlib import Path
from torchaudio.datasets import ljspeech from torchaudio.datasets import ljspeech
...@@ -57,8 +58,7 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase): ...@@ -57,8 +58,7 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
cls.data.append(normalize_wav(data)) cls.data.append(normalize_wav(data))
def test_ljspeech(self): def _test_ljspeech(self, dataset):
dataset = ljspeech.LJSPEECH(self.root_dir)
n_ite = 0 n_ite = 0
for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate( for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate(
dataset dataset
...@@ -72,3 +72,11 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase): ...@@ -72,3 +72,11 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
assert normalized_transcript == expected_normalized_transcript assert normalized_transcript == expected_normalized_transcript
n_ite += 1 n_ite += 1
assert n_ite == len(self.data) assert n_ite == len(self.data)
def test_ljspeech_str(self):
dataset = ljspeech.LJSPEECH(self.root_dir)
self._test_ljspeech(dataset)
def test_ljspeech_path(self):
dataset = ljspeech.LJSPEECH(Path(self.root_dir))
self._test_ljspeech(dataset)
import os import os
from pathlib import Path
from torchaudio.datasets import vctk from torchaudio.datasets import vctk
...@@ -77,8 +78,7 @@ class TestVCTK(TempDirMixin, TorchaudioTestCase): ...@@ -77,8 +78,7 @@ class TestVCTK(TempDirMixin, TorchaudioTestCase):
seed += 1 seed += 1
def test_vctk(self): def _test_vctk(self, dataset):
dataset = vctk.VCTK_092(self.root_dir, audio_ext=".wav")
num_samples = 0 num_samples = 0
for i, (data, sample_rate, utterance, speaker_id, utterance_id) in enumerate(dataset): for i, (data, sample_rate, utterance, speaker_id, utterance_id) in enumerate(dataset):
self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8) self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8)
...@@ -89,3 +89,11 @@ class TestVCTK(TempDirMixin, TorchaudioTestCase): ...@@ -89,3 +89,11 @@ class TestVCTK(TempDirMixin, TorchaudioTestCase):
num_samples += 1 num_samples += 1
assert num_samples == len(self.samples) assert num_samples == len(self.samples)
def test_vctk_str(self):
dataset = vctk.VCTK_092(self.root_dir, audio_ext=".wav")
self._test_vctk(dataset)
def test_vctk_path(self):
dataset = vctk.VCTK_092(Path(self.root_dir), audio_ext=".wav")
self._test_vctk(dataset)
import os import os
import csv import csv
from typing import List, Tuple from typing import List, Tuple, Union
from pathlib import Path
import torchaudio import torchaudio
from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader from torchaudio.datasets.utils import download_url, extract_archive, unicode_csv_reader
...@@ -36,7 +37,7 @@ class LJSPEECH(Dataset): ...@@ -36,7 +37,7 @@ class LJSPEECH(Dataset):
"""Create a Dataset for LJSpeech-1.1. """Create a Dataset for LJSpeech-1.1.
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from. url (str, optional): The URL to download the dataset from.
(default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``) (default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``)
folder_in_archive (str, optional): folder_in_archive (str, optional):
...@@ -49,11 +50,14 @@ class LJSPEECH(Dataset): ...@@ -49,11 +50,14 @@ class LJSPEECH(Dataset):
_ext_archive = '.tar.bz2' _ext_archive = '.tar.bz2'
def __init__(self, def __init__(self,
root: str, root: Union[str, Path],
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None: download: bool = False) -> None:
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(root, basename) archive = os.path.join(root, basename)
......
import os import os
import warnings import warnings
from typing import Any, Tuple from typing import Any, Tuple, Union
from pathlib import Path
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
...@@ -57,7 +58,7 @@ class VCTK(Dataset): ...@@ -57,7 +58,7 @@ class VCTK(Dataset):
For more information about the dataset visit: https://datashare.is.ed.ac.uk/handle/10283/3443 For more information about the dataset visit: https://datashare.is.ed.ac.uk/handle/10283/3443
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): Not used as the dataset is no longer publicly available. url (str, optional): Not used as the dataset is no longer publicly available.
folder_in_archive (str, optional): folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"VCTK-Corpus"``) The top-level directory of the dataset. (default: ``"VCTK-Corpus"``)
...@@ -77,7 +78,7 @@ class VCTK(Dataset): ...@@ -77,7 +78,7 @@ class VCTK(Dataset):
_except_folder = "p315" _except_folder = "p315"
def __init__(self, def __init__(self,
root: str, root: Union[str, Path],
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False, download: bool = False,
...@@ -103,6 +104,9 @@ class VCTK(Dataset): ...@@ -103,6 +104,9 @@ class VCTK(Dataset):
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
archive = os.path.basename(url) archive = os.path.basename(url)
archive = os.path.join(root, archive) archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive) self._path = os.path.join(root, folder_in_archive)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment