Unverified Commit 5630fe35 authored by Kshiteej K's avatar Kshiteej K Committed by GitHub
Browse files

Add pathlib.Path support to `cmuarctic` (#1025)

parent 8d0c08db
import os import os
from pathlib import Path
from torchaudio.datasets import cmuarctic from torchaudio.datasets import cmuarctic
...@@ -54,8 +55,7 @@ class TestCMUARCTIC(TempDirMixin, TorchaudioTestCase): ...@@ -54,8 +55,7 @@ class TestCMUARCTIC(TempDirMixin, TorchaudioTestCase):
txt.write(f'( {utterance_id} "{utterance}" )\n') txt.write(f'( {utterance_id} "{utterance}" )\n')
seed += 1 seed += 1
def test_cmuarctic(self): def _test_cmuarctic(self, dataset):
dataset = cmuarctic.CMUARCTIC(self.root_dir)
n_ite = 0 n_ite = 0
for i, (waveform, sample_rate, utterance, utterance_id) in enumerate(dataset): for i, (waveform, sample_rate, utterance, utterance_id) in enumerate(dataset):
expected_sample = self.samples[i] expected_sample = self.samples[i]
...@@ -65,3 +65,11 @@ class TestCMUARCTIC(TempDirMixin, TorchaudioTestCase): ...@@ -65,3 +65,11 @@ class TestCMUARCTIC(TempDirMixin, TorchaudioTestCase):
self.assertEqual(expected_sample[0], waveform, atol=5e-5, rtol=1e-8) self.assertEqual(expected_sample[0], waveform, atol=5e-5, rtol=1e-8)
n_ite += 1 n_ite += 1
assert n_ite == len(self.samples) assert n_ite == len(self.samples)
def test_cmuarctic_str(self):
dataset = cmuarctic.CMUARCTIC(self.root_dir)
self._test_cmuarctic(dataset)
def test_cmuarctic_path(self):
dataset = cmuarctic.CMUARCTIC(Path(self.root_dir))
self._test_cmuarctic(dataset)
import os import os
from typing import Tuple from pathlib import Path
from typing import Tuple, Union
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
...@@ -79,7 +80,7 @@ class CMUARCTIC(Dataset): ...@@ -79,7 +80,7 @@ class CMUARCTIC(Dataset):
"""Create a Dataset for CMU_ARCTIC. """Create a Dataset for CMU_ARCTIC.
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): url (str, optional):
The URL to download the dataset from or the type of the dataset to dowload. The URL to download the dataset from or the type of the dataset to dowload.
(default: ``"aew"``) (default: ``"aew"``)
...@@ -98,7 +99,7 @@ class CMUARCTIC(Dataset): ...@@ -98,7 +99,7 @@ class CMUARCTIC(Dataset):
_folder_audio = "wav" _folder_audio = "wav"
def __init__(self, def __init__(self,
root: str, root: Union[str, Path],
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None: download: bool = False) -> None:
...@@ -130,6 +131,9 @@ class CMUARCTIC(Dataset): ...@@ -130,6 +131,9 @@ class CMUARCTIC(Dataset):
url = os.path.join(base_url, url + ext_archive) url = os.path.join(base_url, url + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url) basename = os.path.basename(url)
root = os.path.join(root, folder_in_archive) root = os.path.join(root, folder_in_archive)
if not os.path.isdir(root): if not os.path.isdir(root):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment