Unverified Commit 0e5581cb authored by moto's avatar moto Committed by GitHub
Browse files

Simplify and abstract away asset access in test (#542)

This PR aims the following things;
1. Introduce and adopt helper function `get_asset_path` that abstract the logic to construct path to test assets.
2. Remove `create_temp_assets_dir` anywhere except `test_io`.

The benefits of doing so are,
a. the test code becomes simpler (no manual construction of asset path with `os.path.join`)
b. No unnecessary directory creation and file copies.

For 2. and b. tests in `test_io.py` (or tests that use `torchaudio.save`) are the only tests that need to write file to the disc, where the use of temporary directory makes it cleaner, therefore, `create_temp_assets_dir` is not necessary elsewhere. (still, `test_io` does not need to copy the entire asset directory, but that's not the point here.)

Also if any test is accidentally overwriting an asset data, not using a copy will make us aware of such behavior, so it is better to get rid of `create_temp_assets_dir`.
parent 0fa07595
...@@ -6,10 +6,15 @@ from shutil import copytree ...@@ -6,10 +6,15 @@ from shutil import copytree
import torch import torch
import torchaudio import torchaudio
TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) _TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends BACKENDS = torchaudio._backend._audio_backends
def get_asset_path(*paths):
"""Return full path of a test asset"""
return os.path.join(_TEST_DIR_PATH, 'assets', *paths)
def create_temp_assets_dir(): def create_temp_assets_dir():
""" """
Creates a temporary directory and moves all files from test/assets there. Creates a temporary directory and moves all files from test/assets there.
...@@ -17,7 +22,7 @@ def create_temp_assets_dir(): ...@@ -17,7 +22,7 @@ def create_temp_assets_dir():
and object. and object.
""" """
tmp_dir = tempfile.TemporaryDirectory() tmp_dir = tempfile.TemporaryDirectory()
copytree(os.path.join(TEST_DIR_PATH, "assets"), copytree(os.path.join(_TEST_DIR_PATH, "assets"),
os.path.join(tmp_dir.name, "assets")) os.path.join(tmp_dir.name, "assets"))
return tmp_dir.name, tmp_dir return tmp_dir.name, tmp_dir
...@@ -65,11 +70,7 @@ def AudioBackendScope(new_backend): ...@@ -65,11 +70,7 @@ def AudioBackendScope(new_backend):
def filter_backends_with_mp3(backends): def filter_backends_with_mp3(backends):
# Filter out backends that do not support mp3 # Filter out backends that do not support mp3
test_filepath = get_asset_path('steam-train-whistle-daniel_simon.mp3')
test_dirpath, _ = create_temp_assets_dir()
test_filepath = os.path.join(
test_dirpath, "assets", "steam-train-whistle-daniel_simon.mp3"
)
def supports_mp3(backend): def supports_mp3(backend):
try: try:
......
"""Test numerical consistency among single input and batched input.""" """Test numerical consistency among single input and batched input."""
import os
import unittest import unittest
import torch import torch
...@@ -54,7 +53,7 @@ class TestFunctional(unittest.TestCase): ...@@ -54,7 +53,7 @@ class TestFunctional(unittest.TestCase):
'440Hz_44100Hz_16bit_05sec.wav', # 1ch '440Hz_44100Hz_16bit_05sec.wav', # 1ch
] ]
for filename in filenames: for filename in filenames:
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', filename) filepath = common_utils.get_asset_path(filename)
waveform, sample_rate = torchaudio.load(filepath) waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.detect_pitch_frequency, waveform, sample_rate) _test_batch(F.detect_pitch_frequency, waveform, sample_rate)
...@@ -133,8 +132,7 @@ class TestTransforms(unittest.TestCase): ...@@ -133,8 +132,7 @@ class TestTransforms(unittest.TestCase):
torch.testing.assert_allclose(computed, expected) torch.testing.assert_allclose(computed, expected)
def test_batch_mulaw(self): def test_batch_mulaw(self):
test_filepath = os.path.join( test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
# Single then transform then batch # Single then transform then batch
...@@ -159,8 +157,7 @@ class TestTransforms(unittest.TestCase): ...@@ -159,8 +157,7 @@ class TestTransforms(unittest.TestCase):
torch.testing.assert_allclose(computed, expected) torch.testing.assert_allclose(computed, expected)
def test_batch_spectrogram(self): def test_batch_spectrogram(self):
test_filepath = os.path.join( test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
# Single then transform then batch # Single then transform then batch
...@@ -171,8 +168,7 @@ class TestTransforms(unittest.TestCase): ...@@ -171,8 +168,7 @@ class TestTransforms(unittest.TestCase):
torch.testing.assert_allclose(computed, expected) torch.testing.assert_allclose(computed, expected)
def test_batch_melspectrogram(self): def test_batch_melspectrogram(self):
test_filepath = os.path.join( test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
# Single then transform then batch # Single then transform then batch
...@@ -185,8 +181,7 @@ class TestTransforms(unittest.TestCase): ...@@ -185,8 +181,7 @@ class TestTransforms(unittest.TestCase):
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
def test_batch_mfcc(self): def test_batch_mfcc(self):
test_filepath = os.path.join( test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.mp3')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3')
waveform, _ = torchaudio.load(test_filepath) waveform, _ = torchaudio.load(test_filepath)
# Single then transform then batch # Single then transform then batch
...@@ -197,8 +192,7 @@ class TestTransforms(unittest.TestCase): ...@@ -197,8 +192,7 @@ class TestTransforms(unittest.TestCase):
torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5) torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5)
def test_batch_TimeStretch(self): def test_batch_TimeStretch(self):
test_filepath = os.path.join( test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
kwargs = { kwargs = {
...@@ -232,8 +226,7 @@ class TestTransforms(unittest.TestCase): ...@@ -232,8 +226,7 @@ class TestTransforms(unittest.TestCase):
torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5) torch.testing.assert_allclose(computed, expected, atol=1e-5, rtol=1e-5)
def test_batch_Fade(self): def test_batch_Fade(self):
test_filepath = os.path.join( test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
fade_in_len = 3000 fade_in_len = 3000
fade_out_len = 3000 fade_out_len = 3000
...@@ -246,8 +239,7 @@ class TestTransforms(unittest.TestCase): ...@@ -246,8 +239,7 @@ class TestTransforms(unittest.TestCase):
torch.testing.assert_allclose(computed, expected) torch.testing.assert_allclose(computed, expected)
def test_batch_Vol(self): def test_batch_Vol(self):
test_filepath = os.path.join( test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100 waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
# Single then transform then batch # Single then transform then batch
......
...@@ -5,7 +5,9 @@ import torch ...@@ -5,7 +5,9 @@ import torch
import torchaudio import torchaudio
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
import unittest import unittest
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
import common_utils
from common_utils import AudioBackendScope, BACKENDS
def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
...@@ -45,10 +47,9 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): ...@@ -45,10 +47,9 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
class Test_Kaldi(unittest.TestCase): class Test_Kaldi(unittest.TestCase):
test_dirpath, test_dir = create_temp_assets_dir() test_filepath = common_utils.get_asset_path('kaldi_file.wav')
test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav') test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
test_8000_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file_8000.wav') kaldi_output_dir = common_utils.get_asset_path('kaldi')
kaldi_output_dir = os.path.join(test_dirpath, 'assets', 'kaldi')
test_filepaths = {prefix: [] for prefix in compliance.utils.TEST_PREFIX} test_filepaths = {prefix: [] for prefix in compliance.utils.TEST_PREFIX}
# separating test files by their types (e.g 'spec', 'fbank', etc.) # separating test files by their types (e.g 'spec', 'fbank', etc.)
...@@ -90,8 +91,7 @@ class Test_Kaldi(unittest.TestCase): ...@@ -90,8 +91,7 @@ class Test_Kaldi(unittest.TestCase):
def _create_data_set(self): def _create_data_set(self):
# used to generate the dataset to test on. this is not used in testing (offline procedure) # used to generate the dataset to test on. this is not used in testing (offline procedure)
test_dirpath = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) test_filepath = common_utils.get_asset_path('kaldi_file.wav')
test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
sr = 16000 sr = 16000
x = torch.arange(0, 20).float() x = torch.arange(0, 20).float()
# between [-6,6] # between [-6,6]
......
import unittest import unittest
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchaudio import torchaudio
import math from torch.utils.data import Dataset, DataLoader
import os
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir import common_utils
from common_utils import AudioBackendScope, BACKENDS
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
class TORCHAUDIODS(Dataset): class TORCHAUDIODS(Dataset):
test_dirpath, test_dir = create_temp_assets_dir()
def __init__(self): def __init__(self):
self.asset_dirpath = os.path.join(self.test_dirpath, "assets")
sound_files = ["sinewave.wav", "steam-train-whistle-daniel_simon.mp3"] sound_files = ["sinewave.wav", "steam-train-whistle-daniel_simon.mp3"]
self.data = [os.path.join(self.asset_dirpath, fn) for fn in sound_files] self.data = [common_utils.get_asset_path(fn) for fn in sound_files]
self.si, self.ei = torchaudio.info(os.path.join(self.asset_dirpath, "sinewave.wav")) self.si, self.ei = torchaudio.info(common_utils.get_asset_path("sinewave.wav"))
self.si.precision = 16 self.si.precision = 16
self.E = torchaudio.sox_effects.SoxEffectsChain() self.E = torchaudio.sox_effects.SoxEffectsChain()
self.E.append_effect_to_chain("rate", [self.si.rate]) # resample to 16000hz self.E.append_effect_to_chain("rate", [self.si.rate]) # resample to 16000hz
......
import os
import unittest import unittest
from torchaudio.datasets.commonvoice import COMMONVOICE from torchaudio.datasets.commonvoice import COMMONVOICE
...@@ -13,8 +12,7 @@ import common_utils ...@@ -13,8 +12,7 @@ import common_utils
class TestDatasets(unittest.TestCase): class TestDatasets(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir() path = common_utils.get_asset_path()
path = os.path.join(test_dirpath, "assets")
def test_yesno(self): def test_yesno(self):
data = YESNO(self.path) data = YESNO(self.path)
......
import math import math
import os
import unittest import unittest
import torch import torch
...@@ -301,10 +300,8 @@ class TestIstft(unittest.TestCase): ...@@ -301,10 +300,8 @@ class TestIstft(unittest.TestCase):
class TestDetectPitchFrequency(unittest.TestCase): class TestDetectPitchFrequency(unittest.TestCase):
def test_pitch(self): def test_pitch(self):
test_filepath_100 = os.path.join( test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav")
common_utils.TEST_DIR_PATH, 'assets', "100Hz_44100Hz_16bit_05sec.wav") test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', "440Hz_44100Hz_16bit_05sec.wav")
# Files from https://www.mediacollege.com/audio/tone/download/ # Files from https://www.mediacollege.com/audio/tone/download/
tests = [ tests = [
......
import os import unittest
import torch import torch
import torchaudio.kaldi_io as kio import torchaudio.kaldi_io as kio
import unittest
import common_utils import common_utils
class Test_KaldiIO(unittest.TestCase): class Test_KaldiIO(unittest.TestCase):
data1 = [[1, 2, 3], [11, 12, 13], [21, 22, 23]] data1 = [[1, 2, 3], [11, 12, 13], [21, 22, 23]]
data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]] data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]]
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
def _test_helper(self, file_name, expected_data, fn, expected_dtype): def _test_helper(self, file_name, expected_data, fn, expected_dtype):
""" Takes a file_name to the input data and a function fn to extract the """ Takes a file_name to the input data and a function fn to extract the
data. It compares the extracted data to the expected_data. The expected_dtype data. It compares the extracted data to the expected_data. The expected_dtype
will be used to check that the extracted data is of the right type. will be used to check that the extracted data is of the right type.
""" """
test_filepath = os.path.join(self.test_dirpath, "assets", file_name) test_filepath = common_utils.get_asset_path(file_name)
expected_output = {'key' + str(idx + 1): torch.tensor(val, dtype=expected_dtype) expected_output = {'key' + str(idx + 1): torch.tensor(val, dtype=expected_dtype)
for idx, val in enumerate(expected_data)} for idx, val in enumerate(expected_data)}
......
...@@ -149,7 +149,7 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length): ...@@ -149,7 +149,7 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length):
def _load_audio_asset(*asset_paths, **kwargs): def _load_audio_asset(*asset_paths, **kwargs):
file_path = os.path.join(common_utils.TEST_DIR_PATH, 'assets', *asset_paths) file_path = common_utils.get_asset_path(*asset_paths)
sound, sample_rate = torchaudio.load(file_path, **kwargs) sound, sample_rate = torchaudio.load(file_path, **kwargs)
return sound, sample_rate return sound, sample_rate
......
import os
import unittest import unittest
import torch import torch
...@@ -6,16 +5,15 @@ import torchaudio ...@@ -6,16 +5,15 @@ import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
import torchaudio.transforms as T import torchaudio.transforms as T
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir import common_utils
from common_utils import AudioBackendScope, BACKENDS
class TestFunctionalFiltering(unittest.TestCase): class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = create_temp_assets_dir()
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
def test_gain(self): def test_gain(self):
test_filepath = os.path.join(self.test_dirpath, "assets", "steam-train-whistle-daniel_simon.wav") test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) waveform, _ = torchaudio.load(test_filepath)
waveform_gain = F.gain(waveform, 3) waveform_gain = F.gain(waveform, 3)
...@@ -31,7 +29,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -31,7 +29,7 @@ class TestFunctionalFiltering(unittest.TestCase):
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
def test_dither(self): def test_dither(self):
test_filepath = os.path.join(self.test_dirpath, "assets", "steam-train-whistle-daniel_simon.wav") test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) waveform, _ = torchaudio.load(test_filepath)
waveform_dithered = F.dither(waveform) waveform_dithered = F.dither(waveform)
...@@ -53,7 +51,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -53,7 +51,7 @@ class TestFunctionalFiltering(unittest.TestCase):
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
def test_vctk_transform_pipeline(self): def test_vctk_transform_pipeline(self):
test_filepath_vctk = os.path.join(self.test_dirpath, "assets/VCTK-Corpus/wav48/p224/", "p224_002.wav") test_filepath_vctk = common_utils.get_asset_path('VCTK-Corpus', 'wav48', 'p224', 'p224_002.wav')
wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk) wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)
# rate # rate
...@@ -76,14 +74,13 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -76,14 +74,13 @@ class TestFunctionalFiltering(unittest.TestCase):
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
def test_lowpass(self): def test_lowpass(self):
""" """
Test biquad lowpass filter, compare to SoX implementation Test biquad lowpass filter, compare to SoX implementation
""" """
CUTOFF_FREQ = 3000 CUTOFF_FREQ = 3000
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("lowpass", [CUTOFF_FREQ]) E.append_effect_to_chain("lowpass", [CUTOFF_FREQ])
...@@ -103,7 +100,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -103,7 +100,7 @@ class TestFunctionalFiltering(unittest.TestCase):
CUTOFF_FREQ = 2000 CUTOFF_FREQ = 2000
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("highpass", [CUTOFF_FREQ]) E.append_effect_to_chain("highpass", [CUTOFF_FREQ])
...@@ -125,7 +122,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -125,7 +122,7 @@ class TestFunctionalFiltering(unittest.TestCase):
CENTRAL_FREQ = 1000 CENTRAL_FREQ = 1000
Q = 0.707 Q = 0.707
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("allpass", [CENTRAL_FREQ, str(Q) + 'q']) E.append_effect_to_chain("allpass", [CENTRAL_FREQ, str(Q) + 'q'])
...@@ -147,7 +144,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -147,7 +144,7 @@ class TestFunctionalFiltering(unittest.TestCase):
Q = 0.707 Q = 0.707
CONST_SKIRT_GAIN = True CONST_SKIRT_GAIN = True
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("bandpass", ["-c", CENTRAL_FREQ, str(Q) + 'q']) E.append_effect_to_chain("bandpass", ["-c", CENTRAL_FREQ, str(Q) + 'q'])
...@@ -169,7 +166,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -169,7 +166,7 @@ class TestFunctionalFiltering(unittest.TestCase):
Q = 0.707 Q = 0.707
CONST_SKIRT_GAIN = False CONST_SKIRT_GAIN = False
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("bandpass", [CENTRAL_FREQ, str(Q) + 'q']) E.append_effect_to_chain("bandpass", [CENTRAL_FREQ, str(Q) + 'q'])
...@@ -190,7 +187,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -190,7 +187,7 @@ class TestFunctionalFiltering(unittest.TestCase):
CENTRAL_FREQ = 1000 CENTRAL_FREQ = 1000
Q = 0.707 Q = 0.707
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("bandreject", [CENTRAL_FREQ, str(Q) + 'q']) E.append_effect_to_chain("bandreject", [CENTRAL_FREQ, str(Q) + 'q'])
...@@ -212,7 +209,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -212,7 +209,7 @@ class TestFunctionalFiltering(unittest.TestCase):
Q = 0.707 Q = 0.707
NOISE = True NOISE = True
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("band", ["-n", CENTRAL_FREQ, str(Q) + 'q']) E.append_effect_to_chain("band", ["-n", CENTRAL_FREQ, str(Q) + 'q'])
...@@ -234,7 +231,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -234,7 +231,7 @@ class TestFunctionalFiltering(unittest.TestCase):
Q = 0.707 Q = 0.707
NOISE = False NOISE = False
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("band", [CENTRAL_FREQ, str(Q) + 'q']) E.append_effect_to_chain("band", [CENTRAL_FREQ, str(Q) + 'q'])
...@@ -256,7 +253,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -256,7 +253,7 @@ class TestFunctionalFiltering(unittest.TestCase):
Q = 0.707 Q = 0.707
GAIN = 40 GAIN = 40
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("treble", [GAIN, CENTRAL_FREQ, str(Q) + 'q']) E.append_effect_to_chain("treble", [GAIN, CENTRAL_FREQ, str(Q) + 'q'])
...@@ -274,7 +271,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -274,7 +271,7 @@ class TestFunctionalFiltering(unittest.TestCase):
Test biquad deemph filter, compare to SoX implementation Test biquad deemph filter, compare to SoX implementation
""" """
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("deemph") E.append_effect_to_chain("deemph")
...@@ -292,7 +289,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -292,7 +289,7 @@ class TestFunctionalFiltering(unittest.TestCase):
Test biquad riaa filter, compare to SoX implementation Test biquad riaa filter, compare to SoX implementation
""" """
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("riaa") E.append_effect_to_chain("riaa")
...@@ -314,7 +311,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -314,7 +311,7 @@ class TestFunctionalFiltering(unittest.TestCase):
Q = 0.707 Q = 0.707
GAIN = 1 GAIN = 1
noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") noise_filepath = common_utils.get_asset_path('whitenoise.wav')
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath) E.set_input_file(noise_filepath)
E.append_effect_to_chain("equalizer", [CENTER_FREQ, Q, GAIN]) E.append_effect_to_chain("equalizer", [CENTER_FREQ, Q, GAIN])
...@@ -329,7 +326,7 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -329,7 +326,7 @@ class TestFunctionalFiltering(unittest.TestCase):
@AudioBackendScope("sox") @AudioBackendScope("sox")
def test_perf_biquad_filtering(self): def test_perf_biquad_filtering(self):
fn_sine = os.path.join(self.test_dirpath, "assets", "whitenoise.wav") fn_sine = common_utils.get_asset_path('whitenoise.wav')
b0 = 0.4 b0 = 0.4
b1 = 0.2 b1 = 0.2
......
...@@ -2,16 +2,14 @@ import unittest ...@@ -2,16 +2,14 @@ import unittest
import torch import torch
import torchaudio import torchaudio
import math import math
import os
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir import common_utils
from common_utils import AudioBackendScope, BACKENDS
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
class Test_SoxEffectsChain(unittest.TestCase): class Test_SoxEffectsChain(unittest.TestCase):
test_dirpath, test_dir = create_temp_assets_dir() test_filepath = common_utils.get_asset_path("steam-train-whistle-daniel_simon.mp3")
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -22,7 +20,7 @@ class Test_SoxEffectsChain(unittest.TestCase): ...@@ -22,7 +20,7 @@ class Test_SoxEffectsChain(unittest.TestCase):
torchaudio.shutdown_sox() torchaudio.shutdown_sox()
def test_single_channel(self): def test_single_channel(self):
fn_sine = os.path.join(self.test_dirpath, "assets", "sinewave.wav") fn_sine = common_utils.get_asset_path("sinewave.wav")
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(fn_sine) E.set_input_file(fn_sine)
E.append_effect_to_chain("echos", [0.8, 0.7, 40, 0.25, 63, 0.3]) E.append_effect_to_chain("echos", [0.8, 0.7, 40, 0.25, 63, 0.3])
......
"""Test suites for jit-ability and its numerical compatibility""" """Test suites for jit-ability and its numerical compatibility"""
import os
import unittest import unittest
import torch import torch
...@@ -81,8 +80,7 @@ class _FunctionalTestMixin: ...@@ -81,8 +80,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
def test_detect_pitch_frequency(self): def test_detect_pitch_frequency(self):
filepath = os.path.join( filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.mp3')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3')
waveform, _ = torchaudio.load(filepath) waveform, _ = torchaudio.load(filepath)
def func(tensor): def func(tensor):
...@@ -213,7 +211,7 @@ class _FunctionalTestMixin: ...@@ -213,7 +211,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, tensor, shape_only=True) self._assert_consistency(func, tensor, shape_only=True)
def test_lfilter(self): def test_lfilter(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -254,7 +252,7 @@ class _FunctionalTestMixin: ...@@ -254,7 +252,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_lowpass(self): def test_lowpass(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -265,7 +263,7 @@ class _FunctionalTestMixin: ...@@ -265,7 +263,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_highpass(self): def test_highpass(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -276,7 +274,7 @@ class _FunctionalTestMixin: ...@@ -276,7 +274,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_allpass(self): def test_allpass(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -288,7 +286,7 @@ class _FunctionalTestMixin: ...@@ -288,7 +286,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_bandpass_with_csg(self): def test_bandpass_with_csg(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -301,7 +299,7 @@ class _FunctionalTestMixin: ...@@ -301,7 +299,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_bandpass_withou_csg(self): def test_bandpass_withou_csg(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -314,7 +312,7 @@ class _FunctionalTestMixin: ...@@ -314,7 +312,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_bandreject(self): def test_bandreject(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -326,7 +324,7 @@ class _FunctionalTestMixin: ...@@ -326,7 +324,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_band_with_noise(self): def test_band_with_noise(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -339,7 +337,7 @@ class _FunctionalTestMixin: ...@@ -339,7 +337,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_band_without_noise(self): def test_band_without_noise(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -352,7 +350,7 @@ class _FunctionalTestMixin: ...@@ -352,7 +350,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_treble(self): def test_treble(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -365,7 +363,7 @@ class _FunctionalTestMixin: ...@@ -365,7 +363,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_deemph(self): def test_deemph(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -375,7 +373,7 @@ class _FunctionalTestMixin: ...@@ -375,7 +373,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_riaa(self): def test_riaa(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -385,7 +383,7 @@ class _FunctionalTestMixin: ...@@ -385,7 +383,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_equalizer(self): def test_equalizer(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -398,7 +396,7 @@ class _FunctionalTestMixin: ...@@ -398,7 +396,7 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_perf_biquad_filtering(self): def test_perf_biquad_filtering(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor): def func(tensor):
...@@ -469,8 +467,7 @@ class _TransformsTestMixin: ...@@ -469,8 +467,7 @@ class _TransformsTestMixin:
) )
def test_Fade(self): def test_Fade(self):
test_filepath = os.path.join( test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) waveform, _ = torchaudio.load(test_filepath)
fade_in_len = 3000 fade_in_len = 3000
fade_out_len = 3000 fade_out_len = 3000
...@@ -485,8 +482,7 @@ class _TransformsTestMixin: ...@@ -485,8 +482,7 @@ class _TransformsTestMixin:
self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor) self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
def test_Vol(self): def test_Vol(self):
test_filepath = os.path.join( test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) waveform, _ = torchaudio.load(test_filepath)
self._assert_consistency(T.Vol(1.1), waveform) self._assert_consistency(T.Vol(1.1), waveform)
......
import math import math
import os
import unittest import unittest
import torch import torch
...@@ -7,7 +6,7 @@ import torchaudio ...@@ -7,7 +6,7 @@ import torchaudio
import torchaudio.transforms as transforms import torchaudio.transforms as transforms
import torchaudio.functional as F import torchaudio.functional as F
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir import common_utils
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -19,10 +18,6 @@ class Tester(unittest.TestCase): ...@@ -19,10 +18,6 @@ class Tester(unittest.TestCase):
waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)) waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
waveform.unsqueeze_(0) # (1, 64000) waveform.unsqueeze_(0) # (1, 64000)
waveform = (waveform * volume * 2**31).long() waveform = (waveform * volume * 2**31).long()
# file for stereo stft test
test_dirpath, test_dir = create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.wav')
def scale(self, waveform, factor=2.0**31): def scale(self, waveform, factor=2.0**31):
# scales a waveform by a factor # scales a waveform by a factor
...@@ -45,7 +40,8 @@ class Tester(unittest.TestCase): ...@@ -45,7 +40,8 @@ class Tester(unittest.TestCase):
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_AmplitudeToDB(self): def test_AmplitudeToDB(self):
waveform, sample_rate = torchaudio.load(self.test_filepath) filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, sample_rate = torchaudio.load(filepath)
mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.) mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
power_to_db_transform = transforms.AmplitudeToDB('power', 80.) power_to_db_transform = transforms.AmplitudeToDB('power', 80.)
...@@ -114,7 +110,8 @@ class Tester(unittest.TestCase): ...@@ -114,7 +110,8 @@ class Tester(unittest.TestCase):
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
# check on multi-channel audio # check on multi-channel audio
x_stereo, sr_stereo = torchaudio.load(self.test_filepath) # (2, 278756), 44100 filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
x_stereo, sr_stereo = torchaudio.load(filepath) # (2, 278756), 44100
spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394) spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2) self.assertTrue(spectrogram_stereo.size(0) == 2)
...@@ -164,7 +161,7 @@ class Tester(unittest.TestCase): ...@@ -164,7 +161,7 @@ class Tester(unittest.TestCase):
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
def test_resample_size(self): def test_resample_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_path = common_utils.get_asset_path('sinewave.wav')
waveform, sample_rate = torchaudio.load(input_path) waveform, sample_rate = torchaudio.load(input_path)
upsample_rate = sample_rate * 2 upsample_rate = sample_rate * 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment