Unverified Commit 2a3d52ff authored by chin yun yu's avatar chin yun yu Committed by GitHub
Browse files

Add backprop support to lfilter (#1310)

parent ed9020c1
import torch
from .autograd_impl import Autograd
from torchaudio_unittest import common_utils
class TestAutogradLfilterCPU(Autograd, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
import torch
from .autograd_impl import Autograd
from torchaudio_unittest import common_utils
@common_utils.skipIfNoCuda
class TestAutogradLfilterCUDA(Autograd, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
import torch
import torchaudio.functional as F
from torch.autograd import gradcheck
from torchaudio_unittest import common_utils
class Autograd(common_utils.TestBaseMixin):
def test_x_grad(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
def test_a_grad(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
a.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
def test_b_grad(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
b.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
def test_all_grad(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
b.requires_grad = True
a.requires_grad = True
x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
...@@ -6,6 +6,7 @@ import torchaudio ...@@ -6,6 +6,7 @@ import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
import itertools import itertools
import unittest
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
...@@ -21,6 +22,10 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): ...@@ -21,6 +22,10 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
@unittest.expectedFailure
def test_9th_order_filter_stability(self):
super().test_9th_order_filter_stability()
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
......
import torch import torch
import unittest
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram from .functional_impl import Lfilter, Spectrogram
...@@ -9,6 +10,10 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): ...@@ -9,6 +10,10 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device('cuda')
@unittest.expectedFailure
def test_9th_order_filter_stability(self):
super().test_9th_order_filter_stability()
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
from scipy import signal
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
...@@ -45,6 +46,28 @@ class Lfilter(common_utils.TestBaseMixin): ...@@ -45,6 +46,28 @@ class Lfilter(common_utils.TestBaseMixin):
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert shape == waveform.size() == output_waveform.size() assert shape == waveform.size() == output_waveform.size()
def test_9th_order_filter_stability(self):
"""
Validate the precision of lfilter against reference scipy implementation when using high order filter.
The reference implementation use cascaded second-order filters so is more numerically accurate.
"""
# create an impulse signal
x = torch.zeros(1024, dtype=self.dtype, device=self.device)
x[0] = 1
# get target impulse response
sos = signal.butter(9, 850, 'hp', fs=22050, output='sos')
y = torch.from_numpy(signal.sosfilt(sos, x.cpu().numpy())).to(self.dtype).to(self.device)
# get lfilter coefficients
b, a = signal.butter(9, 850, 'hp', fs=22050, output='ba')
b, a = torch.from_numpy(b).to(self.dtype).to(self.device), torch.from_numpy(
a).to(self.dtype).to(self.device)
# predict impulse response
yhat = F.lfilter(x, a, b, False)
self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5)
class Spectrogram(common_utils.TestBaseMixin): class Spectrogram(common_utils.TestBaseMixin):
@parameterized.expand([(0., ), (1., ), (2., ), (3., )]) @parameterized.expand([(0., ), (1., ), (2., ), (3., )])
......
...@@ -80,7 +80,7 @@ void lfilter_core_generic_loop( ...@@ -80,7 +80,7 @@ void lfilter_core_generic_loop(
} }
} }
torch::Tensor lfilter_core( std::tuple<at::Tensor, at::Tensor> lfilter_core(
const torch::Tensor& waveform, const torch::Tensor& waveform,
const torch::Tensor& a_coeffs, const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) { const torch::Tensor& b_coeffs) {
...@@ -123,7 +123,127 @@ torch::Tensor lfilter_core( ...@@ -123,7 +123,127 @@ torch::Tensor lfilter_core(
{torch::indexing::Slice(), {torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)}); torch::indexing::Slice(n_order - 1, torch::indexing::None)});
return output; return {output, input_signal_windows};
}
torch::Tensor lfilter_simple(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
return std::get<0>(lfilter_core(waveform, a_coeffs, b_coeffs));
}
class DifferentiableLfilter
: public torch::autograd::Function<DifferentiableLfilter> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
at::AutoNonVariableTypeMode g;
auto result = lfilter_core(waveform, a_coeffs, b_coeffs);
ctx->save_for_backward(
{waveform,
a_coeffs,
b_coeffs,
std::get<0>(result),
std::get<1>(result)});
return std::get<0>(result);
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto waveform = saved[0];
auto a_coeffs = saved[1];
auto b_coeffs = saved[2];
auto y = saved[3];
auto xh = saved[4];
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_channel = waveform.size(0);
int64_t n_sample = waveform.size(1);
int64_t n_order = a_coeffs.size(0);
int64_t n_sample_padded = n_sample + n_order - 1;
auto a_coeff_flipped = a_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
b_coeff_flipped.div_(a_coeffs[0]);
a_coeff_flipped.div_(a_coeffs[0]);
auto dx = torch::Tensor();
auto da = torch::Tensor();
auto db = torch::Tensor();
auto dy = grad_outputs[0];
at::AutoNonVariableTypeMode g;
namespace F = torch::nn::functional;
auto options = torch::TensorOptions().dtype(dtype).device(device);
if (a_coeffs.requires_grad()) {
auto dyda = torch::zeros({n_channel, n_sample_padded}, options);
if (device.is_cpu()) {
cpu_lfilter_core_loop(-y, a_coeff_flipped, dyda);
} else {
lfilter_core_generic_loop(-y, a_coeff_flipped, dyda);
}
da = F::conv1d(
dyda.unsqueeze(0),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
da.div_(a_coeffs[0]);
}
if (b_coeffs.requires_grad() || waveform.requires_grad()) {
auto dxh = torch::zeros({n_channel, n_sample_padded}, options);
if (device.is_cpu()) {
cpu_lfilter_core_loop(dy.flip(1), a_coeff_flipped, dxh);
} else {
lfilter_core_generic_loop(dy.flip(1), a_coeff_flipped, dxh);
}
dxh = dxh.index(
{torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)})
.flip(1);
if (waveform.requires_grad()) {
dx = F::conv1d(
F::pad(dxh.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})),
b_coeffs.view({1, 1, n_order}))
.squeeze(1);
dx.div_(a_coeffs[0]);
}
if (b_coeffs.requires_grad()) {
db =
F::conv1d(
F::pad(
waveform.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})),
dxh.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
db.div_(a_coeffs[0]);
}
}
return {dx, da, db};
}
};
torch::Tensor lfilter_autograd(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
return DifferentiableLfilter::apply(waveform, a_coeffs, b_coeffs);
} }
} // namespace } // namespace
...@@ -139,6 +259,10 @@ TORCH_LIBRARY(torchaudio, m) { ...@@ -139,6 +259,10 @@ TORCH_LIBRARY(torchaudio, m) {
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"); "torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
} }
TORCH_LIBRARY_IMPL(torchaudio, Math, m) { TORCH_LIBRARY_IMPL(torchaudio, DefaultBackend, m) {
m.impl("torchaudio::_lfilter", lfilter_core); m.impl("torchaudio::_lfilter", lfilter_simple);
}
TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
m.impl("torchaudio::_lfilter", lfilter_autograd);
} }
...@@ -884,6 +884,10 @@ def lfilter( ...@@ -884,6 +884,10 @@ def lfilter(
) -> Tensor: ) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation. r"""Perform an IIR filter by evaluating difference equation.
Note:
To avoid numerical problems, small filter order is prefered.
Using double precision could also minimize numerical precision errors.
Args: Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``. a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment