Unverified Commit b9ee0139 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

YesNo Dataset Pathlib change (#1015)


Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 5630fe35
import os import os
from pathlib import Path
from torchaudio.datasets import yesno from torchaudio.datasets import yesno
...@@ -36,8 +37,7 @@ class TestYesNo(TempDirMixin, TorchaudioTestCase): ...@@ -36,8 +37,7 @@ class TestYesNo(TempDirMixin, TorchaudioTestCase):
save_wav(path, data, 8000) save_wav(path, data, 8000)
cls.data.append(normalize_wav(data)) cls.data.append(normalize_wav(data))
def test_yesno(self): def _test_yesno(self, dataset):
dataset = yesno.YESNO(self.root_dir)
n_ite = 0 n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset): for i, (waveform, sample_rate, label) in enumerate(dataset):
expected_label = self.labels[i] expected_label = self.labels[i]
...@@ -47,3 +47,11 @@ class TestYesNo(TempDirMixin, TorchaudioTestCase): ...@@ -47,3 +47,11 @@ class TestYesNo(TempDirMixin, TorchaudioTestCase):
assert label == expected_label assert label == expected_label
n_ite += 1 n_ite += 1
assert n_ite == len(self.data) assert n_ite == len(self.data)
def test_yesno_str(self):
dataset = yesno.YESNO(self.root_dir)
self._test_yesno(dataset)
def test_yesno_path(self):
dataset = yesno.YESNO(Path(self.root_dir))
self._test_yesno(dataset)
import os import os
import warnings import warnings
from typing import Any, List, Tuple from typing import Any, List, Tuple, Union
from pathlib import Path
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
...@@ -34,7 +35,7 @@ class YESNO(Dataset): ...@@ -34,7 +35,7 @@ class YESNO(Dataset):
"""Create a Dataset for YesNo. """Create a Dataset for YesNo.
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from. url (str, optional): The URL to download the dataset from.
(default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``) (default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``)
folder_in_archive (str, optional): folder_in_archive (str, optional):
...@@ -48,7 +49,7 @@ class YESNO(Dataset): ...@@ -48,7 +49,7 @@ class YESNO(Dataset):
_ext_audio = ".wav" _ext_audio = ".wav"
def __init__(self, def __init__(self,
root: str, root: Union[str, Path],
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False, download: bool = False,
...@@ -65,6 +66,9 @@ class YESNO(Dataset): ...@@ -65,6 +66,9 @@ class YESNO(Dataset):
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
archive = os.path.basename(url) archive = os.path.basename(url)
archive = os.path.join(root, archive) archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive) self._path = os.path.join(root, folder_in_archive)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment