Unverified Commit b9ee0139 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

YesNo Dataset Pathlib change (#1015)


Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 5630fe35
import os
from pathlib import Path
from torchaudio.datasets import yesno
......@@ -36,8 +37,7 @@ class TestYesNo(TempDirMixin, TorchaudioTestCase):
save_wav(path, data, 8000)
cls.data.append(normalize_wav(data))
def test_yesno(self):
dataset = yesno.YESNO(self.root_dir)
def _test_yesno(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
expected_label = self.labels[i]
......@@ -47,3 +47,11 @@ class TestYesNo(TempDirMixin, TorchaudioTestCase):
assert label == expected_label
n_ite += 1
assert n_ite == len(self.data)
def test_yesno_str(self):
dataset = yesno.YESNO(self.root_dir)
self._test_yesno(dataset)
def test_yesno_path(self):
dataset = yesno.YESNO(Path(self.root_dir))
self._test_yesno(dataset)
import os
import warnings
from typing import Any, List, Tuple
from typing import Any, List, Tuple, Union
from pathlib import Path
import torchaudio
from torch import Tensor
......@@ -34,7 +35,7 @@ class YESNO(Dataset):
"""Create a Dataset for YesNo.
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
(default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``)
folder_in_archive (str, optional):
......@@ -48,7 +49,7 @@ class YESNO(Dataset):
_ext_audio = ".wav"
def __init__(self,
root: str,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
......@@ -65,6 +66,9 @@ class YESNO(Dataset):
self.transform = transform
self.target_transform = target_transform
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment