Unverified Commit 102174e9 authored by moto's avatar moto Committed by GitHub
Browse files

Generate YESNO dataset on-the-fly for test (#792)

parent 02b898ff
......@@ -29,8 +29,8 @@ BACKENDS_MP3 = _filter_backends_with_mp3(BACKENDS)
def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
if 'sox_io' in BACKENDS:
be = 'sox_io'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
......
......@@ -15,16 +15,16 @@ class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
@property
def base_temp_dir(self):
@classmethod
def get_base_temp_dir(cls):
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
return os.environ[key]
if self.__class__.temp_dir_ is None:
self.__class__.temp_dir_ = tempfile.TemporaryDirectory()
return self.__class__.temp_dir_.name
if cls.temp_dir_ is None:
cls.temp_dir_ = tempfile.TemporaryDirectory()
return cls.temp_dir_.name
@classmethod
def tearDownClass(cls):
......@@ -34,7 +34,7 @@ class TempDirMixin:
cls.temp_dir_ = None
def get_temp_path(self, *paths):
temp_dir = os.path.join(self.base_temp_dir, self.id())
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
path = os.path.join(temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
......
import os
import unittest
from torchaudio.datasets.commonvoice import COMMONVOICE
......@@ -10,16 +11,19 @@ from torchaudio.datasets.ljspeech import LJSPEECH
from torchaudio.datasets.gtzan import GTZAN
from torchaudio.datasets.cmuarctic import CMUARCTIC
from . import common_utils
from .common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_asset_path,
get_whitenoise,
save_wav,
normalize_wav,
)
class TestDatasets(common_utils.TorchaudioTestCase):
class TestDatasets(TorchaudioTestCase):
backend = 'default'
path = common_utils.get_asset_path()
def test_yesno(self):
data = YESNO(self.path)
data[0]
path = get_asset_path()
def test_vctk(self):
data = VCTK(self.path)
......@@ -46,9 +50,9 @@ class TestDatasets(common_utils.TorchaudioTestCase):
data[0]
class TestCommonVoice(common_utils.TorchaudioTestCase):
class TestCommonVoice(TorchaudioTestCase):
backend = 'default'
path = common_utils.get_asset_path()
path = get_asset_path()
def test_commonvoice(self):
data = COMMONVOICE(self.path, url="tatar")
......@@ -69,5 +73,42 @@ class TestCommonVoice(common_utils.TorchaudioTestCase):
pass
class TestYesNo(TempDirMixin, TorchaudioTestCase):
backend = 'default'
root_dir = None
data = []
labels = [
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
]
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
base_dir = os.path.join(cls.root_dir, 'waves_yesno')
os.makedirs(base_dir, exist_ok=True)
for label in cls.labels:
filename = f'{"_".join(str(l) for l in label)}.wav'
path = os.path.join(base_dir, filename)
data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16')
save_wav(path, data, 8000)
cls.data.append(normalize_wav(data))
def test_yesno(self):
dataset = YESNO(self.root_dir)
samples = list(dataset)
samples.sort(key=lambda s: s[2])
for i, (waveform, sample_rate, label) in enumerate(samples):
expected_label = self.labels[i]
expected_data = self.data[i]
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
assert sample_rate == 8000
assert label == expected_label
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment