Unverified Commit 5bf6b146 authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Refactor speechcommands unittest (#1136)

parent 70fd2f3d
import os
from pathlib import Path
from torchaudio.datasets import speechcommands
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
......@@ -11,7 +9,9 @@ from torchaudio_unittest.common_utils import (
save_wav,
)
LABELS = [
from torchaudio.datasets import speechcommands
_LABELS = [
"bed",
"bird",
"cat",
......@@ -49,28 +49,21 @@ LABELS = [
]
class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
samples = []
train_samples = []
valid_samples = []
test_samples = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join(
cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL
)
def get_mock_dataset(dataset_dir):
"""
dataset_dir: directory to the mocked dataset
"""
mocked_samples = []
mocked_train_samples = []
mocked_valid_samples = []
mocked_test_samples = []
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz sample rate
seed = 0
valid_file = os.path.join(dataset_dir, "validation_list.txt")
test_file = os.path.join(dataset_dir, "testing_list.txt")
with open(valid_file, "w") as valid, open(test_file, "w") as test:
for label in LABELS:
for label in _LABELS:
path = os.path.join(dataset_dir, label)
os.makedirs(path, exist_ok=True)
for j in range(6):
......@@ -96,15 +89,34 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
speaker,
utterance,
)
cls.samples.append(sample)
mocked_samples.append(sample)
if j < 2:
cls.train_samples.append(sample)
mocked_train_samples.append(sample)
elif j < 4:
valid.write(f'{label}/{filename}\n')
cls.valid_samples.append(sample)
mocked_valid_samples.append(sample)
elif j < 6:
test.write(f'{label}/{filename}\n')
cls.test_samples.append(sample)
mocked_test_samples.append(sample)
return mocked_samples, mocked_train_samples, mocked_valid_samples, mocked_test_samples
class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
samples = []
train_samples = []
valid_samples = []
test_samples = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join(
cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL
)
cls.samples, cls.train_samples, cls.valid_samples, cls.test_samples = get_mock_dataset(dataset_dir)
def _testSpeechCommands(self, dataset, data_samples):
num_samples = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment