"vscode:/vscode.git/clone" did not exist on "0e6a8403f6b4d2a2778c12d1e76588d00a8d8f1a"
Unverified Commit 4875007d authored by moto's avatar moto Committed by GitHub
Browse files

Extract JIT tests from filter test module and put in JIT test module. (#507)

parent 21269247
......@@ -9,15 +9,6 @@ import torchaudio.transforms as T
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
def _test_torchscript_functional(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
assert torch.allclose(jit_out, py_out)
class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = create_temp_assets_dir()
......@@ -88,7 +79,6 @@ class TestFunctionalFiltering(unittest.TestCase):
assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
def test_lfilter(self):
......@@ -189,7 +179,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -211,7 +200,6 @@ class TestFunctionalFiltering(unittest.TestCase):
# TBD - this fails at the 1e-4 level, debug why
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -233,7 +221,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -256,7 +243,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -279,7 +265,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -301,7 +286,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.bandreject_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -324,7 +308,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -347,7 +330,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -370,7 +352,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -389,7 +370,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.deemph_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -408,7 +388,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.riaa_biquad(waveform, sample_rate)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -431,7 +410,6 @@ class TestFunctionalFiltering(unittest.TestCase):
output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
......@@ -458,9 +436,6 @@ class TestFunctionalFiltering(unittest.TestCase):
)
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
_test_torchscript_functional(
F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
if __name__ == "__main__":
......
......@@ -25,6 +25,42 @@ def _test_torchscript_functional(py_method, *args, **kwargs):
assert torch.allclose(jit_out, py_out)
def _test_lfilter(waveform):
"""
Design an IIR lowpass filter using scipy.signal filter design
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
Example
>>> from scipy.signal import iirdesign
>>> b, a = iirdesign(0.2, 0.3, 1, 60)
"""
b_coeffs = torch.tensor(
[
0.00299893,
-0.0051152,
0.00841964,
-0.00747802,
0.00841964,
-0.0051152,
0.00299893,
],
device=waveform.device,
)
a_coeffs = torch.tensor(
[
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
],
device=waveform.device,
)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
class TestFunctional(unittest.TestCase):
"""Test functions in `functional` module."""
def test_spectrogram(self):
......@@ -151,6 +187,122 @@ class TestFunctional(unittest.TestCase):
_test_torchscript_functional_shape(F.dither, tensor, "RPDF")
_test_torchscript_functional_shape(F.dither, tensor, "GPDF")
def test_lfilter(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_lfilter(waveform)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_lfilter_cuda(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_lfilter(waveform.cuda(device=torch.device("cuda:0")))
def test_lowpass(self):
cutoff_freq = 3000
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, cutoff_freq)
def test_highpass(self):
cutoff_freq = 2000
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, cutoff_freq)
def test_allpass(self):
central_freq = 1000
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, central_freq, q)
def test_bandpass_with_csg(self):
central_freq = 1000
q = 0.707
const_skirt_gain = True
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandpass_withou_csg(self):
central_freq = 1000
q = 0.707
const_skirt_gain = False
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandreject(self):
central_freq = 1000
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.bandreject_biquad, waveform, sample_rate, central_freq, q)
def test_band_with_noise(self):
central_freq = 1000
q = 0.707
noise = True
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_band_without_noise(self):
central_freq = 1000
q = 0.707
noise = False
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_treble(self):
gain = 40
central_freq = 1000
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, gain, central_freq, q)
def test_deemph(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)
def test_riaa(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)
def test_equalizer(self):
center_freq = 300
gain = 1
q = 0.707
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(
F.equalizer_biquad, waveform, sample_rate, center_freq, gain, q)
def test_perf_biquad_filtering(self):
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
_test_torchscript_functional(F.lfilter, waveform, a, b)
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment