Unverified Commit 5bf6b146 authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Refactor speechcommands unittest (#1136)

parent 70fd2f3d
import os import os
from pathlib import Path from pathlib import Path
from torchaudio.datasets import speechcommands
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
...@@ -11,7 +9,9 @@ from torchaudio_unittest.common_utils import ( ...@@ -11,7 +9,9 @@ from torchaudio_unittest.common_utils import (
save_wav, save_wav,
) )
LABELS = [ from torchaudio.datasets import speechcommands
_LABELS = [
"bed", "bed",
"bird", "bird",
"cat", "cat",
...@@ -49,28 +49,21 @@ LABELS = [ ...@@ -49,28 +49,21 @@ LABELS = [
] ]
class TestSpeechCommands(TempDirMixin, TorchaudioTestCase): def get_mock_dataset(dataset_dir):
backend = "default" """
dataset_dir: directory to the mocked dataset
root_dir = None """
samples = [] mocked_samples = []
train_samples = [] mocked_train_samples = []
valid_samples = [] mocked_valid_samples = []
test_samples = [] mocked_test_samples = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join(
cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL
)
os.makedirs(dataset_dir, exist_ok=True) os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz sample rate sample_rate = 16000 # 16kHz sample rate
seed = 0 seed = 0
valid_file = os.path.join(dataset_dir, "validation_list.txt") valid_file = os.path.join(dataset_dir, "validation_list.txt")
test_file = os.path.join(dataset_dir, "testing_list.txt") test_file = os.path.join(dataset_dir, "testing_list.txt")
with open(valid_file, "w") as valid, open(test_file, "w") as test: with open(valid_file, "w") as valid, open(test_file, "w") as test:
for label in LABELS: for label in _LABELS:
path = os.path.join(dataset_dir, label) path = os.path.join(dataset_dir, label)
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
for j in range(6): for j in range(6):
...@@ -96,15 +89,34 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase): ...@@ -96,15 +89,34 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
speaker, speaker,
utterance, utterance,
) )
cls.samples.append(sample) mocked_samples.append(sample)
if j < 2: if j < 2:
cls.train_samples.append(sample) mocked_train_samples.append(sample)
elif j < 4: elif j < 4:
valid.write(f'{label}/{filename}\n') valid.write(f'{label}/{filename}\n')
cls.valid_samples.append(sample) mocked_valid_samples.append(sample)
elif j < 6: elif j < 6:
test.write(f'{label}/{filename}\n') test.write(f'{label}/{filename}\n')
cls.test_samples.append(sample) mocked_test_samples.append(sample)
return mocked_samples, mocked_train_samples, mocked_valid_samples, mocked_test_samples
class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
samples = []
train_samples = []
valid_samples = []
test_samples = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
dataset_dir = os.path.join(
cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL
)
cls.samples, cls.train_samples, cls.valid_samples, cls.test_samples = get_mock_dataset(dataset_dir)
def _testSpeechCommands(self, dataset, data_samples): def _testSpeechCommands(self, dataset, data_samples):
num_samples = 0 num_samples = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment