Unverified Commit 550b6a30 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Add pathlib support for SPEECHCOMMANDS (#1039)

parent 619da1f2
import os import os
from pathlib import Path
from torchaudio.datasets import speechcommands from torchaudio.datasets import speechcommands
...@@ -105,85 +106,39 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase): ...@@ -105,85 +106,39 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
test.write(f'{label}/{filename}\n') test.write(f'{label}/{filename}\n')
cls.test_samples.append(sample) cls.test_samples.append(sample)
def testSpeechCommands(self): def _testSpeechCommands(self, dataset, data_samples):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir)
num_samples = 0 num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate( for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset dataset
): ):
self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8) self.assertEqual(data, data_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[i][1] assert sample_rate == data_samples[i][1]
assert label == self.samples[i][2] assert label == data_samples[i][2]
assert speaker_id == self.samples[i][3] assert speaker_id == data_samples[i][3]
assert utterance_number == self.samples[i][4] assert utterance_number == data_samples[i][4]
num_samples += 1 num_samples += 1
assert num_samples == len(self.samples) assert num_samples == len(data_samples)
def testSpeechCommandsNone(self): def testSpeechCommands_str(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset=None) dataset = speechcommands.SPEECHCOMMANDS(self.root_dir)
self._testSpeechCommands(dataset, self.samples)
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[i][1]
assert label == self.samples[i][2]
assert speaker_id == self.samples[i][3]
assert utterance_number == self.samples[i][4]
num_samples += 1
assert num_samples == len(self.samples) def testSpeechCommands_path(self):
dataset = speechcommands.SPEECHCOMMANDS(Path(self.root_dir))
self._testSpeechCommands(dataset, self.samples)
def testSpeechCommandsSubsetTrain(self): def testSpeechCommandsSubsetTrain(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training") dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")
self._testSpeechCommands(dataset, self.train_samples)
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.train_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.train_samples[i][1]
assert label == self.train_samples[i][2]
assert speaker_id == self.train_samples[i][3]
assert utterance_number == self.train_samples[i][4]
num_samples += 1
assert num_samples == len(self.train_samples)
def testSpeechCommandsSubsetValid(self): def testSpeechCommandsSubsetValid(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation") dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
self._testSpeechCommands(dataset, self.valid_samples)
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.valid_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.valid_samples[i][1]
assert label == self.valid_samples[i][2]
assert speaker_id == self.valid_samples[i][3]
assert utterance_number == self.valid_samples[i][4]
num_samples += 1
assert num_samples == len(self.valid_samples)
def testSpeechCommandsSubsetTest(self): def testSpeechCommandsSubsetTest(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing") dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")
self._testSpeechCommands(dataset, self.test_samples)
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.test_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.test_samples[i][1]
assert label == self.test_samples[i][2]
assert speaker_id == self.test_samples[i][3]
assert utterance_number == self.test_samples[i][4]
num_samples += 1
assert num_samples == len(self.test_samples)
def testSpeechCommandsSum(self): def testSpeechCommandsSum(self):
dataset_all = speechcommands.SPEECHCOMMANDS(self.root_dir) dataset_all = speechcommands.SPEECHCOMMANDS(self.root_dir)
......
import os import os
from typing import Tuple, Optional from typing import Tuple, Optional, Union
from pathlib import Path
import torchaudio import torchaudio
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -48,7 +49,7 @@ class SPEECHCOMMANDS(Dataset): ...@@ -48,7 +49,7 @@ class SPEECHCOMMANDS(Dataset):
"""Create a Dataset for Speech Commands. """Create a Dataset for Speech Commands.
Args: Args:
root (str): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from, url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload. or the type of the dataset to dowload.
Allowed type values are ``"speech_commands_v0.01"`` and ``"speech_commands_v0.02"`` Allowed type values are ``"speech_commands_v0.01"`` and ``"speech_commands_v0.02"``
...@@ -64,7 +65,7 @@ class SPEECHCOMMANDS(Dataset): ...@@ -64,7 +65,7 @@ class SPEECHCOMMANDS(Dataset):
""" """
def __init__(self, def __init__(self,
root: str, root: Union[str, Path],
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False, download: bool = False,
...@@ -85,6 +86,9 @@ class SPEECHCOMMANDS(Dataset): ...@@ -85,6 +86,9 @@ class SPEECHCOMMANDS(Dataset):
url = os.path.join(base_url, url + ext_archive) url = os.path.join(base_url, url + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(root, basename) archive = os.path.join(root, basename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment