Unverified Commit b34bc7d3 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Add SpeechCommands train/valid/test split (#966)

parent 51e77964
......@@ -53,9 +53,12 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
root_dir = None
samples = []
train_samples = []
valid_samples = []
test_samples = []
@classmethod
def setUp(cls):
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join(
cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL
......@@ -63,10 +66,13 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz sample rate
seed = 0
valid_file = os.path.join(dataset_dir, "validation_list.txt")
test_file = os.path.join(dataset_dir, "testing_list.txt")
with open(valid_file, "w") as valid, open(test_file, "w") as test:
for label in LABELS:
path = os.path.join(dataset_dir, label)
os.makedirs(path, exist_ok=True)
for j in range(2):
for j in range(6):
# generate hash ID for speaker
speaker = "{:08x}".format(j)
......@@ -90,10 +96,17 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
utterance,
)
cls.samples.append(sample)
if j < 2:
cls.train_samples.append(sample)
elif j < 4:
valid.write(f'{label}/{filename}\n')
cls.valid_samples.append(sample)
elif j < 6:
test.write(f'{label}/{filename}\n')
cls.test_samples.append(sample)
def testSpeechCommands(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir)
print(dataset._path)
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
......@@ -107,3 +120,75 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
num_samples += 1
assert num_samples == len(self.samples)
def testSpeechCommandsNone(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset=None)
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[i][1]
assert label == self.samples[i][2]
assert speaker_id == self.samples[i][3]
assert utterance_number == self.samples[i][4]
num_samples += 1
assert num_samples == len(self.samples)
def testSpeechCommandsSubsetTrain(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.train_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.train_samples[i][1]
assert label == self.train_samples[i][2]
assert speaker_id == self.train_samples[i][3]
assert utterance_number == self.train_samples[i][4]
num_samples += 1
assert num_samples == len(self.train_samples)
def testSpeechCommandsSubsetValid(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.valid_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.valid_samples[i][1]
assert label == self.valid_samples[i][2]
assert speaker_id == self.valid_samples[i][3]
assert utterance_number == self.valid_samples[i][4]
num_samples += 1
assert num_samples == len(self.valid_samples)
def testSpeechCommandsSubsetTest(self):
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
):
self.assertEqual(data, self.test_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.test_samples[i][1]
assert label == self.test_samples[i][2]
assert speaker_id == self.test_samples[i][3]
assert utterance_number == self.test_samples[i][4]
num_samples += 1
assert num_samples == len(self.test_samples)
def testSpeechCommandsSum(self):
dataset_all = speechcommands.SPEECHCOMMANDS(self.root_dir)
dataset_train = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")
dataset_valid = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
dataset_test = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")
assert len(dataset_train) + len(dataset_valid) + len(dataset_test) == len(dataset_all)
import os
from typing import Tuple
from typing import Tuple, Optional
import torchaudio
from torch.utils.data import Dataset
......@@ -22,6 +22,15 @@ _CHECKSUMS = {
}
def _load_list(root, *filenames):
output = []
for filename in filenames:
filepath = os.path.join(root, filename)
with open(filepath) as fileobj:
output += [os.path.normpath(os.path.join(root, line.strip())) for line in fileobj]
return output
def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]:
relpath = os.path.relpath(filepath, path)
label, filename = os.path.split(relpath)
......@@ -48,13 +57,25 @@ class SPEECHCOMMANDS(Dataset):
The top-level directory of the dataset. (default: ``"SpeechCommands"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
subset (Optional[str]):
Select a subset of the dataset [None, "training", "validation", "testing"]. None means
the whole dataset. "validation" and "testing" are defined in "validation_list.txt" and
"testing_list.txt", respectively, and "training" is the rest. (default: ``None``)
"""
def __init__(self,
root: str,
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
download: bool = False,
subset: Optional[str] = None,
) -> None:
assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from "
+ "{'training', 'validation', 'testing'}."
)
if url in [
"speech_commands_v0.01",
"speech_commands_v0.02",
......@@ -79,9 +100,22 @@ class SPEECHCOMMANDS(Dataset):
download_url(url, root, hash_value=checksum, hash_type="md5")
extract_archive(archive, self._path)
if subset == "validation":
self._walker = _load_list(self._path, "validation_list.txt")
elif subset == "testing":
self._walker = _load_list(self._path, "testing_list.txt")
elif subset == "training":
excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt"))
walker = walk_files(self._path, suffix=".wav", prefix=True)
self._walker = [
w for w in walker
if HASH_DIVIDER in w
and EXCEPT_FOLDER not in w
and os.path.normpath(w) not in excludes
]
else:
walker = walk_files(self._path, suffix=".wav", prefix=True)
walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
self._walker = list(walker)
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
"""Load the n-th sample from the dataset.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment