Unverified Commit b5c16d33 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Add pathlib support for LIBRITTS and LIBRISPEECH (#1046)

parent 37b4e136
import os import os
from pathlib import Path
from torchaudio.datasets import librispeech from torchaudio.datasets import librispeech
...@@ -91,11 +92,7 @@ class TestLibriSpeech(TempDirMixin, TorchaudioTestCase): ...@@ -91,11 +92,7 @@ class TestLibriSpeech(TempDirMixin, TorchaudioTestCase):
# In case of test failure # In case of test failure
librispeech.LIBRISPEECH._ext_audio = '.flac' librispeech.LIBRISPEECH._ext_audio = '.flac'
def test_librispeech(self): def _test_librispeech(self, dataset):
librispeech.LIBRISPEECH._ext_audio = '.wav'
dataset = librispeech.LIBRISPEECH(self.root_dir)
print(dataset._path)
num_samples = 0 num_samples = 0
for i, ( for i, (
data, sample_rate, utterance, speaker_id, chapter_id, utterance_id data, sample_rate, utterance, speaker_id, chapter_id, utterance_id
...@@ -110,3 +107,13 @@ class TestLibriSpeech(TempDirMixin, TorchaudioTestCase): ...@@ -110,3 +107,13 @@ class TestLibriSpeech(TempDirMixin, TorchaudioTestCase):
assert num_samples == len(self.samples) assert num_samples == len(self.samples)
librispeech.LIBRISPEECH._ext_audio = '.flac' librispeech.LIBRISPEECH._ext_audio = '.flac'
def test_librispeech_str(self):
librispeech.LIBRISPEECH._ext_audio = '.wav'
dataset = librispeech.LIBRISPEECH(self.root_dir)
self._test_librispeech(dataset)
def test_librispeech_path(self):
librispeech.LIBRISPEECH._ext_audio = '.wav'
dataset = librispeech.LIBRISPEECH(Path(self.root_dir))
self._test_librispeech(dataset)
import os import os
from pathlib import Path
from torchaudio.datasets.libritts import LIBRITTS from torchaudio.datasets.libritts import LIBRITTS
...@@ -47,8 +48,7 @@ class TestLibriTTS(TempDirMixin, TorchaudioTestCase): ...@@ -47,8 +48,7 @@ class TestLibriTTS(TempDirMixin, TorchaudioTestCase):
with open(path_normalized, 'w') as file_: with open(path_normalized, 'w') as file_:
file_.write(cls.normalized_text) file_.write(cls.normalized_text)
def test_libritts(self): def _test_libritts(self, dataset):
dataset = LIBRITTS(self.root_dir)
n_ites = 0 n_ites = 0
for i, (waveform, for i, (waveform,
sample_rate, sample_rate,
...@@ -69,3 +69,11 @@ class TestLibriTTS(TempDirMixin, TorchaudioTestCase): ...@@ -69,3 +69,11 @@ class TestLibriTTS(TempDirMixin, TorchaudioTestCase):
assert utterance_id == f'{"_".join(str(u) for u in expected_ids[-4:])}' assert utterance_id == f'{"_".join(str(u) for u in expected_ids[-4:])}'
n_ites += 1 n_ites += 1
assert n_ites == len(self.utterance_ids) assert n_ites == len(self.utterance_ids)
def test_libritts_str(self):
dataset = LIBRITTS(self.root_dir)
self._test_libritts(dataset)
def test_libritts_path(self):
dataset = LIBRITTS(Path(self.root_dir))
self._test_libritts(dataset)
import os import os
from typing import Tuple from typing import Tuple, Union
from pathlib import Path
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
...@@ -70,7 +71,7 @@ class LIBRISPEECH(Dataset): ...@@ -70,7 +71,7 @@ class LIBRISPEECH(Dataset):
"""Create a Dataset for LibriSpeech. """Create a Dataset for LibriSpeech.
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from, url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload. or the type of the dataset to dowload.
Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``, Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
...@@ -86,7 +87,7 @@ class LIBRISPEECH(Dataset): ...@@ -86,7 +87,7 @@ class LIBRISPEECH(Dataset):
_ext_audio = ".flac" _ext_audio = ".flac"
def __init__(self, def __init__(self,
root: str, root: Union[str, Path],
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None: download: bool = False) -> None:
...@@ -106,6 +107,9 @@ class LIBRISPEECH(Dataset): ...@@ -106,6 +107,9 @@ class LIBRISPEECH(Dataset):
url = os.path.join(base_url, url + ext_archive) url = os.path.join(base_url, url + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(root, basename) archive = os.path.join(root, basename)
......
import os import os
from typing import Tuple from typing import Tuple, Union
from pathlib import Path
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
...@@ -68,7 +69,7 @@ class LIBRITTS(Dataset): ...@@ -68,7 +69,7 @@ class LIBRITTS(Dataset):
"""Create a Dataset for LibriTTS. """Create a Dataset for LibriTTS.
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from, url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload. or the type of the dataset to dowload.
Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``, Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
...@@ -86,7 +87,7 @@ class LIBRITTS(Dataset): ...@@ -86,7 +87,7 @@ class LIBRITTS(Dataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False, download: bool = False,
...@@ -107,6 +108,9 @@ class LIBRITTS(Dataset): ...@@ -107,6 +108,9 @@ class LIBRITTS(Dataset):
url = os.path.join(base_url, url + ext_archive) url = os.path.join(base_url, url + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(root, basename) archive = os.path.join(root, basename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment