"vscode:/vscode.git/clone" did not exist on "b9b44f7be8a577f556ed0de96c9cda2f0f6cda56"
Unverified Commit 4875007d authored by moto's avatar moto Committed by GitHub
Browse files

Extract JIT tests from filter test module and put in JIT test module. (#507)

parent 21269247
...@@ -9,15 +9,6 @@ import torchaudio.transforms as T ...@@ -9,15 +9,6 @@ import torchaudio.transforms as T
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert torch.allclose(jit_out, py_out)
class TestFunctionalFiltering(unittest.TestCase): class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = create_temp_assets_dir() test_dirpath, test_dir = create_temp_assets_dir()
...@@ -88,7 +79,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -88,7 +79,6 @@ class TestFunctionalFiltering(unittest.TestCase):
assert len(output_waveform.size()) == 2 assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0) assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1) assert output_waveform.size(1) == waveform.size(1)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
def test_lfilter(self): def test_lfilter(self):
...@@ -189,7 +179,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -189,7 +179,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ) output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -211,7 +200,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -211,7 +200,6 @@ class TestFunctionalFiltering(unittest.TestCase):
# TBD - this fails at the 1e-4 level, debug why # TBD - this fails at the 1e-4 level, debug why
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -233,7 +221,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -233,7 +221,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q) output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -256,7 +243,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -256,7 +243,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN) output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -279,7 +265,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -279,7 +265,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN) output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -301,7 +286,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -301,7 +286,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q) output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.bandreject_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -324,7 +308,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -324,7 +308,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE) output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -347,7 +330,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -347,7 +330,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE) output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -370,7 +352,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -370,7 +352,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q) output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -389,7 +370,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -389,7 +370,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.deemph_biquad(waveform, sample_rate) output_waveform = F.deemph_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -408,7 +388,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -408,7 +388,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.riaa_biquad(waveform, sample_rate) output_waveform = F.riaa_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -431,7 +410,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -431,7 +410,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q) output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available") @unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox") @AudioBackendScope("sox")
...@@ -458,9 +436,6 @@ class TestFunctionalFiltering(unittest.TestCase): ...@@ -458,9 +436,6 @@ class TestFunctionalFiltering(unittest.TestCase):
) )
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4) assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
_test_torchscript_functional(
F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -25,6 +25,42 @@ def _test_torchscript_functional(py_method, *args, **kwargs): ...@@ -25,6 +25,42 @@ def _test_torchscript_functional(py_method, *args, **kwargs):
assert torch.allclose(jit_out, py_out) assert torch.allclose(jit_out, py_out)
def _test_lfilter(waveform):
"""
Design an IIR lowpass filter using scipy.signal filter design
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
Example
>>> from scipy.signal import iirdesign
>>> b, a = iirdesign(0.2, 0.3, 1, 60)
"""
b_coeffs = torch.tensor(
[
0.00299893,
-0.0051152,
0.00841964,
-0.00747802,
0.00841964,
-0.0051152,
0.00299893,
],
device=waveform.device,
)
a_coeffs = torch.tensor(
[
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
],
device=waveform.device,
)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
"""Test functions in `functional` module.""" """Test functions in `functional` module."""
def test_spectrogram(self): def test_spectrogram(self):
...@@ -151,6 +187,122 @@ class TestFunctional(unittest.TestCase): ...@@ -151,6 +187,122 @@ class TestFunctional(unittest.TestCase):
_test_torchscript_functional_shape(F.dither, tensor, "RPDF") _test_torchscript_functional_shape(F.dither, tensor, "RPDF")
_test_torchscript_functional_shape(F.dither, tensor, "GPDF") _test_torchscript_functional_shape(F.dither, tensor, "GPDF")
def test_lfilter(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_lfilter(waveform)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_lfilter_cuda(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_lfilter(waveform.cuda(device=torch.device("cuda:0")))
def test_lowpass(self):
cutoff_freq = 3000
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, cutoff_freq)
def test_highpass(self):
cutoff_freq = 2000
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, cutoff_freq)
def test_allpass(self):
central_freq = 1000
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, central_freq, q)
def test_bandpass_with_csg(self):
central_freq = 1000
q = 0.707
const_skirt_gain = True
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandpass_withou_csg(self):
central_freq = 1000
q = 0.707
const_skirt_gain = False
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandreject(self):
central_freq = 1000
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.bandreject_biquad, waveform, sample_rate, central_freq, q)
def test_band_with_noise(self):
central_freq = 1000
q = 0.707
noise = True
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_band_without_noise(self):
central_freq = 1000
q = 0.707
noise = False
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_treble(self):
gain = 40
central_freq = 1000
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, gain, central_freq, q)
def test_deemph(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)
def test_riaa(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)
def test_equalizer(self):
center_freq = 300
gain = 1
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.equalizer_biquad, waveform, sample_rate, center_freq, gain, q)
def test_perf_biquad_filtering(self):
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lfilter, waveform, a, b)
RUN_CUDA = torch.cuda.is_available() RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA) print("Run test with cuda:", RUN_CUDA)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment