Unverified Commit 68f6a6a0 authored by moto's avatar moto Committed by GitHub
Browse files

Make GTZAN dataset sorted and use on-the-fly data in GTZAN test (#819)

parent 3cdcd7ba
......@@ -4,7 +4,6 @@ from torchaudio.datasets.speechcommands import SPEECHCOMMANDS
from torchaudio.datasets.utils import diskcache_iterator, bg_iterator
from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.ljspeech import LJSPEECH
from torchaudio.datasets.gtzan import GTZAN
from torchaudio.datasets.cmuarctic import CMUARCTIC
from ..common_utils import (
......@@ -33,10 +32,6 @@ class TestDatasets(TorchaudioTestCase):
data = SPEECHCOMMANDS(self.path)
data[0]
def test_gtzan(self):
data = GTZAN(self.path)
data[0]
def test_cmuarctic(self):
data = CMUARCTIC(self.path)
data[0]
......
import os
from torchaudio.datasets import gtzan
from ..common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_whitenoise,
save_wav,
normalize_wav,
)
class TestGTZAN(TempDirMixin, TorchaudioTestCase):
backend = 'default'
root_dir = None
samples = []
training = []
validation = []
testing = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
sample_rate = 22050
seed = 0
for genre in gtzan.gtzan_genres:
base_dir = os.path.join(cls.root_dir, 'genres', genre)
os.makedirs(base_dir, exist_ok=True)
for i in range(100):
filename = f'{genre}.{i:05d}'
path = os.path.join(base_dir, f'{filename}.wav')
data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed)
save_wav(path, data, sample_rate)
sample = (normalize_wav(data), sample_rate, genre)
cls.samples.append(sample)
if filename in gtzan.filtered_test:
cls.testing.append(sample)
if filename in gtzan.filtered_train:
cls.training.append(sample)
if filename in gtzan.filtered_valid:
cls.validation.append(sample)
seed += 1
def test_no_subset(self):
dataset = gtzan.GTZAN(self.root_dir)
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[i][1]
assert label == self.samples[i][2]
n_ite += 1
assert n_ite == len(self.samples)
def test_training(self):
dataset = gtzan.GTZAN(self.root_dir, subset='training')
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.training[i][1]
assert label == self.training[i][2]
n_ite += 1
assert n_ite == len(self.training)
def test_validation(self):
dataset = gtzan.GTZAN(self.root_dir, subset='validation')
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.validation[i][1]
assert label == self.validation[i][2]
n_ite += 1
assert n_ite == len(self.validation)
def test_testing(self):
dataset = gtzan.GTZAN(self.root_dir, subset='testing')
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.testing[i][1]
assert label == self.testing[i][2]
n_ite += 1
assert n_ite == len(self.testing)
......@@ -1064,6 +1064,7 @@ class GTZAN(Dataset):
continue
songs_in_genre = os.listdir(fulldir)
songs_in_genre.sort()
for fname in songs_in_genre:
name, ext = os.path.splitext(fname)
if ext.lower() == ".wav" and "." in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment