Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
2a3d52ff
Unverified
Commit
2a3d52ff
authored
Mar 16, 2021
by
chin yun yu
Committed by
GitHub
Mar 15, 2021
Browse files
Add backprop support to lfilter (#1310)
parent
ed9020c1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
222 additions
and
4 deletions
+222
-4
test/torchaudio_unittest/functional/autograd_cpu_test.py
test/torchaudio_unittest/functional/autograd_cpu_test.py
+8
-0
test/torchaudio_unittest/functional/autograd_cuda_test.py
test/torchaudio_unittest/functional/autograd_cuda_test.py
+9
-0
test/torchaudio_unittest/functional/autograd_impl.py
test/torchaudio_unittest/functional/autograd_impl.py
+40
-0
test/torchaudio_unittest/functional/functional_cpu_test.py
test/torchaudio_unittest/functional/functional_cpu_test.py
+5
-0
test/torchaudio_unittest/functional/functional_cuda_test.py
test/torchaudio_unittest/functional/functional_cuda_test.py
+5
-0
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+23
-0
torchaudio/csrc/lfilter.cpp
torchaudio/csrc/lfilter.cpp
+128
-4
torchaudio/functional/filtering.py
torchaudio/functional/filtering.py
+4
-0
No files found.
test/torchaudio_unittest/functional/autograd_cpu_test.py
0 → 100644
View file @
2a3d52ff
import
torch
from
.autograd_impl
import
Autograd
from
torchaudio_unittest
import
common_utils
class
TestAutogradLfilterCPU
(
Autograd
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/functional/autograd_cuda_test.py
0 → 100644
View file @
2a3d52ff
import
torch
from
.autograd_impl
import
Autograd
from
torchaudio_unittest
import
common_utils
@
common_utils
.
skipIfNoCuda
class
TestAutogradLfilterCUDA
(
Autograd
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/functional/autograd_impl.py
0 → 100644
View file @
2a3d52ff
import
torch
import
torchaudio.functional
as
F
from
torch.autograd
import
gradcheck
from
torchaudio_unittest
import
common_utils
class
Autograd
(
common_utils
.
TestBaseMixin
):
def
test_x_grad
(
self
):
torch
.
random
.
manual_seed
(
2434
)
x
=
torch
.
rand
(
2
,
4
,
256
*
2
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
x
.
requires_grad
=
True
assert
gradcheck
(
F
.
lfilter
,
(
x
,
a
,
b
),
eps
=
1e-10
)
def
test_a_grad
(
self
):
torch
.
random
.
manual_seed
(
2434
)
x
=
torch
.
rand
(
2
,
4
,
256
*
2
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
a
.
requires_grad
=
True
assert
gradcheck
(
F
.
lfilter
,
(
x
,
a
,
b
),
eps
=
1e-10
)
def
test_b_grad
(
self
):
torch
.
random
.
manual_seed
(
2434
)
x
=
torch
.
rand
(
2
,
4
,
256
*
2
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
b
.
requires_grad
=
True
assert
gradcheck
(
F
.
lfilter
,
(
x
,
a
,
b
),
eps
=
1e-10
)
def
test_all_grad
(
self
):
torch
.
random
.
manual_seed
(
2434
)
x
=
torch
.
rand
(
2
,
4
,
256
*
2
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
a
=
torch
.
tensor
([
0.7
,
0.2
,
0.6
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
b
.
requires_grad
=
True
a
.
requires_grad
=
True
x
.
requires_grad
=
True
assert
gradcheck
(
F
.
lfilter
,
(
x
,
a
,
b
),
eps
=
1e-10
)
test/torchaudio_unittest/functional/functional_cpu_test.py
View file @
2a3d52ff
...
@@ -6,6 +6,7 @@ import torchaudio
...
@@ -6,6 +6,7 @@ import torchaudio
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
from
parameterized
import
parameterized
from
parameterized
import
parameterized
import
itertools
import
itertools
import
unittest
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest.common_utils
import
(
from
torchaudio_unittest.common_utils
import
(
...
@@ -21,6 +22,10 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
...
@@ -21,6 +22,10 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype
=
torch
.
float32
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
@
unittest
.
expectedFailure
def
test_9th_order_filter_stability
(
self
):
super
().
test_9th_order_filter_stability
()
class
TestLFilterFloat64
(
Lfilter
,
common_utils
.
PytorchTestCase
):
class
TestLFilterFloat64
(
Lfilter
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
...
...
test/torchaudio_unittest/functional/functional_cuda_test.py
View file @
2a3d52ff
import
torch
import
torch
import
unittest
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest
import
common_utils
from
.functional_impl
import
Lfilter
,
Spectrogram
from
.functional_impl
import
Lfilter
,
Spectrogram
...
@@ -9,6 +10,10 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
...
@@ -9,6 +10,10 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype
=
torch
.
float32
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
@
unittest
.
expectedFailure
def
test_9th_order_filter_stability
(
self
):
super
().
test_9th_order_filter_stability
()
@
common_utils
.
skipIfNoCuda
@
common_utils
.
skipIfNoCuda
class
TestLFilterFloat64
(
Lfilter
,
common_utils
.
PytorchTestCase
):
class
TestLFilterFloat64
(
Lfilter
,
common_utils
.
PytorchTestCase
):
...
...
test/torchaudio_unittest/functional/functional_impl.py
View file @
2a3d52ff
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
torch
import
torch
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
scipy
import
signal
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest
import
common_utils
...
@@ -45,6 +46,28 @@ class Lfilter(common_utils.TestBaseMixin):
...
@@ -45,6 +46,28 @@ class Lfilter(common_utils.TestBaseMixin):
output_waveform
=
F
.
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
)
output_waveform
=
F
.
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
)
assert
shape
==
waveform
.
size
()
==
output_waveform
.
size
()
assert
shape
==
waveform
.
size
()
==
output_waveform
.
size
()
def
test_9th_order_filter_stability
(
self
):
"""
Validate the precision of lfilter against reference scipy implementation when using high order filter.
The reference implementation use cascaded second-order filters so is more numerically accurate.
"""
# create an impulse signal
x
=
torch
.
zeros
(
1024
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
x
[
0
]
=
1
# get target impulse response
sos
=
signal
.
butter
(
9
,
850
,
'hp'
,
fs
=
22050
,
output
=
'sos'
)
y
=
torch
.
from_numpy
(
signal
.
sosfilt
(
sos
,
x
.
cpu
().
numpy
())).
to
(
self
.
dtype
).
to
(
self
.
device
)
# get lfilter coefficients
b
,
a
=
signal
.
butter
(
9
,
850
,
'hp'
,
fs
=
22050
,
output
=
'ba'
)
b
,
a
=
torch
.
from_numpy
(
b
).
to
(
self
.
dtype
).
to
(
self
.
device
),
torch
.
from_numpy
(
a
).
to
(
self
.
dtype
).
to
(
self
.
device
)
# predict impulse response
yhat
=
F
.
lfilter
(
x
,
a
,
b
,
False
)
self
.
assertEqual
(
yhat
,
y
,
atol
=
1e-4
,
rtol
=
1e-5
)
class
Spectrogram
(
common_utils
.
TestBaseMixin
):
class
Spectrogram
(
common_utils
.
TestBaseMixin
):
@
parameterized
.
expand
([(
0.
,
),
(
1.
,
),
(
2.
,
),
(
3.
,
)])
@
parameterized
.
expand
([(
0.
,
),
(
1.
,
),
(
2.
,
),
(
3.
,
)])
...
...
torchaudio/csrc/lfilter.cpp
View file @
2a3d52ff
...
@@ -80,7 +80,7 @@ void lfilter_core_generic_loop(
...
@@ -80,7 +80,7 @@ void lfilter_core_generic_loop(
}
}
}
}
torch
::
Tensor
lfilter_core
(
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
lfilter_core
(
const
torch
::
Tensor
&
waveform
,
const
torch
::
Tensor
&
waveform
,
const
torch
::
Tensor
&
a_coeffs
,
const
torch
::
Tensor
&
a_coeffs
,
const
torch
::
Tensor
&
b_coeffs
)
{
const
torch
::
Tensor
&
b_coeffs
)
{
...
@@ -123,7 +123,127 @@ torch::Tensor lfilter_core(
...
@@ -123,7 +123,127 @@ torch::Tensor lfilter_core(
{
torch
::
indexing
::
Slice
(),
{
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(
n_order
-
1
,
torch
::
indexing
::
None
)});
torch
::
indexing
::
Slice
(
n_order
-
1
,
torch
::
indexing
::
None
)});
return
output
;
return
{
output
,
input_signal_windows
};
}
torch
::
Tensor
lfilter_simple
(
const
torch
::
Tensor
&
waveform
,
const
torch
::
Tensor
&
a_coeffs
,
const
torch
::
Tensor
&
b_coeffs
)
{
return
std
::
get
<
0
>
(
lfilter_core
(
waveform
,
a_coeffs
,
b_coeffs
));
}
class
DifferentiableLfilter
:
public
torch
::
autograd
::
Function
<
DifferentiableLfilter
>
{
public:
static
torch
::
Tensor
forward
(
torch
::
autograd
::
AutogradContext
*
ctx
,
const
torch
::
Tensor
&
waveform
,
const
torch
::
Tensor
&
a_coeffs
,
const
torch
::
Tensor
&
b_coeffs
)
{
at
::
AutoNonVariableTypeMode
g
;
auto
result
=
lfilter_core
(
waveform
,
a_coeffs
,
b_coeffs
);
ctx
->
save_for_backward
(
{
waveform
,
a_coeffs
,
b_coeffs
,
std
::
get
<
0
>
(
result
),
std
::
get
<
1
>
(
result
)});
return
std
::
get
<
0
>
(
result
);
}
static
torch
::
autograd
::
tensor_list
backward
(
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
autograd
::
tensor_list
grad_outputs
)
{
auto
saved
=
ctx
->
get_saved_variables
();
auto
waveform
=
saved
[
0
];
auto
a_coeffs
=
saved
[
1
];
auto
b_coeffs
=
saved
[
2
];
auto
y
=
saved
[
3
];
auto
xh
=
saved
[
4
];
auto
device
=
waveform
.
device
();
auto
dtype
=
waveform
.
dtype
();
int64_t
n_channel
=
waveform
.
size
(
0
);
int64_t
n_sample
=
waveform
.
size
(
1
);
int64_t
n_order
=
a_coeffs
.
size
(
0
);
int64_t
n_sample_padded
=
n_sample
+
n_order
-
1
;
auto
a_coeff_flipped
=
a_coeffs
.
flip
(
0
).
contiguous
();
auto
b_coeff_flipped
=
b_coeffs
.
flip
(
0
).
contiguous
();
b_coeff_flipped
.
div_
(
a_coeffs
[
0
]);
a_coeff_flipped
.
div_
(
a_coeffs
[
0
]);
auto
dx
=
torch
::
Tensor
();
auto
da
=
torch
::
Tensor
();
auto
db
=
torch
::
Tensor
();
auto
dy
=
grad_outputs
[
0
];
at
::
AutoNonVariableTypeMode
g
;
namespace
F
=
torch
::
nn
::
functional
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
device
);
if
(
a_coeffs
.
requires_grad
())
{
auto
dyda
=
torch
::
zeros
({
n_channel
,
n_sample_padded
},
options
);
if
(
device
.
is_cpu
())
{
cpu_lfilter_core_loop
(
-
y
,
a_coeff_flipped
,
dyda
);
}
else
{
lfilter_core_generic_loop
(
-
y
,
a_coeff_flipped
,
dyda
);
}
da
=
F
::
conv1d
(
dyda
.
unsqueeze
(
0
),
dy
.
unsqueeze
(
1
),
F
::
Conv1dFuncOptions
().
groups
(
n_channel
))
.
sum
(
1
)
.
squeeze
(
0
)
.
flip
(
0
);
da
.
div_
(
a_coeffs
[
0
]);
}
if
(
b_coeffs
.
requires_grad
()
||
waveform
.
requires_grad
())
{
auto
dxh
=
torch
::
zeros
({
n_channel
,
n_sample_padded
},
options
);
if
(
device
.
is_cpu
())
{
cpu_lfilter_core_loop
(
dy
.
flip
(
1
),
a_coeff_flipped
,
dxh
);
}
else
{
lfilter_core_generic_loop
(
dy
.
flip
(
1
),
a_coeff_flipped
,
dxh
);
}
dxh
=
dxh
.
index
(
{
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(
n_order
-
1
,
torch
::
indexing
::
None
)})
.
flip
(
1
);
if
(
waveform
.
requires_grad
())
{
dx
=
F
::
conv1d
(
F
::
pad
(
dxh
.
unsqueeze
(
1
),
F
::
PadFuncOptions
({
0
,
n_order
-
1
})),
b_coeffs
.
view
({
1
,
1
,
n_order
}))
.
squeeze
(
1
);
dx
.
div_
(
a_coeffs
[
0
]);
}
if
(
b_coeffs
.
requires_grad
())
{
db
=
F
::
conv1d
(
F
::
pad
(
waveform
.
unsqueeze
(
0
),
F
::
PadFuncOptions
({
n_order
-
1
,
0
})),
dxh
.
unsqueeze
(
1
),
F
::
Conv1dFuncOptions
().
groups
(
n_channel
))
.
sum
(
1
)
.
squeeze
(
0
)
.
flip
(
0
);
db
.
div_
(
a_coeffs
[
0
]);
}
}
return
{
dx
,
da
,
db
};
}
};
torch
::
Tensor
lfilter_autograd
(
const
torch
::
Tensor
&
waveform
,
const
torch
::
Tensor
&
a_coeffs
,
const
torch
::
Tensor
&
b_coeffs
)
{
return
DifferentiableLfilter
::
apply
(
waveform
,
a_coeffs
,
b_coeffs
);
}
}
}
// namespace
}
// namespace
...
@@ -139,6 +259,10 @@ TORCH_LIBRARY(torchaudio, m) {
...
@@ -139,6 +259,10 @@ TORCH_LIBRARY(torchaudio, m) {
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"
);
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"
);
}
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
Math
,
m
)
{
TORCH_LIBRARY_IMPL
(
torchaudio
,
DefaultBackend
,
m
)
{
m
.
impl
(
"torchaudio::_lfilter"
,
lfilter_core
);
m
.
impl
(
"torchaudio::_lfilter"
,
lfilter_simple
);
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
Autograd
,
m
)
{
m
.
impl
(
"torchaudio::_lfilter"
,
lfilter_autograd
);
}
}
torchaudio/functional/filtering.py
View file @
2a3d52ff
...
@@ -884,6 +884,10 @@ def lfilter(
...
@@ -884,6 +884,10 @@ def lfilter(
)
->
Tensor
:
)
->
Tensor
:
r
"""Perform an IIR filter by evaluating difference equation.
r
"""Perform an IIR filter by evaluating difference equation.
Note:
To avoid numerical problems, small filter order is prefered.
Using double precision could also minimize numerical precision errors.
Args:
Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment