Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
723e9a52
Unverified
Commit
723e9a52
authored
May 07, 2021
by
Chin-Yun Yu
Committed by
GitHub
May 06, 2021
Browse files
Support higher order derivatives for `F.lfilter` (#1441)
parent
5417e4fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
110 additions
and
124 deletions
+110
-124
test/torchaudio_unittest/functional/autograd_impl.py
test/torchaudio_unittest/functional/autograd_impl.py
+2
-1
torchaudio/csrc/lfilter.cpp
torchaudio/csrc/lfilter.cpp
+108
-123
No files found.
test/torchaudio_unittest/functional/autograd_impl.py
View file @
723e9a52
...
@@ -3,7 +3,7 @@ import torch
...
@@ -3,7 +3,7 @@ import torch
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torch
import
Tensor
from
torch
import
Tensor
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
,
gradgradcheck
from
torchaudio_unittest.common_utils
import
(
from
torchaudio_unittest.common_utils
import
(
TestBaseMixin
,
TestBaseMixin
,
get_whitenoise
,
get_whitenoise
,
...
@@ -26,6 +26,7 @@ class Autograd(TestBaseMixin):
...
@@ -26,6 +26,7 @@ class Autograd(TestBaseMixin):
i
.
requires_grad
=
True
i
.
requires_grad
=
True
inputs_
.
append
(
i
)
inputs_
.
append
(
i
)
assert
gradcheck
(
transform
,
inputs_
)
assert
gradcheck
(
transform
,
inputs_
)
assert
gradgradcheck
(
transform
,
inputs_
)
def
test_lfilter_x
(
self
):
def
test_lfilter_x
(
self
):
torch
.
random
.
manual_seed
(
2434
)
torch
.
random
.
manual_seed
(
2434
)
...
...
torchaudio/csrc/lfilter.cpp
View file @
723e9a52
...
@@ -80,170 +80,159 @@ void lfilter_core_generic_loop(
...
@@ -80,170 +80,159 @@ void lfilter_core_generic_loop(
}
}
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
lfilter_core
(
class
DifferentiableIIR
:
public
torch
::
autograd
::
Function
<
DifferentiableIIR
>
{
const
torch
::
Tensor
&
waveform
,
public:
const
torch
::
Tensor
&
a_coeffs
,
static
torch
::
Tensor
forward
(
const
torch
::
Tensor
&
b_coeffs
)
{
torch
::
autograd
::
AutogradContext
*
ctx
,
TORCH_CHECK
(
waveform
.
device
()
==
a_coeffs
.
device
());
const
torch
::
Tensor
&
waveform
,
TORCH_CHECK
(
b_coeffs
.
device
()
==
a_coeffs
.
device
());
const
torch
::
Tensor
&
a_coeffs_normalized
)
{
TORCH_CHECK
(
a_coeffs
.
size
(
0
)
==
b_coeffs
.
size
(
0
));
auto
device
=
waveform
.
device
();
auto
dtype
=
waveform
.
dtype
();
int64_t
n_channel
=
waveform
.
size
(
0
);
int64_t
n_sample
=
waveform
.
size
(
1
);
int64_t
n_order
=
a_coeffs_normalized
.
size
(
0
);
int64_t
n_sample_padded
=
n_sample
+
n_order
-
1
;
TORCH_INTERNAL_ASSERT
(
wavef
orm
.
s
ize
s
().
size
()
==
2
);
auto
a_coeff_flipped
=
a_coeffs_n
orm
al
ize
d
.
flip
(
0
).
contiguous
(
);
auto
device
=
waveform
.
device
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
device
);
int64_t
n_order
=
a_coeffs
.
size
(
0
);
auto
padded_output_waveform
=
torch
::
zeros
({
n_channel
,
n_sample_padded
},
options
);
if
(
device
.
is_cpu
())
{
cpu_lfilter_core_loop
(
waveform
,
a_coeff_flipped
,
padded_output_waveform
);
}
else
{
lfilter_core_generic_loop
(
waveform
,
a_coeff_flipped
,
padded_output_waveform
);
}
TORCH_INTERNAL_ASSERT
(
n_order
>
0
);
auto
output
=
padded_output_waveform
.
index
(
{
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(
n_order
-
1
,
torch
::
indexing
::
None
)});
namespace
F
=
torch
::
nn
::
functional
;
ctx
->
save_for_backward
({
waveform
,
a_coeffs_normalized
,
output
});
return
output
;
}
auto
padded_waveform
=
F
::
pad
(
waveform
,
F
::
PadFuncOptions
({
n_order
-
1
,
0
}));
static
torch
::
autograd
::
tensor_list
backward
(
auto
padded_output_waveform
=
torch
::
zeros_like
(
padded_waveform
);
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
autograd
::
tensor_list
grad_outputs
)
{
auto
saved
=
ctx
->
get_saved_variables
();
auto
x
=
saved
[
0
];
auto
a_coeffs_normalized
=
saved
[
1
];
auto
y
=
saved
[
2
];
auto
a_coeff_flipped
=
a_coeffs
.
flip
(
0
).
contiguous
(
);
int64_t
n_channel
=
x
.
size
(
0
);
auto
b_coeff_flipped
=
b
_coeffs
.
flip
(
0
).
contiguous
(
);
int64_t
n_order
=
a
_coeffs
_normalized
.
size
(
0
);
auto
input_signal_windows
=
auto
dx
=
torch
::
Tensor
();
F
::
conv1d
(
auto
da
=
torch
::
Tensor
();
padded_waveform
.
unsqueeze
(
1
),
b_coeff_flipped
.
view
({
1
,
1
,
n_order
}))
auto
dy
=
grad_outputs
[
0
];
.
squeeze
(
1
);
input_signal_windows
.
div_
(
a_coeffs
[
0
]);
namespace
F
=
torch
::
nn
::
functional
;
a_coeff_flipped
.
div_
(
a_coeffs
[
0
]);
if
(
device
.
is_cpu
())
{
if
(
a_coeffs_normalized
.
requires_grad
())
{
cpu_lfilter_core_loop
(
auto
dyda
=
F
::
pad
(
input_signal_windows
,
a_coeff_flipped
,
padded_output_waveform
);
DifferentiableIIR
::
apply
(
-
y
,
a_coeffs_normalized
),
}
else
{
F
::
PadFuncOptions
({
n_order
-
1
,
0
}));
lfilter_core_generic_loop
(
input_signal_windows
,
a_coeff_flipped
,
padded_output_waveform
);
}
auto
output
=
padded_output_waveform
.
index
(
da
=
F
::
conv1d
(
{
torch
::
indexing
::
Slice
(),
dyda
.
unsqueeze
(
0
),
torch
::
indexing
::
Slice
(
n_order
-
1
,
torch
::
indexing
::
None
)});
dy
.
unsqueeze
(
1
),
F
::
Conv1dFuncOptions
().
groups
(
n_channel
))
.
sum
(
1
)
.
squeeze
(
0
)
.
flip
(
0
);
}
return
{
output
,
input_signal_windows
};
if
(
x
.
requires_grad
())
{
}
dx
=
DifferentiableIIR
::
apply
(
dy
.
flip
(
1
),
a_coeffs_normalized
).
flip
(
1
);
}
torch
::
Tensor
lfilter_simple
(
return
{
dx
,
da
};
const
torch
::
Tensor
&
waveform
,
}
const
torch
::
Tensor
&
a_coeffs
,
};
const
torch
::
Tensor
&
b_coeffs
)
{
return
std
::
get
<
0
>
(
lfilter_core
(
waveform
,
a_coeffs
,
b_coeffs
));
}
class
DifferentiableLfilter
class
DifferentiableFIR
:
public
torch
::
autograd
::
Function
<
DifferentiableFIR
>
{
:
public
torch
::
autograd
::
Function
<
DifferentiableLfilter
>
{
public:
public:
static
torch
::
Tensor
forward
(
static
torch
::
Tensor
forward
(
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
autograd
::
AutogradContext
*
ctx
,
const
torch
::
Tensor
&
waveform
,
const
torch
::
Tensor
&
waveform
,
const
torch
::
Tensor
&
a_coeffs
,
const
torch
::
Tensor
&
b_coeffs
)
{
const
torch
::
Tensor
&
b_coeffs
)
{
at
::
AutoNonVariableTypeMode
g
;
int64_t
n_order
=
b_coeffs
.
size
(
0
);
auto
result
=
lfilter_core
(
waveform
,
a_coeffs
,
b_coeffs
);
ctx
->
save_for_backward
(
namespace
F
=
torch
::
nn
::
functional
;
{
waveform
,
auto
b_coeff_flipped
=
b_coeffs
.
flip
(
0
).
contiguous
();
a_coeffs
,
auto
padded_waveform
=
b_coeffs
,
F
::
pad
(
waveform
,
F
::
PadFuncOptions
({
n_order
-
1
,
0
}));
std
::
get
<
0
>
(
result
),
std
::
get
<
1
>
(
result
)});
auto
output
=
return
std
::
get
<
0
>
(
result
);
F
::
conv1d
(
padded_waveform
.
unsqueeze
(
1
),
b_coeff_flipped
.
view
({
1
,
1
,
n_order
}))
.
squeeze
(
1
);
ctx
->
save_for_backward
({
waveform
,
b_coeffs
,
output
});
return
output
;
}
}
static
torch
::
autograd
::
tensor_list
backward
(
static
torch
::
autograd
::
tensor_list
backward
(
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
autograd
::
tensor_list
grad_outputs
)
{
torch
::
autograd
::
tensor_list
grad_outputs
)
{
auto
saved
=
ctx
->
get_saved_variables
();
auto
saved
=
ctx
->
get_saved_variables
();
auto
waveform
=
saved
[
0
];
auto
x
=
saved
[
0
];
auto
a_coeffs
=
saved
[
1
];
auto
b_coeffs
=
saved
[
1
];
auto
b_coeffs
=
saved
[
2
];
auto
y
=
saved
[
2
];
auto
y
=
saved
[
3
];
auto
xh
=
saved
[
4
];
auto
device
=
waveform
.
device
();
auto
dtype
=
waveform
.
dtype
();
int64_t
n_channel
=
waveform
.
size
(
0
);
int64_t
n_sample
=
waveform
.
size
(
1
);
int64_t
n_order
=
a_coeffs
.
size
(
0
);
int64_t
n_sample_padded
=
n_sample
+
n_order
-
1
;
auto
a_coeff_flipped
=
a_coeffs
.
flip
(
0
).
contiguous
();
int64_t
n_channel
=
x
.
size
(
0
);
auto
b_coeff_flipped
=
b_coeffs
.
flip
(
0
).
contiguous
();
int64_t
n_order
=
b_coeffs
.
size
(
0
);
b_coeff_flipped
.
div_
(
a_coeffs
[
0
]);
a_coeff_flipped
.
div_
(
a_coeffs
[
0
]);
auto
dx
=
torch
::
Tensor
();
auto
dx
=
torch
::
Tensor
();
auto
da
=
torch
::
Tensor
();
auto
db
=
torch
::
Tensor
();
auto
db
=
torch
::
Tensor
();
auto
dy
=
grad_outputs
[
0
];
auto
dy
=
grad_outputs
[
0
];
at
::
AutoNonVariableTypeMode
g
;
namespace
F
=
torch
::
nn
::
functional
;
namespace
F
=
torch
::
nn
::
functional
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
device
);
if
(
a_coeffs
.
requires_grad
())
{
if
(
b_coeffs
.
requires_grad
())
{
auto
dyda
=
torch
::
zeros
({
n_channel
,
n_sample_padded
},
options
);
db
=
F
::
conv1d
(
if
(
device
.
is_cpu
())
{
F
::
pad
(
x
.
unsqueeze
(
0
),
F
::
PadFuncOptions
({
n_order
-
1
,
0
})),
cpu_lfilter_core_loop
(
-
y
,
a_coeff_flipped
,
dyda
);
}
else
{
lfilter_core_generic_loop
(
-
y
,
a_coeff_flipped
,
dyda
);
}
da
=
F
::
conv1d
(
dyda
.
unsqueeze
(
0
),
dy
.
unsqueeze
(
1
),
dy
.
unsqueeze
(
1
),
F
::
Conv1dFuncOptions
().
groups
(
n_channel
))
F
::
Conv1dFuncOptions
().
groups
(
n_channel
))
.
sum
(
1
)
.
sum
(
1
)
.
squeeze
(
0
)
.
squeeze
(
0
)
.
flip
(
0
);
.
flip
(
0
);
da
.
div_
(
a_coeffs
[
0
]);
}
}
if
(
b_coeffs
.
requires_grad
()
||
waveform
.
requires_grad
())
{
if
(
x
.
requires_grad
())
{
auto
dxh
=
torch
::
zeros
({
n_channel
,
n_sample_padded
},
options
);
dx
=
F
::
conv1d
(
if
(
device
.
is_cpu
())
{
F
::
pad
(
dy
.
unsqueeze
(
1
),
F
::
PadFuncOptions
({
0
,
n_order
-
1
})),
cpu_lfilter_core_loop
(
dy
.
flip
(
1
),
a_coeff_flipped
,
dxh
);
b_coeffs
.
view
({
1
,
1
,
n_order
}))
}
else
{
.
squeeze
(
1
);
lfilter_core_generic_loop
(
dy
.
flip
(
1
),
a_coeff_flipped
,
dxh
);
}
dxh
=
dxh
.
index
(
{
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(
n_order
-
1
,
torch
::
indexing
::
None
)})
.
flip
(
1
);
if
(
waveform
.
requires_grad
())
{
dx
=
F
::
conv1d
(
F
::
pad
(
dxh
.
unsqueeze
(
1
),
F
::
PadFuncOptions
({
0
,
n_order
-
1
})),
b_coeffs
.
view
({
1
,
1
,
n_order
}))
.
squeeze
(
1
);
dx
.
div_
(
a_coeffs
[
0
]);
}
if
(
b_coeffs
.
requires_grad
())
{
db
=
F
::
conv1d
(
F
::
pad
(
waveform
.
unsqueeze
(
0
),
F
::
PadFuncOptions
({
n_order
-
1
,
0
})),
dxh
.
unsqueeze
(
1
),
F
::
Conv1dFuncOptions
().
groups
(
n_channel
))
.
sum
(
1
)
.
squeeze
(
0
)
.
flip
(
0
);
db
.
div_
(
a_coeffs
[
0
]);
}
}
}
return
{
dx
,
da
,
db
};
return
{
dx
,
db
};
}
}
};
};
torch
::
Tensor
lfilter_
autograd
(
torch
::
Tensor
lfilter_
core
(
const
torch
::
Tensor
&
waveform
,
const
torch
::
Tensor
&
waveform
,
const
torch
::
Tensor
&
a_coeffs
,
const
torch
::
Tensor
&
a_coeffs
,
const
torch
::
Tensor
&
b_coeffs
)
{
const
torch
::
Tensor
&
b_coeffs
)
{
return
DifferentiableLfilter
::
apply
(
waveform
,
a_coeffs
,
b_coeffs
);
TORCH_CHECK
(
waveform
.
device
()
==
a_coeffs
.
device
());
TORCH_CHECK
(
b_coeffs
.
device
()
==
a_coeffs
.
device
());
TORCH_CHECK
(
a_coeffs
.
size
(
0
)
==
b_coeffs
.
size
(
0
));
TORCH_INTERNAL_ASSERT
(
waveform
.
sizes
().
size
()
==
2
);
int64_t
n_order
=
b_coeffs
.
size
(
0
);
TORCH_INTERNAL_ASSERT
(
n_order
>
0
);
auto
filtered_waveform
=
DifferentiableFIR
::
apply
(
waveform
,
b_coeffs
/
a_coeffs
[
0
]);
auto
output
=
DifferentiableIIR
::
apply
(
filtered_waveform
,
a_coeffs
/
a_coeffs
[
0
]);
return
output
;
}
}
}
// namespace
}
// namespace
...
@@ -259,10 +248,6 @@ TORCH_LIBRARY(torchaudio, m) {
...
@@ -259,10 +248,6 @@ TORCH_LIBRARY(torchaudio, m) {
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"
);
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"
);
}
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
DefaultBackend
,
m
)
{
TORCH_LIBRARY_IMPL
(
torchaudio
,
CompositeImplicitAutograd
,
m
)
{
m
.
impl
(
"torchaudio::_lfilter"
,
lfilter_simple
);
m
.
impl
(
"torchaudio::_lfilter"
,
lfilter_core
);
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
Autograd
,
m
)
{
m
.
impl
(
"torchaudio::_lfilter"
,
lfilter_autograd
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment