Unverified Commit 4875007d authored by moto's avatar moto Committed by GitHub
Browse files

Extract JIT tests from filter test module and put in JIT test module. (#507)

parent 21269247
...@@ -9,15 +9,6 @@ import torchaudio.transforms as T ...@@ -9,15 +9,6 @@ import torchaudio.transforms as T
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert torch.allclose(jit_out, py_out)
class TestFunctionalFiltering(unittest.TestCase): class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = create_temp_assets_dir() test_dirpath, test_dir = create_temp_assets_dir()
...@@ -88,7 +79,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -88,7 +79,6 @@ class TestFunctionalFiltering(unittest.TestCase):
assert len(output_waveform.size()) == 2 assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0) assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1) assert output_waveform.size(1) == waveform.size(1)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
def test_lfilter(self): def test_lfilter(self):
...@@ -189,7 +179,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -189,7 +179,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ) output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -211,7 +200,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -211,7 +200,6 @@ class TestFunctionalFiltering(unittest.TestCase):
# TBD - this fails at the 1e-4 level, debug why # TBD - this fails at the 1e-4 level, debug why
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -233,7 +221,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -233,7 +221,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q) output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -256,7 +243,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -256,7 +243,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN) output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -279,7 +265,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -279,7 +265,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN) output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -301,7 +286,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -301,7 +286,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q) output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.bandreject_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -324,7 +308,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -324,7 +308,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE) output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -347,7 +330,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -347,7 +330,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE) output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -370,7 +352,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -370,7 +352,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q) output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -389,7 +370,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -389,7 +370,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.deemph_biquad(waveform, sample_rate) output_waveform = F.deemph_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -408,7 +388,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -408,7 +388,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.riaa_biquad(waveform, sample_rate) output_waveform = F.riaa_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -431,7 +410,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -431,7 +410,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q) output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -458,9 +436,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -458,9 +436,6 @@ class TestFunctionalFiltering(unittest.TestCase):
) )
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4) assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
_test_torchscript_functional(
F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -25,6 +25,42 @@ def _test_torchscript_functional(py_method, *args, **kwargs): ...@@ -25,6 +25,42 @@ def _test_torchscript_functional(py_method, *args, **kwargs):
assert torch.allclose(jit_out, py_out) assert torch.allclose(jit_out, py_out)
def _test_lfilter(waveform):
"""
Design an IIR lowpass filter using scipy.signal filter design
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
Example
>>> from scipy.signal import iirdesign
>>> b, a = iirdesign(0.2, 0.3, 1, 60)
"""
b_coeffs = torch.tensor(
[
0.00299893,
-0.0051152,
0.00841964,
-0.00747802,
0.00841964,
-0.0051152,
0.00299893,
],
device=waveform.device,
)
a_coeffs = torch.tensor(
[
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
],
device=waveform.device,
)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
"""Test functions in `functional` module.""" """Test functions in `functional` module."""
def test_spectrogram(self): def test_spectrogram(self):
...@@ -151,6 +187,122 @@ class TestFunctional(unittest.TestCase): ...@@ -151,6 +187,122 @@ class TestFunctional(unittest.TestCase):
_test_torchscript_functional_shape(F.dither, tensor, "RPDF") _test_torchscript_functional_shape(F.dither, tensor, "RPDF")
_test_torchscript_functional_shape(F.dither, tensor, "GPDF") _test_torchscript_functional_shape(F.dither, tensor, "GPDF")
def test_lfilter(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_lfilter(waveform)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_lfilter_cuda(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_lfilter(waveform.cuda(device=torch.device("cuda:0")))
def test_lowpass(self):
cutoff_freq = 3000
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, cutoff_freq)
def test_highpass(self):
cutoff_freq = 2000
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, cutoff_freq)
def test_allpass(self):
central_freq = 1000
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, central_freq, q)
def test_bandpass_with_csg(self):
central_freq = 1000
q = 0.707
const_skirt_gain = True
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandpass_withou_csg(self):
central_freq = 1000
q = 0.707
const_skirt_gain = False
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandreject(self):
central_freq = 1000
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.bandreject_biquad, waveform, sample_rate, central_freq, q)
def test_band_with_noise(self):
central_freq = 1000
q = 0.707
noise = True
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_band_without_noise(self):
central_freq = 1000
q = 0.707
noise = False
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_treble(self):
gain = 40
central_freq = 1000
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, gain, central_freq, q)
def test_deemph(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)
def test_riaa(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)
def test_equalizer(self):
center_freq = 300
gain = 1
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.equalizer_biquad, waveform, sample_rate, center_freq, gain, q)
def test_perf_biquad_filtering(self):
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lfilter, waveform, a, b)
RUN_CUDA = torch.cuda.is_available() RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA) print("Run test with cuda:", RUN_CUDA)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment