Unverified Commit c1ef2edd authored by Chin-Yun Yu's avatar Chin-Yun Yu Committed by GitHub
Browse files

Refactor `F.lfilter` and `F.*_biquad` autograd tests (#1438)

parent 9a0e70ea
from typing import Callable, Tuple
import torch import torch
from parameterized import parameterized
from torch import Tensor
import torchaudio.functional as F import torchaudio.functional as F
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torchaudio_unittest import common_utils from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
)
class Autograd(TestBaseMixin):
def assert_grad(
self,
transform: Callable[..., Tensor],
inputs: Tuple[torch.Tensor],
*,
enable_all_grad: bool = True,
):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
class Autograd(common_utils.TestBaseMixin):
def test_lfilter_x(self): def test_lfilter_x(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) b = torch.tensor([0.4, 0.2, 0.9])
x.requires_grad = True x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
def test_lfilter_a(self): def test_lfilter_a(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) b = torch.tensor([0.4, 0.2, 0.9])
a.requires_grad = True a.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
def test_lfilter_b(self): def test_lfilter_b(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) b = torch.tensor([0.4, 0.2, 0.9])
b.requires_grad = True b.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
def test_lfilter_all_inputs(self): def test_lfilter_all_inputs(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) b = torch.tensor([0.4, 0.2, 0.9])
b.requires_grad = True self.assert_grad(F.lfilter, (x, a, b))
a.requires_grad = True
x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
def test_biquad(self): def test_biquad(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device, requires_grad=True) a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device, requires_grad=True) b = torch.tensor([0.4, 0.2, 0.9])
assert gradcheck(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]), eps=1e-10) self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
def test_band_biquad(self): @parameterized.expand([
(800, 0.7, True),
(800, 0.7, False),
])
def test_band_biquad(self, central_freq, Q, noise):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) central_freq = torch.tensor(central_freq)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) Q = torch.tensor(Q)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q)) self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
def test_band_biquad_with_noise(self): @parameterized.expand([
(800, 0.7, 10),
(800, 0.7, -10),
])
def test_bass_biquad(self, central_freq, Q, gain):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) central_freq = torch.tensor(central_freq)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) Q = torch.tensor(Q)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q, True)) gain = torch.tensor(gain)
self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
def test_bass_biquad(self): @parameterized.expand([
torch.random.manual_seed(2434) (3000, 0.7, 10),
sr = 22050 (3000, 0.7, -10),
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(100, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bass_biquad, (x, sr, gain, central_freq, Q))
def test_treble_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(3000, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.treble_biquad, (x, sr, gain, central_freq, Q))
def test_allpass_biquad(self): ])
def test_treble_biquad(self, central_freq, Q, gain):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) central_freq = torch.tensor(central_freq)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) Q = torch.tensor(Q)
assert gradcheck(F.allpass_biquad, (x, sr, central_freq, Q)) gain = torch.tensor(gain)
self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
def test_lowpass_biquad(self): @parameterized.expand([
(800, 0.7, ),
])
def test_allpass_biquad(self, central_freq, Q):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) central_freq = torch.tensor(central_freq)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) Q = torch.tensor(Q)
assert gradcheck(F.lowpass_biquad, (x, sr, cutoff_freq, Q)) self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
def test_highpass_biquad(self): @parameterized.expand([
(800, 0.7, ),
])
def test_lowpass_biquad(self, cutoff_freq, Q):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) cutoff_freq = torch.tensor(cutoff_freq)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) Q = torch.tensor(Q)
assert gradcheck(F.highpass_biquad, (x, sr, cutoff_freq, Q)) self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
def test_bandpass_biquad(self): @parameterized.expand([
(800, 0.7, ),
])
def test_highpass_biquad(self, cutoff_freq, Q):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) cutoff_freq = torch.tensor(cutoff_freq)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) Q = torch.tensor(Q)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q)) self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
def test_bandpass_biquad_with_const_skirt_gain(self): @parameterized.expand([
(800, 0.7, True),
(800, 0.7, False),
])
def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) central_freq = torch.tensor(central_freq)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) Q = torch.tensor(Q)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q, True)) self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
def test_equalizer_biquad(self): @parameterized.expand([
(800, 0.7, 10),
(800, 0.7, -10),
])
def test_equalizer_biquad(self, central_freq, Q, gain):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) central_freq = torch.tensor(central_freq)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) Q = torch.tensor(Q)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True) gain = torch.tensor(gain)
assert gradcheck(F.equalizer_biquad, (x, sr, central_freq, gain, Q)) self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
def test_bandreject_biquad(self): @parameterized.expand([
(800, 0.7, ),
])
def test_bandreject_biquad(self, central_freq, Q):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
sr = 22050 sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) central_freq = torch.tensor(central_freq)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) Q = torch.tensor(Q)
assert gradcheck(F.bandreject_biquad, (x, sr, central_freq, Q)) self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment