Unverified Commit 5bf6b146 authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Refactor speechcommands unittest (#1136)

parent 70fd2f3d
import os
from pathlib import Path
from torchaudio.datasets import speechcommands
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
......@@ -11,7 +9,9 @@ from torchaudio_unittest.common_utils import (
save_wav,
)
LABELS = [
from torchaudio.datasets import speechcommands
_LABELS = [
"bed",
"bird",
"cat",
......@@ -49,6 +49,58 @@ LABELS = [
]
def get_mock_dataset(dataset_dir):
"""
dataset_dir: directory to the mocked dataset
"""
mocked_samples = []
mocked_train_samples = []
mocked_valid_samples = []
mocked_test_samples = []
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz sample rate
seed = 0
valid_file = os.path.join(dataset_dir, "validation_list.txt")
test_file = os.path.join(dataset_dir, "testing_list.txt")
with open(valid_file, "w") as valid, open(test_file, "w") as test:
for label in _LABELS:
path = os.path.join(dataset_dir, label)
os.makedirs(path, exist_ok=True)
for j in range(6):
# generate hash ID for speaker
speaker = "{:08x}".format(j)
for utterance in range(3):
filename = f"{speaker}{speechcommands.HASH_DIVIDER}{utterance}.wav"
file_path = os.path.join(path, filename)
seed += 1
data = get_whitenoise(
sample_rate=sample_rate,
duration=0.01,
n_channels=1,
dtype="int16",
seed=seed,
)
save_wav(file_path, data, sample_rate)
sample = (
normalize_wav(data),
sample_rate,
label,
speaker,
utterance,
)
mocked_samples.append(sample)
if j < 2:
mocked_train_samples.append(sample)
elif j < 4:
valid.write(f'{label}/{filename}\n')
mocked_valid_samples.append(sample)
elif j < 6:
test.write(f'{label}/{filename}\n')
mocked_test_samples.append(sample)
return mocked_samples, mocked_train_samples, mocked_valid_samples, mocked_test_samples
class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
backend = "default"
......@@ -64,52 +116,12 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
dataset_dir = os.path.join(
cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL
)
os.makedirs(dataset_dir, exist_ok=True)
sample_rate = 16000 # 16kHz sample rate
seed = 0
valid_file = os.path.join(dataset_dir, "validation_list.txt")
test_file = os.path.join(dataset_dir, "testing_list.txt")
with open(valid_file, "w") as valid, open(test_file, "w") as test:
for label in LABELS:
path = os.path.join(dataset_dir, label)
os.makedirs(path, exist_ok=True)
for j in range(6):
# generate hash ID for speaker
speaker = "{:08x}".format(j)
for utterance in range(3):
filename = f"{speaker}{speechcommands.HASH_DIVIDER}{utterance}.wav"
file_path = os.path.join(path, filename)
seed += 1
data = get_whitenoise(
sample_rate=sample_rate,
duration=0.01,
n_channels=1,
dtype="int16",
seed=seed,
)
save_wav(file_path, data, sample_rate)
sample = (
normalize_wav(data),
sample_rate,
label,
speaker,
utterance,
)
cls.samples.append(sample)
if j < 2:
cls.train_samples.append(sample)
elif j < 4:
valid.write(f'{label}/{filename}\n')
cls.valid_samples.append(sample)
elif j < 6:
test.write(f'{label}/{filename}\n')
cls.test_samples.append(sample)
cls.samples, cls.train_samples, cls.valid_samples, cls.test_samples = get_mock_dataset(dataset_dir)
def _testSpeechCommands(self, dataset, data_samples):
num_samples = 0
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
dataset
dataset
):
self.assertEqual(data, data_samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == data_samples[i][1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment