Unverified Commit b95d60c2 authored by moto's avatar moto Committed by GitHub
Browse files

Extract JIT tests from test_functional to the dedicated test module (#480)

parent d63d851e
......@@ -17,21 +17,6 @@ if IMPORT_LIBROSA:
import librosa
def _test_torchscript_functional_shape(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert jit_out.shape == py_out.shape
return jit_out, py_out
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_out, py_out = _test_torchscript_functional_shape(py_method, *args, **kwargs)
assert torch.allclose(jit_out, py_out)
class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)]
number_of_trials = 100
......@@ -43,38 +28,6 @@ class TestFunctional(unittest.TestCase):
'steam-train-whistle-daniel_simon.wav')
waveform_train, sr_train = torchaudio.load(test_filepath)
def test_torchscript_spectrogram(self):
tensor = torch.rand((1, 1000))
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws)
power = 2
normalize = False
_test_torchscript_functional(
F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
)
def test_torchscript_griffinlim(self):
tensor = torch.rand((1, 201, 6))
n_fft = 400
ws = 400
hop = 200
window = torch.hann_window(ws)
power = 2
normalize = False
momentum = 0.99
n_iter = 32
length = 1000
init = 0
_test_torchscript_functional(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0
)
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa not available')
def test_griffinlim(self):
......@@ -138,26 +91,10 @@ class TestFunctional(unittest.TestCase):
[0.5, 1.0, 1.0, 0.5]]])
self._test_compute_deltas(specgram, expected)
def test_compute_deltas_randn(self):
channel = 13
n_mfcc = channel * 3
time = 1021
win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time)
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
def test_jit_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)]
......@@ -568,33 +505,6 @@ class TestFunctional(unittest.TestCase):
torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)
def test_torchscript_create_fb_matrix(self):
n_stft = 100
f_min = 0.0
f_max = 20.0
n_mels = 10
sample_rate = 16000
_test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)
def test_torchscript_amplitude_to_DB(self):
spec = torch.rand((6, 201))
multiplier = 10.0
amin = 1e-10
db_multiplier = 0.0
top_db = 80.0
_test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)
def test_torchscript_DB_to_amplitude(self):
x = torch.rand((1, 100))
ref = 1.
power = 1.
_test_torchscript_functional(F.DB_to_amplitude, x, ref, power)
def test_DB_to_amplitude(self):
# Make some noise
x = torch.rand(1000)
......@@ -661,66 +571,6 @@ class TestFunctional(unittest.TestCase):
self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))
def test_torchscript_create_dct(self):
n_mfcc = 40
n_mels = 128
norm = "ortho"
_test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm)
def test_torchscript_mu_law_encoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_encoding, tensor, qc)
def test_torchscript_mu_law_decoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_decoding, tensor, qc)
def test_torchscript_complex_norm(self):
complex_tensor = torch.randn(1, 2, 1025, 400, 2)
power = 2
_test_torchscript_functional(F.complex_norm, complex_tensor, power)
def test_mask_along_axis(self):
specgram = torch.randn(2, 1025, 400)
mask_param = 100
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis)
def test_mask_along_axis_iid(self):
specgrams = torch.randn(4, 2, 1025, 400)
mask_param = 100
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
def test_torchscript_gain(self):
tensor = torch.rand((1, 1000))
gainDB = 2.0
_test_torchscript_functional(F.gain, tensor, gainDB)
def test_torchscript_dither(self):
tensor = torch.rand((2, 1000))
_test_torchscript_functional_shape(F.dither, tensor)
_test_torchscript_functional_shape(F.dither, tensor, "RPDF")
_test_torchscript_functional_shape(F.dither, tensor, "GPDF")
def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
......
"""Test suites for jit-ability and its numerical compatibility"""
import os
import unittest
import torch
import torchaudio
import torchaudio.functional as F
import common_utils
def _test_torchscript_functional_shape(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert jit_out.shape == py_out.shape
return jit_out, py_out
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_out, py_out = _test_torchscript_functional_shape(py_method, *args, **kwargs)
assert torch.allclose(jit_out, py_out)
class TestFunctional(unittest.TestCase):
"""Test functions in `functional` module."""
def test_spectrogram(self):
tensor = torch.rand((1, 1000))
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws)
power = 2
normalize = False
_test_torchscript_functional(
F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
)
def test_griffinlim(self):
tensor = torch.rand((1, 201, 6))
n_fft = 400
ws = 400
hop = 200
window = torch.hann_window(ws)
power = 2
normalize = False
momentum = 0.99
n_iter = 32
length = 1000
_test_torchscript_functional(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0
)
def test_compute_deltas(self):
channel = 13
n_mfcc = channel * 3
time = 1021
win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time)
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
def test_detect_pitch_frequency(self):
filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3')
waveform, sample_rate = torchaudio.load(filepath)
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
def test_create_fb_matrix(self):
n_stft = 100
f_min = 0.0
f_max = 20.0
n_mels = 10
sample_rate = 16000
_test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)
def test_amplitude_to_DB(self):
spec = torch.rand((6, 201))
multiplier = 10.0
amin = 1e-10
db_multiplier = 0.0
top_db = 80.0
_test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)
def test_DB_to_amplitude(self):
x = torch.rand((1, 100))
ref = 1.
power = 1.
_test_torchscript_functional(F.DB_to_amplitude, x, ref, power)
def test_create_dct(self):
n_mfcc = 40
n_mels = 128
norm = "ortho"
_test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm)
def test_mu_law_encoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_encoding, tensor, qc)
def test_mu_law_decoding(self):
tensor = torch.rand((1, 10))
qc = 256
_test_torchscript_functional(F.mu_law_decoding, tensor, qc)
def test_complex_norm(self):
complex_tensor = torch.randn(1, 2, 1025, 400, 2)
power = 2
_test_torchscript_functional(F.complex_norm, complex_tensor, power)
def test_mask_along_axis(self):
specgram = torch.randn(2, 1025, 400)
mask_param = 100
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis)
def test_mask_along_axis_iid(self):
specgrams = torch.randn(4, 2, 1025, 400)
mask_param = 100
mask_value = 30.
axis = 2
_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
def test_gain(self):
tensor = torch.rand((1, 1000))
gainDB = 2.0
_test_torchscript_functional(F.gain, tensor, gainDB)
def test_dither(self):
tensor = torch.rand((2, 1000))
_test_torchscript_functional_shape(F.dither, tensor)
_test_torchscript_functional_shape(F.dither, tensor, "RPDF")
_test_torchscript_functional_shape(F.dither, tensor, "GPDF")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment