"vscode:/vscode.git/clone" did not exist on "40f16872c749d5b5cbf26218b3bd33c6a6788582"
Unverified Commit a9c4d0a8 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor torchscript test helper function (#521)

parent 657f0a02
...@@ -10,19 +10,16 @@ import torchaudio.transforms ...@@ -10,19 +10,16 @@ import torchaudio.transforms
import common_utils import common_utils
def _test_torchscript_functional_shape(py_method, *args, **kwargs): def _assert_functional_consistency(py_method, *args, shape_only=False, **kwargs):
jit_method = torch.jit.script(py_method) jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs) jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs) py_out = py_method(*args, **kwargs)
assert jit_out.shape == py_out.shape if shape_only:
return jit_out, py_out assert jit_out.shape == py_out.shape, (jit_out.shape, py_out.shape)
else:
torch.testing.assert_allclose(jit_out, py_out)
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_out, py_out = _test_torchscript_functional_shape(py_method, *args, **kwargs)
torch.testing.assert_allclose(jit_out, py_out)
def _test_lfilter(waveform): def _test_lfilter(waveform):
...@@ -58,7 +55,7 @@ def _test_lfilter(waveform): ...@@ -58,7 +55,7 @@ def _test_lfilter(waveform):
], ],
device=waveform.device, device=waveform.device,
) )
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs) _assert_functional_consistency(F.lfilter, waveform, a_coeffs, b_coeffs)
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
...@@ -73,7 +70,7 @@ class TestFunctional(unittest.TestCase): ...@@ -73,7 +70,7 @@ class TestFunctional(unittest.TestCase):
power = 2 power = 2
normalize = False normalize = False
_test_torchscript_functional( _assert_functional_consistency(
F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
) )
...@@ -89,7 +86,7 @@ class TestFunctional(unittest.TestCase): ...@@ -89,7 +86,7 @@ class TestFunctional(unittest.TestCase):
n_iter = 32 n_iter = 32
length = 1000 length = 1000
_test_torchscript_functional( _assert_functional_consistency(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0 F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0
) )
...@@ -100,13 +97,13 @@ class TestFunctional(unittest.TestCase): ...@@ -100,13 +97,13 @@ class TestFunctional(unittest.TestCase):
win_length = 2 * 7 + 1 win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time) specgram = torch.randn(channel, n_mfcc, time)
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length) _assert_functional_consistency(F.compute_deltas, specgram, win_length=win_length)
def test_detect_pitch_frequency(self): def test_detect_pitch_frequency(self):
filepath = os.path.join( filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3') common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3')
waveform, sample_rate = torchaudio.load(filepath) waveform, sample_rate = torchaudio.load(filepath)
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate) _assert_functional_consistency(F.detect_pitch_frequency, waveform, sample_rate)
def test_create_fb_matrix(self): def test_create_fb_matrix(self):
n_stft = 100 n_stft = 100
...@@ -115,7 +112,7 @@ class TestFunctional(unittest.TestCase): ...@@ -115,7 +112,7 @@ class TestFunctional(unittest.TestCase):
n_mels = 10 n_mels = 10
sample_rate = 16000 sample_rate = 16000
_test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate) _assert_functional_consistency(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)
def test_amplitude_to_DB(self): def test_amplitude_to_DB(self):
spec = torch.rand((6, 201)) spec = torch.rand((6, 201))
...@@ -124,39 +121,39 @@ class TestFunctional(unittest.TestCase): ...@@ -124,39 +121,39 @@ class TestFunctional(unittest.TestCase):
db_multiplier = 0.0 db_multiplier = 0.0
top_db = 80.0 top_db = 80.0
_test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db) _assert_functional_consistency(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)
def test_DB_to_amplitude(self): def test_DB_to_amplitude(self):
x = torch.rand((1, 100)) x = torch.rand((1, 100))
ref = 1. ref = 1.
power = 1. power = 1.
_test_torchscript_functional(F.DB_to_amplitude, x, ref, power) _assert_functional_consistency(F.DB_to_amplitude, x, ref, power)
def test_create_dct(self): def test_create_dct(self):
n_mfcc = 40 n_mfcc = 40
n_mels = 128 n_mels = 128
norm = "ortho" norm = "ortho"
_test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm) _assert_functional_consistency(F.create_dct, n_mfcc, n_mels, norm)
def test_mu_law_encoding(self): def test_mu_law_encoding(self):
tensor = torch.rand((1, 10)) tensor = torch.rand((1, 10))
qc = 256 qc = 256
_test_torchscript_functional(F.mu_law_encoding, tensor, qc) _assert_functional_consistency(F.mu_law_encoding, tensor, qc)
def test_mu_law_decoding(self): def test_mu_law_decoding(self):
tensor = torch.rand((1, 10)) tensor = torch.rand((1, 10))
qc = 256 qc = 256
_test_torchscript_functional(F.mu_law_decoding, tensor, qc) _assert_functional_consistency(F.mu_law_decoding, tensor, qc)
def test_complex_norm(self): def test_complex_norm(self):
complex_tensor = torch.randn(1, 2, 1025, 400, 2) complex_tensor = torch.randn(1, 2, 1025, 400, 2)
power = 2 power = 2
_test_torchscript_functional(F.complex_norm, complex_tensor, power) _assert_functional_consistency(F.complex_norm, complex_tensor, power)
def test_mask_along_axis(self): def test_mask_along_axis(self):
specgram = torch.randn(2, 1025, 400) specgram = torch.randn(2, 1025, 400)
...@@ -164,7 +161,7 @@ class TestFunctional(unittest.TestCase): ...@@ -164,7 +161,7 @@ class TestFunctional(unittest.TestCase):
mask_value = 30. mask_value = 30.
axis = 2 axis = 2
_test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis) _assert_functional_consistency(F.mask_along_axis, specgram, mask_param, mask_value, axis)
def test_mask_along_axis_iid(self): def test_mask_along_axis_iid(self):
specgrams = torch.randn(4, 2, 1025, 400) specgrams = torch.randn(4, 2, 1025, 400)
...@@ -172,20 +169,20 @@ class TestFunctional(unittest.TestCase): ...@@ -172,20 +169,20 @@ class TestFunctional(unittest.TestCase):
mask_value = 30. mask_value = 30.
axis = 2 axis = 2
_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis) _assert_functional_consistency(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
def test_gain(self): def test_gain(self):
tensor = torch.rand((1, 1000)) tensor = torch.rand((1, 1000))
gainDB = 2.0 gainDB = 2.0
_test_torchscript_functional(F.gain, tensor, gainDB) _assert_functional_consistency(F.gain, tensor, gainDB)
def test_dither(self): def test_dither(self):
tensor = torch.rand((2, 1000)) tensor = torch.rand((2, 1000))
_test_torchscript_functional_shape(F.dither, tensor) _assert_functional_consistency(F.dither, tensor, shape_only=True)
_test_torchscript_functional_shape(F.dither, tensor, "RPDF") _assert_functional_consistency(F.dither, tensor, "RPDF", shape_only=True)
_test_torchscript_functional_shape(F.dither, tensor, "GPDF") _assert_functional_consistency(F.dither, tensor, "GPDF", shape_only=True)
def test_lfilter(self): def test_lfilter(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
...@@ -203,14 +200,14 @@ class TestFunctional(unittest.TestCase): ...@@ -203,14 +200,14 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, cutoff_freq) _assert_functional_consistency(F.lowpass_biquad, waveform, sample_rate, cutoff_freq)
def test_highpass(self): def test_highpass(self):
cutoff_freq = 2000 cutoff_freq = 2000
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, cutoff_freq) _assert_functional_consistency(F.highpass_biquad, waveform, sample_rate, cutoff_freq)
def test_allpass(self): def test_allpass(self):
central_freq = 1000 central_freq = 1000
...@@ -218,7 +215,7 @@ class TestFunctional(unittest.TestCase): ...@@ -218,7 +215,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, central_freq, q) _assert_functional_consistency(F.allpass_biquad, waveform, sample_rate, central_freq, q)
def test_bandpass_with_csg(self): def test_bandpass_with_csg(self):
central_freq = 1000 central_freq = 1000
...@@ -227,7 +224,7 @@ class TestFunctional(unittest.TestCase): ...@@ -227,7 +224,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional( _assert_functional_consistency(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain) F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandpass_withou_csg(self): def test_bandpass_withou_csg(self):
...@@ -237,7 +234,7 @@ class TestFunctional(unittest.TestCase): ...@@ -237,7 +234,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional( _assert_functional_consistency(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain) F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandreject(self): def test_bandreject(self):
...@@ -246,7 +243,7 @@ class TestFunctional(unittest.TestCase): ...@@ -246,7 +243,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional( _assert_functional_consistency(
F.bandreject_biquad, waveform, sample_rate, central_freq, q) F.bandreject_biquad, waveform, sample_rate, central_freq, q)
def test_band_with_noise(self): def test_band_with_noise(self):
...@@ -256,7 +253,7 @@ class TestFunctional(unittest.TestCase): ...@@ -256,7 +253,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise) _assert_functional_consistency(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_band_without_noise(self): def test_band_without_noise(self):
central_freq = 1000 central_freq = 1000
...@@ -265,7 +262,7 @@ class TestFunctional(unittest.TestCase): ...@@ -265,7 +262,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise) _assert_functional_consistency(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_treble(self): def test_treble(self):
gain = 40 gain = 40
...@@ -274,17 +271,17 @@ class TestFunctional(unittest.TestCase): ...@@ -274,17 +271,17 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, gain, central_freq, q) _assert_functional_consistency(F.treble_biquad, waveform, sample_rate, gain, central_freq, q)
def test_deemph(self): def test_deemph(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate) _assert_functional_consistency(F.deemph_biquad, waveform, sample_rate)
def test_riaa(self): def test_riaa(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate) _assert_functional_consistency(F.riaa_biquad, waveform, sample_rate)
def test_equalizer(self): def test_equalizer(self):
center_freq = 300 center_freq = 300
...@@ -293,7 +290,7 @@ class TestFunctional(unittest.TestCase): ...@@ -293,7 +290,7 @@ class TestFunctional(unittest.TestCase):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional( _assert_functional_consistency(
F.equalizer_biquad, waveform, sample_rate, center_freq, gain, q) F.equalizer_biquad, waveform, sample_rate, center_freq, gain, q)
def test_perf_biquad_filtering(self): def test_perf_biquad_filtering(self):
...@@ -301,7 +298,7 @@ class TestFunctional(unittest.TestCase): ...@@ -301,7 +298,7 @@ class TestFunctional(unittest.TestCase):
b = torch.tensor([0.4, 0.2, 0.9]) b = torch.tensor([0.4, 0.2, 0.9])
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lfilter, waveform, a, b) _assert_functional_consistency(F.lfilter, waveform, a, b)
RUN_CUDA = torch.cuda.is_available() RUN_CUDA = torch.cuda.is_available()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment