Unverified Commit 3cdcd7ba authored by moto's avatar moto Committed by GitHub
Browse files

Refactor datasets test (#817)

parent 0406d30d
from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.speechcommands import SPEECHCOMMANDS
from torchaudio.datasets.utils import diskcache_iterator, bg_iterator
from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.ljspeech import LJSPEECH
from torchaudio.datasets.gtzan import GTZAN
from torchaudio.datasets.cmuarctic import CMUARCTIC
from ..common_utils import (
TorchaudioTestCase,
get_asset_path,
)
class TestDatasets(TorchaudioTestCase):
backend = 'default'
path = get_asset_path()
def test_vctk(self):
data = VCTK(self.path)
data[0]
def test_librispeech(self):
data = LIBRISPEECH(self.path, "dev-clean")
data[0]
def test_ljspeech(self):
data = LJSPEECH(self.path)
data[0]
def test_speechcommands(self):
data = SPEECHCOMMANDS(self.path)
data[0]
def test_gtzan(self):
data = GTZAN(self.path)
data[0]
def test_cmuarctic(self):
data = CMUARCTIC(self.path)
data[0]
class TestCommonVoice(TorchaudioTestCase):
backend = 'default'
path = get_asset_path()
def test_commonvoice(self):
data = COMMONVOICE(self.path, url="tatar")
data[0]
def test_commonvoice_diskcache(self):
data = COMMONVOICE(self.path, url="tatar")
data = diskcache_iterator(data)
# Save
data[0]
# Load
data[0]
def test_commonvoice_bg(self):
data = COMMONVOICE(self.path, url="tatar")
data = bg_iterator(data, 5)
for _ in data:
pass
import os import os
import unittest
from torchaudio.datasets.commonvoice import COMMONVOICE
from torchaudio.datasets.librispeech import LIBRISPEECH
from torchaudio.datasets.speechcommands import SPEECHCOMMANDS
from torchaudio.datasets.utils import diskcache_iterator, bg_iterator
from torchaudio.datasets.vctk import VCTK
from torchaudio.datasets.yesno import YESNO
from torchaudio.datasets.ljspeech import LJSPEECH
from torchaudio.datasets.gtzan import GTZAN
from torchaudio.datasets.cmuarctic import CMUARCTIC
from torchaudio.datasets.libritts import LIBRITTS from torchaudio.datasets.libritts import LIBRITTS
from .common_utils import ( from ..common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
get_asset_path,
get_whitenoise, get_whitenoise,
save_wav, save_wav,
normalize_wav, normalize_wav,
) )
class TestDatasets(TorchaudioTestCase):
backend = 'default'
path = get_asset_path()
def test_vctk(self):
data = VCTK(self.path)
data[0]
def test_librispeech(self):
data = LIBRISPEECH(self.path, "dev-clean")
data[0]
def test_ljspeech(self):
data = LJSPEECH(self.path)
data[0]
def test_speechcommands(self):
data = SPEECHCOMMANDS(self.path)
data[0]
def test_gtzan(self):
data = GTZAN(self.path)
data[0]
def test_cmuarctic(self):
data = CMUARCTIC(self.path)
data[0]
class TestCommonVoice(TorchaudioTestCase):
backend = 'default'
path = get_asset_path()
def test_commonvoice(self):
data = COMMONVOICE(self.path, url="tatar")
data[0]
def test_commonvoice_diskcache(self):
data = COMMONVOICE(self.path, url="tatar")
data = diskcache_iterator(data)
# Save
data[0]
# Load
data[0]
def test_commonvoice_bg(self):
data = COMMONVOICE(self.path, url="tatar")
data = bg_iterator(data, 5)
for _ in data:
pass
class TestYesNo(TempDirMixin, TorchaudioTestCase):
backend = 'default'
root_dir = None
data = []
labels = [
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
]
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
base_dir = os.path.join(cls.root_dir, 'waves_yesno')
os.makedirs(base_dir, exist_ok=True)
for label in cls.labels:
filename = f'{"_".join(str(l) for l in label)}.wav'
path = os.path.join(base_dir, filename)
data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16')
save_wav(path, data, 8000)
cls.data.append(normalize_wav(data))
def test_yesno(self):
dataset = YESNO(self.root_dir)
samples = list(dataset)
samples.sort(key=lambda s: s[2])
for i, (waveform, sample_rate, label) in enumerate(samples):
expected_label = self.labels[i]
expected_data = self.data[i]
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
assert sample_rate == 8000
assert label == expected_label
class TestLibriTTS(TempDirMixin, TorchaudioTestCase): class TestLibriTTS(TempDirMixin, TorchaudioTestCase):
backend = 'default' backend = 'default'
...@@ -139,15 +39,13 @@ class TestLibriTTS(TempDirMixin, TorchaudioTestCase): ...@@ -139,15 +39,13 @@ class TestLibriTTS(TempDirMixin, TorchaudioTestCase):
original_text_filename = f'{"_".join(str(u) for u in utterance_id)}.original.txt' original_text_filename = f'{"_".join(str(u) for u in utterance_id)}.original.txt'
path_original = os.path.join(file_dir, original_text_filename) path_original = os.path.join(file_dir, original_text_filename)
f = open(path_original, 'w') with open(path_original, 'w') as file_:
f.write(cls.original_text) file_.write(cls.original_text)
f.close()
normalized_text_filename = f'{"_".join(str(u) for u in utterance_id)}.normalized.txt' normalized_text_filename = f'{"_".join(str(u) for u in utterance_id)}.normalized.txt'
path_normalized = os.path.join(file_dir, normalized_text_filename) path_normalized = os.path.join(file_dir, normalized_text_filename)
f = open(path_normalized, 'w') with open(path_normalized, 'w') as file_:
f.write(cls.normalized_text) file_.write(cls.normalized_text)
f.close()
def test_libritts(self): def test_libritts(self):
dataset = LIBRITTS(self.root_dir) dataset = LIBRITTS(self.root_dir)
......
import os
from torchaudio.datasets import yesno
from ..common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_whitenoise,
save_wav,
normalize_wav,
)
class TestYesNo(TempDirMixin, TorchaudioTestCase):
backend = 'default'
root_dir = None
data = []
labels = [
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1],
]
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
base_dir = os.path.join(cls.root_dir, 'waves_yesno')
os.makedirs(base_dir, exist_ok=True)
for label in cls.labels:
filename = f'{"_".join(str(l) for l in label)}.wav'
path = os.path.join(base_dir, filename)
data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16')
save_wav(path, data, 8000)
cls.data.append(normalize_wav(data))
def test_yesno(self):
dataset = yesno.YESNO(self.root_dir)
samples = list(dataset)
samples.sort(key=lambda s: s[2])
for i, (waveform, sample_rate, label) in enumerate(samples):
expected_label = self.labels[i]
expected_data = self.data[i]
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
assert sample_rate == 8000
assert label == expected_label
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment