Unverified Commit edaeda4f authored by Kshiteej K's avatar Kshiteej K Committed by GitHub
Browse files

Add pathlib.Path support to `gtzan` (#1032)


Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 55175003
import os import os
from pathlib import Path
from torchaudio.datasets import gtzan from torchaudio.datasets import gtzan
...@@ -54,9 +55,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase): ...@@ -54,9 +55,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
n_ite += 1 n_ite += 1
assert n_ite == len(self.samples) assert n_ite == len(self.samples)
def test_training(self): def _test_training(self, dataset):
dataset = gtzan.GTZAN(self.root_dir, subset='training')
n_ite = 0 n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset): for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8) self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8)
...@@ -65,9 +64,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase): ...@@ -65,9 +64,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
n_ite += 1 n_ite += 1
assert n_ite == len(self.training) assert n_ite == len(self.training)
def test_validation(self): def _test_validation(self, dataset):
dataset = gtzan.GTZAN(self.root_dir, subset='validation')
n_ite = 0 n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset): for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8) self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8)
...@@ -76,9 +73,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase): ...@@ -76,9 +73,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
n_ite += 1 n_ite += 1
assert n_ite == len(self.validation) assert n_ite == len(self.validation)
def test_testing(self): def _test_testing(self, dataset):
dataset = gtzan.GTZAN(self.root_dir, subset='testing')
n_ite = 0 n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset): for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8) self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8)
...@@ -86,3 +81,30 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase): ...@@ -86,3 +81,30 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
assert label == self.testing[i][2] assert label == self.testing[i][2]
n_ite += 1 n_ite += 1
assert n_ite == len(self.testing) assert n_ite == len(self.testing)
def test_training_str(self):
train_dataset = gtzan.GTZAN(self.root_dir, subset='training')
self._test_training(train_dataset)
def test_validation_str(self):
val_dataset = gtzan.GTZAN(self.root_dir, subset='validation')
self._test_validation(val_dataset)
def test_testing_str(self):
test_dataset = gtzan.GTZAN(self.root_dir, subset='testing')
self._test_testing(test_dataset)
def test_training_path(self):
root_dir = Path(self.root_dir)
train_dataset = gtzan.GTZAN(root_dir, subset='training')
self._test_training(train_dataset)
def test_validation_path(self):
root_dir = Path(self.root_dir)
val_dataset = gtzan.GTZAN(root_dir, subset='validation')
self._test_validation(val_dataset)
def test_testing_path(self):
root_dir = Path(self.root_dir)
test_dataset = gtzan.GTZAN(root_dir, subset='testing')
self._test_testing(test_dataset)
import os import os
import warnings import warnings
from typing import Any, Tuple, Optional from pathlib import Path
from typing import Any, Tuple, Optional, Union
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
...@@ -1005,7 +1006,7 @@ class GTZAN(Dataset): ...@@ -1005,7 +1006,7 @@ class GTZAN(Dataset):
this dataset to publish results. this dataset to publish results.
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from. url (str, optional): The URL to download the dataset from.
(default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``) (default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
folder_in_archive (str, optional): The top-level directory of the dataset. folder_in_archive (str, optional): The top-level directory of the dataset.
...@@ -1020,7 +1021,7 @@ class GTZAN(Dataset): ...@@ -1020,7 +1021,7 @@ class GTZAN(Dataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False, download: bool = False,
...@@ -1028,6 +1029,10 @@ class GTZAN(Dataset): ...@@ -1028,6 +1029,10 @@ class GTZAN(Dataset):
) -> None: ) -> None:
# super(GTZAN, self).__init__() # super(GTZAN, self).__init__()
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
self.root = root self.root = root
self.url = url self.url = url
self.folder_in_archive = folder_in_archive self.folder_in_archive = folder_in_archive
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment