Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
8094751f
Unverified
Commit
8094751f
authored
Aug 11, 2021
by
Chin-Yun Yu
Committed by
GitHub
Aug 10, 2021
Browse files
Add batch support to lfilter (#1638)
parent
15bc554f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
2 deletions
+36
-2
test/torchaudio_unittest/functional/autograd_impl.py
test/torchaudio_unittest/functional/autograd_impl.py
+10
-0
test/torchaudio_unittest/functional/batch_consistency_test.py
.../torchaudio_unittest/functional/batch_consistency_test.py
+15
-0
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+1
-1
torchaudio/functional/filtering.py
torchaudio/functional/filtering.py
+10
-1
No files found.
test/torchaudio_unittest/functional/autograd_impl.py
View file @
8094751f
from
typing
import
Callable
,
Tuple
from
functools
import
partial
import
torch
from
parameterized
import
parameterized
from
torch
import
Tensor
...
...
@@ -62,6 +63,15 @@ class Autograd(TestBaseMixin):
def
test_lfilter_filterbanks
(
self
):
torch
.
random
.
manual_seed
(
2434
)
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
3
)
a
=
torch
.
tensor
([[
0.7
,
0.2
,
0.6
],
[
0.8
,
0.2
,
0.9
]])
b
=
torch
.
tensor
([[
0.4
,
0.2
,
0.9
],
[
0.7
,
0.2
,
0.6
]])
self
.
assert_grad
(
partial
(
F
.
lfilter
,
batching
=
False
),
(
x
,
a
,
b
))
def
test_lfilter_batching
(
self
):
torch
.
random
.
manual_seed
(
2434
)
x
=
get_whitenoise
(
sample_rate
=
22050
,
duration
=
0.01
,
n_channels
=
2
)
a
=
torch
.
tensor
([[
0.7
,
0.2
,
0.6
],
[
0.8
,
0.2
,
0.9
]])
b
=
torch
.
tensor
([[
0.4
,
0.2
,
0.9
],
...
...
test/torchaudio_unittest/functional/batch_consistency_test.py
View file @
8094751f
...
...
@@ -217,3 +217,18 @@ class TestFunctional(common_utils.TorchaudioTestCase):
batch
=
waveform
.
view
(
self
.
batch_size
,
n_channels
,
waveform
.
size
(
-
1
))
self
.
assert_batch_consistency
(
F
.
compute_kaldi_pitch
,
batch
,
sample_rate
=
sample_rate
)
def
test_lfilter
(
self
):
signal_length
=
2048
torch
.
manual_seed
(
2434
)
x
=
torch
.
randn
(
self
.
batch_size
,
signal_length
)
a
=
torch
.
rand
(
self
.
batch_size
,
3
)
b
=
torch
.
rand
(
self
.
batch_size
,
3
)
batchwise_output
=
F
.
lfilter
(
x
,
a
,
b
,
batching
=
True
)
itemwise_output
=
torch
.
stack
([
F
.
lfilter
(
x
[
i
],
a
[
i
],
b
[
i
])
for
i
in
range
(
self
.
batch_size
)
])
self
.
assertEqual
(
batchwise_output
,
itemwise_output
)
test/torchaudio_unittest/functional/functional_impl.py
View file @
8094751f
...
...
@@ -80,7 +80,7 @@ class Functional(TestBaseMixin):
waveform
=
torch
.
rand
(
*
input_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
b_coeffs
=
torch
.
rand
(
*
coeff_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
a_coeffs
=
torch
.
rand
(
*
coeff_shape
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
output_waveform
=
F
.
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
)
output_waveform
=
F
.
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
,
batching
=
False
)
assert
input_shape
==
waveform
.
size
()
assert
target_shape
==
output_waveform
.
size
()
...
...
torchaudio/functional/filtering.py
View file @
8094751f
...
...
@@ -930,6 +930,7 @@ def lfilter(
a_coeffs
:
Tensor
,
b_coeffs
:
Tensor
,
clamp
:
bool
=
True
,
batching
:
bool
=
True
)
->
Tensor
:
r
"""Perform an IIR filter by evaluating difference equation.
...
...
@@ -948,6 +949,10 @@ def lfilter(
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
batching (bool, optional): Activate when coefficients are in 2D. If ``True``, then waveform should be at least
2D, and the size of second axis from last should equals to ``num_filters``.
The output can be expressed as ``output[..., i, :] = lfilter(waveform[..., i, :],
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
Returns:
Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs``
...
...
@@ -957,7 +962,11 @@ def lfilter(
assert
a_coeffs
.
ndim
<=
2
if
a_coeffs
.
ndim
>
1
:
waveform
=
torch
.
stack
([
waveform
]
*
a_coeffs
.
shape
[
0
],
-
2
)
if
batching
:
assert
waveform
.
ndim
>
1
assert
waveform
.
shape
[
-
2
]
==
a_coeffs
.
shape
[
0
]
else
:
waveform
=
torch
.
stack
([
waveform
]
*
a_coeffs
.
shape
[
0
],
-
2
)
else
:
a_coeffs
=
a_coeffs
.
unsqueeze
(
0
)
b_coeffs
=
b_coeffs
.
unsqueeze
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment