Unverified Commit 2067d034 authored by Krishna Kalyan's avatar Krishna Kalyan Committed by GitHub
Browse files

Refactor GTZAN unittest (#1148)


Co-authored-by: default avatarkrishnakalyan3 <skalyan@cloudera.com>
parent 6edb3355
...@@ -12,6 +12,37 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,6 +12,37 @@ from torchaudio_unittest.common_utils import (
) )
def get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""
mocked_samples = []
mocked_training = []
mocked_validation = []
mocked_testing = []
sample_rate = 22050
seed = 0
for genre in gtzan.gtzan_genres:
base_dir = os.path.join(root_dir, 'genres', genre)
os.makedirs(base_dir, exist_ok=True)
for i in range(100):
filename = f'{genre}.{i:05d}'
path = os.path.join(base_dir, f'{filename}.wav')
data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed)
save_wav(path, data, sample_rate)
sample = (normalize_wav(data), sample_rate, genre)
mocked_samples.append(sample)
if filename in gtzan.filtered_test:
mocked_testing.append(sample)
if filename in gtzan.filtered_train:
mocked_training.append(sample)
if filename in gtzan.filtered_valid:
mocked_validation.append(sample)
seed += 1
return (mocked_samples, mocked_training, mocked_validation, mocked_testing)
class TestGTZAN(TempDirMixin, TorchaudioTestCase): class TestGTZAN(TempDirMixin, TorchaudioTestCase):
backend = 'default' backend = 'default'
...@@ -24,25 +55,11 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase): ...@@ -24,25 +55,11 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir() cls.root_dir = cls.get_base_temp_dir()
sample_rate = 22050 mocked_data = get_mock_dataset(cls.root_dir)
seed = 0 cls.samples = mocked_data[0]
for genre in gtzan.gtzan_genres: cls.training = mocked_data[1]
base_dir = os.path.join(cls.root_dir, 'genres', genre) cls.validation = mocked_data[2]
os.makedirs(base_dir, exist_ok=True) cls.testing = mocked_data[3]
for i in range(100):
filename = f'{genre}.{i:05d}'
path = os.path.join(base_dir, f'{filename}.wav')
data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed)
save_wav(path, data, sample_rate)
sample = (normalize_wav(data), sample_rate, genre)
cls.samples.append(sample)
if filename in gtzan.filtered_test:
cls.testing.append(sample)
if filename in gtzan.filtered_train:
cls.training.append(sample)
if filename in gtzan.filtered_valid:
cls.validation.append(sample)
seed += 1
def test_no_subset(self): def test_no_subset(self):
dataset = gtzan.GTZAN(self.root_dir) dataset = gtzan.GTZAN(self.root_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment