Unverified Commit 723e9a52 authored by Chin-Yun Yu's avatar Chin-Yun Yu Committed by GitHub
Browse files

Support higher order derivatives for `F.lfilter` (#1441)

parent 5417e4fb
......@@ -3,7 +3,7 @@ import torch
from parameterized import parameterized
from torch import Tensor
import torchaudio.functional as F
from torch.autograd import gradcheck
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
......@@ -26,6 +26,7 @@ class Autograd(TestBaseMixin):
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_)
def test_lfilter_x(self):
torch.random.manual_seed(2434)
......
......@@ -80,170 +80,159 @@ void lfilter_core_generic_loop(
}
}
std::tuple<at::Tensor, at::Tensor> lfilter_core(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));
class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_channel = waveform.size(0);
int64_t n_sample = waveform.size(1);
int64_t n_order = a_coeffs_normalized.size(0);
int64_t n_sample_padded = n_sample + n_order - 1;
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);
auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous();
auto device = waveform.device();
int64_t n_order = a_coeffs.size(0);
auto options = torch::TensorOptions().dtype(dtype).device(device);
auto padded_output_waveform =
torch::zeros({n_channel, n_sample_padded}, options);
if (device.is_cpu()) {
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
} else {
lfilter_core_generic_loop(
waveform, a_coeff_flipped, padded_output_waveform);
}
TORCH_INTERNAL_ASSERT(n_order > 0);
auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});
namespace F = torch::nn::functional;
ctx->save_for_backward({waveform, a_coeffs_normalized, output});
return output;
}
auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
auto padded_output_waveform = torch::zeros_like(padded_waveform);
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto a_coeffs_normalized = saved[1];
auto y = saved[2];
auto a_coeff_flipped = a_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
int64_t n_channel = x.size(0);
int64_t n_order = a_coeffs_normalized.size(0);
auto input_signal_windows =
F::conv1d(
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order}))
.squeeze(1);
auto dx = torch::Tensor();
auto da = torch::Tensor();
auto dy = grad_outputs[0];
input_signal_windows.div_(a_coeffs[0]);
a_coeff_flipped.div_(a_coeffs[0]);
namespace F = torch::nn::functional;
if (device.is_cpu()) {
cpu_lfilter_core_loop(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
} else {
lfilter_core_generic_loop(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
}
if (a_coeffs_normalized.requires_grad()) {
auto dyda = F::pad(
DifferentiableIIR::apply(-y, a_coeffs_normalized),
F::PadFuncOptions({n_order - 1, 0}));
auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});
da = F::conv1d(
dyda.unsqueeze(0),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
}
return {output, input_signal_windows};
}
if (x.requires_grad()) {
dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1);
}
torch::Tensor lfilter_simple(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
return std::get<0>(lfilter_core(waveform, a_coeffs, b_coeffs));
}
return {dx, da};
}
};
class DifferentiableLfilter
: public torch::autograd::Function<DifferentiableLfilter> {
class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
at::AutoNonVariableTypeMode g;
auto result = lfilter_core(waveform, a_coeffs, b_coeffs);
ctx->save_for_backward(
{waveform,
a_coeffs,
b_coeffs,
std::get<0>(result),
std::get<1>(result)});
return std::get<0>(result);
int64_t n_order = b_coeffs.size(0);
namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
auto output =
F::conv1d(
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order}))
.squeeze(1);
ctx->save_for_backward({waveform, b_coeffs, output});
return output;
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto waveform = saved[0];
auto a_coeffs = saved[1];
auto b_coeffs = saved[2];
auto y = saved[3];
auto xh = saved[4];
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_channel = waveform.size(0);
int64_t n_sample = waveform.size(1);
int64_t n_order = a_coeffs.size(0);
int64_t n_sample_padded = n_sample + n_order - 1;
auto x = saved[0];
auto b_coeffs = saved[1];
auto y = saved[2];
auto a_coeff_flipped = a_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
b_coeff_flipped.div_(a_coeffs[0]);
a_coeff_flipped.div_(a_coeffs[0]);
int64_t n_channel = x.size(0);
int64_t n_order = b_coeffs.size(0);
auto dx = torch::Tensor();
auto da = torch::Tensor();
auto db = torch::Tensor();
auto dy = grad_outputs[0];
at::AutoNonVariableTypeMode g;
namespace F = torch::nn::functional;
auto options = torch::TensorOptions().dtype(dtype).device(device);
if (a_coeffs.requires_grad()) {
auto dyda = torch::zeros({n_channel, n_sample_padded}, options);
if (device.is_cpu()) {
cpu_lfilter_core_loop(-y, a_coeff_flipped, dyda);
} else {
lfilter_core_generic_loop(-y, a_coeff_flipped, dyda);
}
da = F::conv1d(
dyda.unsqueeze(0),
if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
da.div_(a_coeffs[0]);
}
if (b_coeffs.requires_grad() || waveform.requires_grad()) {
auto dxh = torch::zeros({n_channel, n_sample_padded}, options);
if (device.is_cpu()) {
cpu_lfilter_core_loop(dy.flip(1), a_coeff_flipped, dxh);
} else {
lfilter_core_generic_loop(dy.flip(1), a_coeff_flipped, dxh);
}
dxh = dxh.index(
{torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)})
.flip(1);
if (waveform.requires_grad()) {
dx = F::conv1d(
F::pad(dxh.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})),
b_coeffs.view({1, 1, n_order}))
.squeeze(1);
dx.div_(a_coeffs[0]);
}
if (b_coeffs.requires_grad()) {
db =
F::conv1d(
F::pad(
waveform.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})),
dxh.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
db.div_(a_coeffs[0]);
}
if (x.requires_grad()) {
dx = F::conv1d(
F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})),
b_coeffs.view({1, 1, n_order}))
.squeeze(1);
}
return {dx, da, db};
return {dx, db};
}
};
torch::Tensor lfilter_autograd(
torch::Tensor lfilter_core(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
return DifferentiableLfilter::apply(waveform, a_coeffs, b_coeffs);
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);
int64_t n_order = b_coeffs.size(0);
TORCH_INTERNAL_ASSERT(n_order > 0);
auto filtered_waveform =
DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]);
auto output =
DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]);
return output;
}
} // namespace
......@@ -259,10 +248,6 @@ TORCH_LIBRARY(torchaudio, m) {
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
}
TORCH_LIBRARY_IMPL(torchaudio, DefaultBackend, m) {
m.impl("torchaudio::_lfilter", lfilter_simple);
}
TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
m.impl("torchaudio::_lfilter", lfilter_autograd);
TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment