Unverified Commit edaeda4f authored by Kshiteej K's avatar Kshiteej K Committed by GitHub
Browse files

Add pathlib.Path support to `gtzan` (#1032)


Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 55175003
import os
from pathlib import Path
from torchaudio.datasets import gtzan
......@@ -54,9 +55,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
n_ite += 1
assert n_ite == len(self.samples)
def test_training(self):
dataset = gtzan.GTZAN(self.root_dir, subset='training')
def _test_training(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8)
......@@ -65,9 +64,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
n_ite += 1
assert n_ite == len(self.training)
def test_validation(self):
dataset = gtzan.GTZAN(self.root_dir, subset='validation')
def _test_validation(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8)
......@@ -76,9 +73,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
n_ite += 1
assert n_ite == len(self.validation)
def test_testing(self):
dataset = gtzan.GTZAN(self.root_dir, subset='testing')
def _test_testing(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8)
......@@ -86,3 +81,30 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
assert label == self.testing[i][2]
n_ite += 1
assert n_ite == len(self.testing)
def test_training_str(self):
train_dataset = gtzan.GTZAN(self.root_dir, subset='training')
self._test_training(train_dataset)
def test_validation_str(self):
val_dataset = gtzan.GTZAN(self.root_dir, subset='validation')
self._test_validation(val_dataset)
def test_testing_str(self):
test_dataset = gtzan.GTZAN(self.root_dir, subset='testing')
self._test_testing(test_dataset)
def test_training_path(self):
root_dir = Path(self.root_dir)
train_dataset = gtzan.GTZAN(root_dir, subset='training')
self._test_training(train_dataset)
def test_validation_path(self):
root_dir = Path(self.root_dir)
val_dataset = gtzan.GTZAN(root_dir, subset='validation')
self._test_validation(val_dataset)
def test_testing_path(self):
root_dir = Path(self.root_dir)
test_dataset = gtzan.GTZAN(root_dir, subset='testing')
self._test_testing(test_dataset)
import os
import warnings
from typing import Any, Tuple, Optional
from pathlib import Path
from typing import Any, Tuple, Optional, Union
import torchaudio
from torch import Tensor
......@@ -1005,7 +1006,7 @@ class GTZAN(Dataset):
this dataset to publish results.
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
(default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
folder_in_archive (str, optional): The top-level directory of the dataset.
......@@ -1020,7 +1021,7 @@ class GTZAN(Dataset):
def __init__(
self,
root: str,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
......@@ -1028,6 +1029,10 @@ class GTZAN(Dataset):
) -> None:
# super(GTZAN, self).__init__()
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
self.root = root
self.url = url
self.folder_in_archive = folder_in_archive
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment