Unverified Commit a9c4d0a8 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor torchscript test helper function (#521)

parent 657f0a02
......@@ -10,19 +10,16 @@ import torchaudio.transforms
import common_utils
def _test_torchscript_functional_shape(py_method, *args, **kwargs):
def _assert_functional_consistency(py_method, *args, shape_only=False, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert jit_out.shape == py_out.shape
return jit_out, py_out
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_out, py_out = _test_torchscript_functional_shape(py_method, *args, **kwargs)
torch.testing.assert_allclose(jit_out, py_out)
if shape_only:
assert jit_out.shape == py_out.shape, (jit_out.shape, py_out.shape)
else:
torch.testing.assert_allclose(jit_out, py_out)
def _test_lfilter(waveform):
......@@ -58,7 +55,7 @@ def _test_lfilter(waveform):
],
device=waveform.device,
)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
_assert_functional_consistency(F.lfilter, waveform, a_coeffs, b_coeffs)
class TestFunctional(unittest.TestCase):
......@@ -73,7 +70,7 @@ class TestFunctional(unittest.TestCase):
power = 2
normalize = False
_test_torchscript_functional(
_assert_functional_consistency(
F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
)
......@@ -89,7 +86,7 @@ class TestFunctional(unittest.TestCase):
n_iter = 32
length = 1000
_test_torchscript_functional(
_assert_functional_consistency(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0
)
......@@ -100,13 +97,13 @@ class TestFunctional(unittest.TestCase):
win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time)
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
_assert_functional_consistency(F.compute_deltas, specgram, win_length=win_length)
def test_detect_pitch_frequency(self):
filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3')
waveform, sample_rate = torchaudio.load(filepath)
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
_assert_functional_consistency(F.detect_pitch_frequency, waveform, sample_rate)
def test_create_fb_matrix(self):
n_stft = 100
......@@ -115,7 +112,7 @@ class TestFunctional(unittest.TestCase):
n_mels = 10
sample_rate = 16000
_test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)
_assert_functional_consistency(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)
def test_amplitude_to_DB(self):
spec = torch.rand((6, 201))
......@@ -124,39 +121,39 @@ class TestFunctional(unittest.TestCase):
db_multiplier = 0.0
top_db = 80.0
_test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)
_assert_functional_consistency(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)
def test_DB_to_amplitude(self):
x = torch.rand((1, 100))
ref = 1.
power = 1.
_test_torchscript_functional(F.DB_to_amplitude, x, ref, power)
_assert_functional_consistency(F.DB_to_amplitude, x, ref, power)
def test_create_dct(self):
n_mfcc = 40
n_mels = 128
norm = "ortho"
_test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm)
_assert_functional_consistency(F.create_dct, n_mfcc, n_mels, norm)
def test_mu_law_encoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_encoding, tensor, qc)
_assert_functional_consistency(F.mu_law_encoding, tensor, qc)
def test_mu_law_decoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_decoding, tensor, qc)
_assert_functional_consistency(F.mu_law_decoding, tensor, qc)
def test_complex_norm(self):
complex_tensor = torch.randn(1, 2, 1025, 400, 2)
power = 2
_test_torchscript_functional(F.complex_norm, complex_tensor, power)
_assert_functional_consistency(F.complex_norm, complex_tensor, power)
def test_mask_along_axis(self):
specgram = torch.randn(2, 1025, 400)
......@@ -164,7 +161,7 @@ class TestFunctional(unittest.TestCase):
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis)
_assert_functional_consistency(F.mask_along_axis, specgram, mask_param, mask_value, axis)
def test_mask_along_axis_iid(self):
specgrams = torch.randn(4, 2, 1025, 400)
......@@ -172,20 +169,20 @@ class TestFunctional(unittest.TestCase):
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
_assert_functional_consistency(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
def test_gain(self):
tensor = torch.rand((1, 1000))
gainDB = 2.0
_test_torchscript_functional(F.gain, tensor, gainDB)
_assert_functional_consistency(F.gain, tensor, gainDB)
def test_dither(self):
tensor = torch.rand((2, 1000))
_test_torchscript_functional_shape(F.dither, tensor)
_test_torchscript_functional_shape(F.dither, tensor, "RPDF")
_test_torchscript_functional_shape(F.dither, tensor, "GPDF")
_assert_functional_consistency(F.dither, tensor, shape_only=True)
_assert_functional_consistency(F.dither, tensor, "RPDF", shape_only=True)
_assert_functional_consistency(F.dither, tensor, "GPDF", shape_only=True)
def test_lfilter(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
......@@ -203,14 +200,14 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, cutoff_freq)
_assert_functional_consistency(F.lowpass_biquad, waveform, sample_rate, cutoff_freq)
def test_highpass(self):
cutoff_freq = 2000
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, cutoff_freq)
_assert_functional_consistency(F.highpass_biquad, waveform, sample_rate, cutoff_freq)
def test_allpass(self):
central_freq = 1000
......@@ -218,7 +215,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, central_freq, q)
_assert_functional_consistency(F.allpass_biquad, waveform, sample_rate, central_freq, q)
def test_bandpass_with_csg(self):
central_freq = 1000
......@@ -227,7 +224,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
_assert_functional_consistency(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandpass_withou_csg(self):
......@@ -237,7 +234,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
_assert_functional_consistency(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandreject(self):
......@@ -246,7 +243,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
_assert_functional_consistency(
F.bandreject_biquad, waveform, sample_rate, central_freq, q)
def test_band_with_noise(self):
......@@ -256,7 +253,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
_assert_functional_consistency(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_band_without_noise(self):
central_freq = 1000
......@@ -265,7 +262,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
_assert_functional_consistency(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_treble(self):
gain = 40
......@@ -274,17 +271,17 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, gain, central_freq, q)
_assert_functional_consistency(F.treble_biquad, waveform, sample_rate, gain, central_freq, q)
def test_deemph(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)
_assert_functional_consistency(F.deemph_biquad, waveform, sample_rate)
def test_riaa(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)
_assert_functional_consistency(F.riaa_biquad, waveform, sample_rate)
def test_equalizer(self):
center_freq = 300
......@@ -293,7 +290,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
_assert_functional_consistency(
F.equalizer_biquad, waveform, sample_rate, center_freq, gain, q)
def test_perf_biquad_filtering(self):
......@@ -301,7 +298,7 @@ class TestFunctional(unittest.TestCase):
b = torch.tensor([0.4, 0.2, 0.9])
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lfilter, waveform, a, b)
_assert_functional_consistency(F.lfilter, waveform, a, b)
RUN_CUDA = torch.cuda.is_available()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment