Unverified Commit c1ef2edd authored by Chin-Yun Yu's avatar Chin-Yun Yu Committed by GitHub
Browse files

Refactor `F.lfilter` and `F.*_biquad` autograd tests (#1438)

parent 9a0e70ea
from typing import Callable, Tuple
import torch
from parameterized import parameterized
from torch import Tensor
import torchaudio.functional as F
from torch.autograd import gradcheck
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
)
class Autograd(TestBaseMixin):
def assert_grad(
self,
transform: Callable[..., Tensor],
inputs: Tuple[torch.Tensor],
*,
enable_all_grad: bool = True,
):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
class Autograd(common_utils.TestBaseMixin):
def test_lfilter_x(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
def test_lfilter_a(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
a.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
def test_lfilter_b(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
b.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
def test_lfilter_all_inputs(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
b.requires_grad = True
a.requires_grad = True
x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.lfilter, (x, a, b))
def test_biquad(self):
torch.random.manual_seed(2434)
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device, requires_grad=True)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]), eps=1e-10)
x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
def test_band_biquad(self):
@parameterized.expand([
(800, 0.7, True),
(800, 0.7, False),
])
def test_band_biquad(self, central_freq, Q, noise):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q))
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(central_freq)
Q = torch.tensor(Q)
self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
def test_band_biquad_with_noise(self):
@parameterized.expand([
(800, 0.7, 10),
(800, 0.7, -10),
])
def test_bass_biquad(self, central_freq, Q, gain):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q, True))
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(central_freq)
Q = torch.tensor(Q)
gain = torch.tensor(gain)
self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
def test_bass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(100, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bass_biquad, (x, sr, gain, central_freq, Q))
def test_treble_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(3000, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.treble_biquad, (x, sr, gain, central_freq, Q))
@parameterized.expand([
(3000, 0.7, 10),
(3000, 0.7, -10),
def test_allpass_biquad(self):
])
def test_treble_biquad(self, central_freq, Q, gain):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.allpass_biquad, (x, sr, central_freq, Q))
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(central_freq)
Q = torch.tensor(Q)
gain = torch.tensor(gain)
self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
def test_lowpass_biquad(self):
@parameterized.expand([
(800, 0.7, ),
])
def test_allpass_biquad(self, central_freq, Q):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(central_freq)
Q = torch.tensor(Q)
self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
def test_highpass_biquad(self):
@parameterized.expand([
(800, 0.7, ),
])
def test_lowpass_biquad(self, cutoff_freq, Q):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.highpass_biquad, (x, sr, cutoff_freq, Q))
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
cutoff_freq = torch.tensor(cutoff_freq)
Q = torch.tensor(Q)
self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
def test_bandpass_biquad(self):
@parameterized.expand([
(800, 0.7, ),
])
def test_highpass_biquad(self, cutoff_freq, Q):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q))
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
cutoff_freq = torch.tensor(cutoff_freq)
Q = torch.tensor(Q)
self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
def test_bandpass_biquad_with_const_skirt_gain(self):
@parameterized.expand([
(800, 0.7, True),
(800, 0.7, False),
])
def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q, True))
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(central_freq)
Q = torch.tensor(Q)
self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
def test_equalizer_biquad(self):
@parameterized.expand([
(800, 0.7, 10),
(800, 0.7, -10),
])
def test_equalizer_biquad(self, central_freq, Q, gain):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(central_freq)
Q = torch.tensor(Q)
gain = torch.tensor(gain)
self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
def test_bandreject_biquad(self):
@parameterized.expand([
(800, 0.7, ),
])
def test_bandreject_biquad(self, central_freq, Q):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandreject_biquad, (x, sr, central_freq, Q))
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(central_freq)
Q = torch.tensor(Q)
self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment