Unverified Commit 550b6a30 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Add pathlib support for SPEECHCOMMANDS (#1039)

parent 619da1f2
import os
from pathlib import Path
from torchaudio.datasets import speechcommands
......@@ -105,85 +106,39 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
test.write(f'{label}/{filename}\n')
cls.test_samples.append(sample)
def testSpeechCommands(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir)
def _testSpeechCommands(self, dataset, data_samples):
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[i][1]
assert label == self.samples[i][2]
assert speaker_id == self.samples[i][3]
assert utterance_number == self.samples[i][4]
self.assertEqual(data, data_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == data_samples[i][1]
assert label == data_samples[i][2]
assert speaker_id == data_samples[i][3]
assert utterance_number == data_samples[i][4]
num_samples += 1
assert num_samples == len(self.samples)
assert num_samples == len(data_samples)
def testSpeechCommandsNone(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset=None)
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[i][1]
assert label == self.samples[i][2]
assert speaker_id == self.samples[i][3]
assert utterance_number == self.samples[i][4]
num_samples += 1
def testSpeechCommands_str(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir)
self._testSpeechCommands(dataset, self.samples)
assert num_samples == len(self.samples)
def testSpeechCommands_path(self):
dataset = speechcommands.SPEECHCOMMANDS(Path(self.root_dir))
self._testSpeechCommands(dataset, self.samples)
def testSpeechCommandsSubsetTrain(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.train_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.train_samples[i][1]
assert label == self.train_samples[i][2]
assert speaker_id == self.train_samples[i][3]
assert utterance_number == self.train_samples[i][4]
num_samples += 1
assert num_samples == len(self.train_samples)
self._testSpeechCommands(dataset, self.train_samples)
def testSpeechCommandsSubsetValid(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.valid_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.valid_samples[i][1]
assert label == self.valid_samples[i][2]
assert speaker_id == self.valid_samples[i][3]
assert utterance_number == self.valid_samples[i][4]
num_samples += 1
assert num_samples == len(self.valid_samples)
self._testSpeechCommands(dataset, self.valid_samples)
def testSpeechCommandsSubsetTest(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.test_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.test_samples[i][1]
assert label == self.test_samples[i][2]
assert speaker_id == self.test_samples[i][3]
assert utterance_number == self.test_samples[i][4]
num_samples += 1
assert num_samples == len(self.test_samples)
self._testSpeechCommands(dataset, self.test_samples)
def testSpeechCommandsSum(self):
dataset_all = speechcommands.SPEECHCOMMANDS(self.root_dir)
......
import os
from typing import Tuple, Optional
from typing import Tuple, Optional, Union
from pathlib import Path
import torchaudio
from torch.utils.data import Dataset
......@@ -48,7 +49,7 @@ class SPEECHCOMMANDS(Dataset):
"""Create a Dataset for Speech Commands.
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload.
Allowed type values are ``"speech_commands_v0.01"`` and ``"speech_commands_v0.02"``
......@@ -64,7 +65,7 @@ class SPEECHCOMMANDS(Dataset):
"""
def __init__(self,
root: str,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
......@@ -85,6 +86,9 @@ class SPEECHCOMMANDS(Dataset):
url = os.path.join(base_url, url + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment