Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ac97ad82
Unverified
Commit
ac97ad82
authored
Sep 20, 2021
by
nateanl
Committed by
GitHub
Sep 20, 2021
Browse files
Move MVDR and PSD modules to transforms (#1771)
parent
88ca1e05
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
235 additions
and
729 deletions
+235
-729
docs/source/refs.bib
docs/source/refs.bib
+39
-1
docs/source/transforms.rst
docs/source/transforms.rst
+17
-0
examples/beamforming/mvdr.py
examples/beamforming/mvdr.py
+0
-446
test/torchaudio_unittest/common_utils/psd_utils.py
test/torchaudio_unittest/common_utils/psd_utils.py
+27
-0
test/torchaudio_unittest/example/beamforming/__init__.py
test/torchaudio_unittest/example/beamforming/__init__.py
+0
-0
test/torchaudio_unittest/example/beamforming/autograd_cpu_test.py
...chaudio_unittest/example/beamforming/autograd_cpu_test.py
+0
-10
test/torchaudio_unittest/example/beamforming/autograd_cuda_test.py
...haudio_unittest/example/beamforming/autograd_cuda_test.py
+0
-15
test/torchaudio_unittest/example/beamforming/autograd_test_impl.py
...haudio_unittest/example/beamforming/autograd_test_impl.py
+0
-70
test/torchaudio_unittest/example/beamforming/batch_consistency_test.py
...io_unittest/example/beamforming/batch_consistency_test.py
+0
-59
test/torchaudio_unittest/example/beamforming/torchscript_consistency_cpu_test.py
...t/example/beamforming/torchscript_consistency_cpu_test.py
+0
-14
test/torchaudio_unittest/example/beamforming/torchscript_consistency_cuda_test.py
.../example/beamforming/torchscript_consistency_cuda_test.py
+0
-16
test/torchaudio_unittest/example/beamforming/torchscript_consistency_impl.py
...ttest/example/beamforming/torchscript_consistency_impl.py
+0
-57
test/torchaudio_unittest/example/beamforming/transforms_cpu_test.py
...audio_unittest/example/beamforming/transforms_cpu_test.py
+0
-14
test/torchaudio_unittest/example/beamforming/transforms_cuda_test.py
...udio_unittest/example/beamforming/transforms_cuda_test.py
+0
-19
test/torchaudio_unittest/transforms/autograd_test_impl.py
test/torchaudio_unittest/transforms/autograd_test_impl.py
+36
-0
test/torchaudio_unittest/transforms/batch_consistency_test.py
.../torchaudio_unittest/transforms/batch_consistency_test.py
+51
-0
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
...o_unittest/transforms/torchscript_consistency_cpu_test.py
+2
-2
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
..._unittest/transforms/torchscript_consistency_cuda_test.py
+2
-2
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
...audio_unittest/transforms/torchscript_consistency_impl.py
+37
-4
test/torchaudio_unittest/transforms/transforms_test_impl.py
test/torchaudio_unittest/transforms/transforms_test_impl.py
+24
-0
No files found.
docs/source/refs.bib
View file @
ac97ad82
...
...
@@ -95,4 +95,42 @@
pages
=
{4779--4783}
,
year
=
{2018}
,
organization
=
{IEEE}
}
\ No newline at end of file
}
@inproceedings
{
souden2009optimal
,
title
=
{On optimal frequency-domain multichannel linear filtering for noise reduction}
,
author
=
{Souden, Mehrez and Benesty, Jacob and Affes, Sofiene}
,
booktitle
=
{IEEE Transactions on audio, speech, and language processing}
,
volume
=
{18}
,
number
=
{2}
,
pages
=
{260--276}
,
year
=
{2009}
,
publisher
=
{IEEE}
}
@inproceedings
{
higuchi2016robust
,
title
=
{Robust MVDR beamforming using time-frequency masks for online/offline ASR in noise}
,
author
=
{Higuchi, Takuya and Ito, Nobutaka and Yoshioka, Takuya and Nakatani, Tomohiro}
,
booktitle
=
{2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}
,
pages
=
{5210--5214}
,
year
=
{2016}
,
organization
=
{IEEE}
}
@article
{
mises1929praktische
,
title
=
{Praktische Verfahren der Gleichungsaufl{\"o}sung.}
,
author
=
{Mises, RV and Pollaczek-Geiringer, Hilda}
,
journal
=
{ZAMM-Journal of Applied Mathematics and Mechanics/Zeitschrift f{\"u}r Angewandte Mathematik und Mechanik}
,
volume
=
{9}
,
number
=
{1}
,
pages
=
{58--77}
,
year
=
{1929}
,
publisher
=
{Wiley Online Library}
}
@article
{
higuchi2017online
,
title
=
{Online MVDR beamformer based on complex Gaussian mixture model with spatial prior for noise robust ASR}
,
author
=
{Higuchi, Takuya and Ito, Nobutaka and Araki, Shoko and Yoshioka, Takuya and Delcroix, Marc and Nakatani, Tomohiro}
,
journal
=
{IEEE/ACM Transactions on Audio, Speech, and Language Processing}
,
volume
=
{25}
,
number
=
{4}
,
pages
=
{780--793}
,
year
=
{2017}
,
publisher
=
{IEEE}
}
docs/source/transforms.rst
View file @
ac97ad82
...
...
@@ -188,6 +188,23 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward
:hidden:`Multi-channel`
~~~~~~~~~~~~~~~~~~~~~~~
:hidden:`PSD`
-------------
.. autoclass:: PSD
.. automethod:: forward
:hidden:`MVDR`
--------------
.. autoclass:: MVDR
.. automethod:: forward
References
~~~~~~~~~~
...
...
examples/beamforming/mvdr.py
deleted
100644 → 0
View file @
88ca1e05
"""Implementation of MVDR Beamforming Module
Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py
We provide three solutions of MVDR beamforming. One is based on reference channel selection:
Souden, Mehrez, Jacob Benesty, and Sofiene Affes.
"On optimal frequency-domain multichannel linear filtering for noise reduction."
IEEE Transactions on audio, speech, and language processing 18.2 (2009): 260-276.
The other two solutions are based on the steering vector. We apply either eigenvalue decomposition
or the power method to get the steering vector from the PSD matrices.
For eigenvalue decomposistion method, please refer:
Higuchi, Takuya, et al. "Robust MVDR beamforming using time-frequency masks for online/offline ASR in noise."
2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2016.
For power method, please refer:
Mises, R. V., and Hilda Pollaczek‐Geiringer.
"Praktische Verfahren der Gleichungsauflösung."
ZAMM‐Journal of Applied Mathematics and Mechanics/Zeitschrift für Angewandte Mathematik und Mechanik 9.1 (1929): 58-77.
For online streaming audio, we provide a recursive method to update PSD matrices based on:
Higuchi, Takuya, et al.
"Online MVDR beamformer based on complex Gaussian mixture model with spatial prior for noise robust ASR."
IEEE/ACM Transactions on Audio, Speech, and Language Processing 25.4 (2017): 780-793.
"""
from
typing
import
Optional
import
torch
def
mat_trace
(
input
:
torch
.
Tensor
,
dim1
:
int
=
-
1
,
dim2
:
int
=
-
2
)
->
torch
.
Tensor
:
r
"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension (..., channel, channel)
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
Returns:
torch.Tensor: trace of the input Tensor
"""
assert
input
.
ndim
>=
2
,
"The dimension of the tensor must be at least 2."
assert
input
.
shape
[
dim1
]
==
input
.
shape
[
dim2
],
\
"The size of ``dim1`` and ``dim2`` must be the same."
input
=
torch
.
diagonal
(
input
,
0
,
dim1
=
dim1
,
dim2
=
dim2
)
return
input
.
sum
(
dim
=-
1
)
class
PSD
(
torch
.
nn
.
Module
):
r
"""Compute cross-channel power spectral density (PSD) matrix.
Args:
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks (Default: ``False``)
normalize (bool, optional): whether normalize the mask along the time dimension
eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-15
"""
def
__init__
(
self
,
multi_mask
:
bool
=
False
,
normalize
:
bool
=
True
,
eps
:
float
=
1e-15
):
super
().
__init__
()
self
.
multi_mask
=
multi_mask
self
.
normalize
=
normalize
self
.
eps
=
eps
def
forward
(
self
,
X
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
Args:
X (torch.Tensor): multi-channel complex-valued STFT matrix.
Tensor of dimension (..., channel, freq, time)
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension (..., freq, time) if multi_mask is ``False`` or
of dimension (..., channel, freq, time) if multi_mask is ``True``
Returns:
torch.Tensor: PSD matrix of the input STFT matrix.
Tensor of dimension (..., freq, channel, channel)
"""
# outer product:
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
psd_X
=
torch
.
einsum
(
"...cft,...eft->...ftce"
,
[
X
,
X
.
conj
()])
if
mask
is
None
:
psd
=
psd_X
else
:
if
self
.
multi_mask
:
# Averaging mask along channel dimension
mask
=
mask
.
mean
(
dim
=-
3
)
# (..., freq, time)
# Normalized mask along time dimension:
if
self
.
normalize
:
mask
=
mask
/
(
mask
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
psd
=
psd_X
*
mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
psd
=
psd
.
sum
(
dim
=-
3
)
return
psd
class
MVDR
(
torch
.
nn
.
Module
):
"""MVDR module that performs MVDR beamforming with Time-Frequency masks.
Args:
ref_channel (int, optional): the reference channel for beamforming. (Default: ``0``)
solution (str, optional): the solution to get MVDR weight.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks (Default: ``False``)
diag_loading (bool, optional): whether apply diagonal loading on the psd matrix of noise
(Default: ``True``)
diag_eps (float, optional): the coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
online (bool, optional): whether to update the mvdr vector based on the previous psd matrices.
(Default: ``False``)
Note:
If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
eigenvalues of the PSD matrix are not distinct (i.e. some eigenvalues are close or identical).
"""
def
__init__
(
self
,
ref_channel
:
int
=
0
,
solution
:
str
=
"ref_channel"
,
multi_mask
:
bool
=
False
,
diag_loading
:
bool
=
True
,
diag_eps
:
float
=
1e-7
,
online
:
bool
=
False
,
):
super
().
__init__
()
assert
solution
in
[
"ref_channel"
,
"stv_evd"
,
"stv_power"
],
\
"Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]."
self
.
ref_channel
=
ref_channel
self
.
solution
=
solution
self
.
multi_mask
=
multi_mask
self
.
diag_loading
=
diag_loading
self
.
diag_eps
=
diag_eps
self
.
online
=
online
self
.
psd
=
PSD
(
multi_mask
)
psd_s
:
torch
.
Tensor
=
torch
.
zeros
(
1
)
psd_n
:
torch
.
Tensor
=
torch
.
zeros
(
1
)
mask_sum_s
:
torch
.
Tensor
=
torch
.
zeros
(
1
)
mask_sum_n
:
torch
.
Tensor
=
torch
.
zeros
(
1
)
self
.
register_buffer
(
'psd_s'
,
psd_s
)
self
.
register_buffer
(
'psd_n'
,
psd_n
)
self
.
register_buffer
(
'mask_sum_s'
,
mask_sum_s
)
self
.
register_buffer
(
'mask_sum_n'
,
mask_sum_n
)
def
_get_updated_mvdr_vector
(
self
,
psd_s
:
torch
.
Tensor
,
psd_n
:
torch
.
Tensor
,
mask_s
:
torch
.
Tensor
,
mask_n
:
torch
.
Tensor
,
reference_vector
:
torch
.
Tensor
,
solution
:
str
=
'ref_channel'
,
diagonal_loading
:
bool
=
True
,
diag_eps
:
float
=
1e-7
,
eps
:
float
=
1e-8
,
)
->
torch
.
Tensor
:
r
"""Recursively update the MVDR beamforming vector.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
mask_s (torch.Tensor): T-F mask of target speech
mask_n (torch.Tensor): T-F mask of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
solution (str, optional): the solution to estimate the beamforming weight
(Default: ``ref_channel``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-8)
Returns:
torch.Tensor: the mvdr beamforming weight matrix
"""
if
self
.
multi_mask
:
# Averaging mask along channel dimension
mask_s
=
mask_s
.
mean
(
dim
=-
3
)
# (..., freq, time)
mask_n
=
mask_n
.
mean
(
dim
=-
3
)
# (..., freq, time)
if
self
.
psd_s
.
ndim
==
1
:
self
.
psd_s
=
psd_s
self
.
psd_n
=
psd_n
self
.
mask_sum_s
=
mask_s
.
sum
(
dim
=-
1
)
self
.
mask_sum_n
=
mask_n
.
sum
(
dim
=-
1
)
return
self
.
_get_mvdr_vector
(
psd_s
,
psd_n
,
reference_vector
,
solution
,
diagonal_loading
,
diag_eps
,
eps
)
else
:
psd_s
=
self
.
_get_updated_psd_speech
(
psd_s
,
mask_s
)
psd_n
=
self
.
_get_updated_psd_noise
(
psd_n
,
mask_n
)
self
.
psd_s
=
psd_s
self
.
psd_n
=
psd_n
self
.
mask_sum_s
=
self
.
mask_sum_s
+
mask_s
.
sum
(
dim
=-
1
)
self
.
mask_sum_n
=
self
.
mask_sum_n
+
mask_n
.
sum
(
dim
=-
1
)
return
self
.
_get_mvdr_vector
(
psd_s
,
psd_n
,
reference_vector
,
solution
,
diagonal_loading
,
diag_eps
,
eps
)
def
_get_updated_psd_speech
(
self
,
psd_s
:
torch
.
Tensor
,
mask_s
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Update psd of speech recursively.
Args:
psd_s (torch.Tensor): psd matrix of target speech
mask_s (torch.Tensor): T-F mask of target speech
Returns:
torch.Tensor: the updated psd of speech
"""
numerator
=
self
.
mask_sum_s
/
(
self
.
mask_sum_s
+
mask_s
.
sum
(
dim
=-
1
))
denominator
=
1
/
(
self
.
mask_sum_s
+
mask_s
.
sum
(
dim
=-
1
))
psd_s
=
self
.
psd_s
*
numerator
[...,
None
,
None
]
+
psd_s
*
denominator
[...,
None
,
None
]
return
psd_s
def
_get_updated_psd_noise
(
self
,
psd_n
:
torch
.
Tensor
,
mask_n
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Update psd of noise recursively.
Args:
psd_n (torch.Tensor): psd matrix of target noise
mask_n (torch.Tensor): T-F mask of target noise
Returns:
torch.Tensor: the updated psd of noise
"""
numerator
=
self
.
mask_sum_n
/
(
self
.
mask_sum_n
+
mask_n
.
sum
(
dim
=-
1
))
denominator
=
1
/
(
self
.
mask_sum_n
+
mask_n
.
sum
(
dim
=-
1
))
psd_n
=
self
.
psd_n
*
numerator
[...,
None
,
None
]
+
psd_n
*
denominator
[...,
None
,
None
]
return
psd_n
def
_get_mvdr_vector
(
self
,
psd_s
:
torch
.
Tensor
,
psd_n
:
torch
.
Tensor
,
reference_vector
:
torch
.
Tensor
,
solution
:
str
=
'ref_channel'
,
diagonal_loading
:
bool
=
True
,
diag_eps
:
float
=
1e-7
,
eps
:
float
=
1e-8
,
)
->
torch
.
Tensor
:
r
"""Compute beamforming vector by the reference channel selection method.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
solution (str, optional): the solution to estimate the beamforming weight
(Default: ``ref_channel``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-8
Returns:
torch.Tensor: the mvdr beamforming weight matrix
"""
if
diagonal_loading
:
psd_n
=
self
.
_tik_reg
(
psd_n
,
reg
=
diag_eps
,
eps
=
eps
)
if
solution
==
"ref_channel"
:
numerator
=
torch
.
linalg
.
solve
(
psd_n
,
psd_s
)
# psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws
=
numerator
/
(
mat_trace
(
numerator
)[...,
None
,
None
]
+
eps
)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_vector
=
torch
.
einsum
(
"...fec,...c->...fe"
,
[
ws
,
reference_vector
])
else
:
if
solution
==
"stv_evd"
:
stv
=
self
.
_get_steering_vector_evd
(
psd_s
)
else
:
stv
=
self
.
_get_steering_vector_power
(
psd_s
,
psd_n
,
reference_vector
)
# numerator = psd_n.inv() @ stv
numerator
=
torch
.
linalg
.
solve
(
psd_n
,
stv
).
squeeze
(
-
1
)
# (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
denominator
=
torch
.
einsum
(
"...d,...d->..."
,
[
stv
.
conj
().
squeeze
(
-
1
),
numerator
])
# normalzie the numerator
scale
=
stv
.
squeeze
(
-
1
)[...,
self
.
ref_channel
,
None
].
conj
()
beamform_vector
=
numerator
*
scale
/
(
denominator
.
real
.
unsqueeze
(
-
1
)
+
eps
)
return
beamform_vector
def
_get_steering_vector_evd
(
self
,
psd_s
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Estimate the steering vector by eigenvalue decomposition.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
"""
w
,
v
=
torch
.
linalg
.
eig
(
psd_s
)
# (..., freq, channel, channel)
_
,
indices
=
torch
.
max
(
w
.
abs
(),
dim
=-
1
,
keepdim
=
True
)
indices
=
indices
.
unsqueeze
(
-
1
)
stv
=
v
.
gather
(
-
1
,
indices
.
expand
(
psd_s
.
shape
[:
-
1
]
+
(
1
,)))
# (..., freq, channel, 1)
return
stv
def
_get_steering_vector_power
(
self
,
psd_s
:
torch
.
Tensor
,
psd_n
:
torch
.
Tensor
,
reference_vector
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Estimate the steering vector by the power method.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
psd_n (torch.Tensor): covariance matrix of noise
Tensor of dimension (..., freq, channel, channel)
reference_vector (torch.Tensor): one-hot reference channel matrix
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
"""
phi
=
torch
.
linalg
.
solve
(
psd_n
,
psd_s
)
# psd_n.inv() @ psd_s
stv
=
torch
.
einsum
(
"...fec,...c->...fe"
,
[
phi
,
reference_vector
])
stv
=
stv
.
unsqueeze
(
-
1
)
stv
=
torch
.
matmul
(
phi
,
stv
)
stv
=
torch
.
matmul
(
psd_s
,
stv
)
return
stv
def
_apply_beamforming_vector
(
self
,
X
:
torch
.
Tensor
,
beamform_vector
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Apply the beamforming weight to the noisy STFT
Args:
X (torch.tensor): multi-channel noisy STFT
Tensor of dimension (..., channel, freq, time)
beamform_vector (torch.Tensor): beamforming weight matrix
Tensor of dimension (..., freq, channel)
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, time)
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
Y
=
torch
.
einsum
(
"...fc,...cft->...ft"
,
[
beamform_vector
.
conj
(),
X
])
return
Y
def
_tik_reg
(
self
,
mat
:
torch
.
Tensor
,
reg
:
float
=
1e-7
,
eps
:
float
=
1e-8
)
->
torch
.
Tensor
:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: 1e-8)
Returns:
torch.Tensor: regularized matrix (..., channel, channel)
"""
# Add eps
C
=
mat
.
size
(
-
1
)
eye
=
torch
.
eye
(
C
,
dtype
=
mat
.
dtype
,
device
=
mat
.
device
)
with
torch
.
no_grad
():
epsilon
=
mat_trace
(
mat
).
real
[...,
None
,
None
]
*
reg
# in case that correlation_matrix is all-zero
epsilon
=
epsilon
+
eps
mat
=
mat
+
epsilon
*
eye
[...,
:,
:]
return
mat
def
forward
(
self
,
X
:
torch
.
Tensor
,
mask_s
:
torch
.
Tensor
,
mask_n
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Perform MVDR beamforming.
Args:
X (torch.Tensor): the multi-channel STF of the noisy speech.
Tensor of dimension (..., channel, freq, time)
mask_s (torch.Tensor): Time-Frequency mask of target speech
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
(Default: None)
Returns:
torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension (..., freq, time)
"""
if
X
.
ndim
<
3
:
raise
ValueError
(
f
"Expected at least 3D tensor (..., channel, freq, time). Found:
{
X
.
shape
}
"
)
if
X
.
dtype
!=
torch
.
cdouble
:
raise
ValueError
(
f
"The type of the input STFT tensor must be ``torch.cdouble``. Found:
{
X
.
dtype
}
"
)
if
mask_n
is
None
:
mask_n
=
1
-
mask_s
shape
=
X
.
size
()
# pack batch
X
=
X
.
reshape
(
-
1
,
shape
[
-
3
],
shape
[
-
2
],
shape
[
-
1
])
if
self
.
multi_mask
:
mask_s
=
mask_s
.
reshape
(
-
1
,
shape
[
-
3
],
shape
[
-
2
],
shape
[
-
1
])
mask_n
=
mask_n
.
reshape
(
-
1
,
shape
[
-
3
],
shape
[
-
2
],
shape
[
-
1
])
else
:
mask_s
=
mask_s
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
mask_n
=
mask_n
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
psd_s
=
self
.
psd
(
X
,
mask_s
)
# (..., freq, time, channel, channel)
psd_n
=
self
.
psd
(
X
,
mask_n
)
# (..., freq, time, channel, channel)
u
=
torch
.
zeros
(
X
.
size
()[:
-
2
],
device
=
X
.
device
,
dtype
=
torch
.
cdouble
)
# (..., channel)
u
[...,
self
.
ref_channel
].
fill_
(
1
)
if
self
.
online
:
w_mvdr
=
self
.
_get_updated_mvdr_vector
(
psd_s
,
psd_n
,
mask_s
,
mask_n
,
u
,
self
.
solution
,
self
.
diag_loading
,
self
.
diag_eps
)
else
:
w_mvdr
=
self
.
_get_mvdr_vector
(
psd_s
,
psd_n
,
u
,
self
.
solution
,
self
.
diag_loading
,
self
.
diag_eps
)
Y
=
self
.
_apply_beamforming_vector
(
X
,
w_mvdr
)
# unpack batch
Y
=
Y
.
reshape
(
shape
[:
-
3
]
+
shape
[
-
2
:])
return
Y
test/torchaudio_unittest/
example/beamforming/transforms_test_impl
.py
→
test/torchaudio_unittest/
common_utils/psd_utils
.py
View file @
ac97ad82
...
...
@@ -2,14 +2,6 @@ from typing import Optional
import
numpy
as
np
import
torch
from
beamforming.mvdr
import
PSD
from
parameterized
import
parameterized
,
param
from
torchaudio_unittest.common_utils
import
(
TestBaseMixin
,
get_whitenoise
,
get_spectrogram
,
)
def
psd_numpy
(
...
...
@@ -33,28 +25,3 @@ def psd_numpy(
psd
=
psd
.
sum
(
axis
=-
3
)
return
torch
.
tensor
(
psd
,
dtype
=
torch
.
cdouble
)
class
TransformsTestBase
(
TestBaseMixin
):
@
parameterized
.
expand
([
param
(
0.5
,
1
,
True
,
False
),
param
(
0.5
,
1
,
None
,
False
),
param
(
1
,
4
,
True
,
True
),
param
(
1
,
6
,
None
,
True
),
])
def
test_psd
(
self
,
duration
,
channel
,
mask
,
multi_mask
):
"""Providing dtype changes the kernel cache dtype"""
transform
=
PSD
(
multi_mask
)
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
duration
,
n_channels
=
channel
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
400
)
# (channel, freq, time)
spectrogram
=
spectrogram
.
to
(
torch
.
cdouble
)
if
mask
is
not
None
:
if
multi_mask
:
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
3
:])
else
:
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:])
psd_np
=
psd_numpy
(
spectrogram
.
detach
().
numpy
(),
mask
.
detach
().
numpy
(),
multi_mask
)
else
:
psd_np
=
psd_numpy
(
spectrogram
.
detach
().
numpy
(),
mask
,
multi_mask
)
psd
=
transform
(
spectrogram
,
mask
)
self
.
assertEqual
(
psd
,
psd_np
,
atol
=
1e-5
,
rtol
=
1e-5
)
test/torchaudio_unittest/example/beamforming/__init__.py
deleted
100644 → 0
View file @
88ca1e05
test/torchaudio_unittest/example/beamforming/autograd_cpu_test.py
deleted
100644 → 0
View file @
88ca1e05
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.autograd_test_impl
import
AutogradTestMixin
class
AutogradCPUTest
(
AutogradTestMixin
,
PytorchTestCase
):
device
=
'cpu'
class
AutogradRNNTCPUTest
(
PytorchTestCase
):
device
=
'cpu'
test/torchaudio_unittest/example/beamforming/autograd_cuda_test.py
deleted
100644 → 0
View file @
88ca1e05
from
torchaudio_unittest.common_utils
import
(
PytorchTestCase
,
skipIfNoCuda
,
)
from
.autograd_test_impl
import
AutogradTestMixin
@
skipIfNoCuda
class
AutogradCUDATest
(
AutogradTestMixin
,
PytorchTestCase
):
device
=
'cuda'
@
skipIfNoCuda
class
AutogradRNNTCUDATest
(
PytorchTestCase
):
device
=
'cuda'
test/torchaudio_unittest/example/beamforming/autograd_test_impl.py
deleted
100644 → 0
View file @
88ca1e05
from
typing
import
List
import
torch
from
beamforming.mvdr
import
PSD
,
MVDR
from
parameterized
import
parameterized
,
param
from
torch.autograd
import
gradcheck
,
gradgradcheck
from
torchaudio_unittest.common_utils
import
(
TestBaseMixin
,
get_whitenoise
,
get_spectrogram
,
)
class
AutogradTestMixin
(
TestBaseMixin
):
def
assert_grad
(
self
,
transform
:
torch
.
nn
.
Module
,
inputs
:
List
[
torch
.
Tensor
],
*
,
nondet_tol
:
float
=
0.0
,
):
transform
=
transform
.
to
(
dtype
=
torch
.
float64
,
device
=
self
.
device
)
# gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or
# `torch.cdouble`, when the default eps and tolerance values are used.
inputs_
=
[]
for
i
in
inputs
:
if
torch
.
is_tensor
(
i
):
i
=
i
.
to
(
dtype
=
torch
.
cdouble
if
i
.
is_complex
()
else
torch
.
double
,
device
=
self
.
device
)
i
.
requires_grad
=
True
inputs_
.
append
(
i
)
assert
gradcheck
(
transform
,
inputs_
)
assert
gradgradcheck
(
transform
,
inputs_
,
nondet_tol
=
nondet_tol
)
def
test_psd
(
self
):
transform
=
PSD
()
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
0.05
,
n_channels
=
2
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
400
)
self
.
assert_grad
(
transform
,
[
spectrogram
])
@
parameterized
.
expand
([
[
True
],
[
False
],
])
def
test_psd_with_mask
(
self
,
multi_mask
):
transform
=
PSD
(
multi_mask
=
multi_mask
)
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
0.05
,
n_channels
=
2
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
400
)
if
multi_mask
:
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
3
:])
else
:
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:])
self
.
assert_grad
(
transform
,
[
spectrogram
,
mask
])
@
parameterized
.
expand
([
param
(
solution
=
"ref_channel"
),
param
(
solution
=
"stv_power"
),
# evd will fail since the eigenvalues are not distinct
# param(solution="stv_evd"),
])
def
test_mvdr
(
self
,
solution
):
transform
=
MVDR
(
solution
=
solution
)
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
0.05
,
n_channels
=
2
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
400
)
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:])
self
.
assert_grad
(
transform
,
[
spectrogram
,
mask
])
test/torchaudio_unittest/example/beamforming/batch_consistency_test.py
deleted
100644 → 0
View file @
88ca1e05
"""Test numerical consistency among single input and batched input."""
import
torch
from
beamforming.mvdr
import
PSD
,
MVDR
from
parameterized
import
parameterized
from
torchaudio_unittest
import
common_utils
class
TestTransforms
(
common_utils
.
TorchaudioTestCase
):
def
test_batch_PSD
(
self
):
spec
=
torch
.
rand
((
4
,
6
,
201
,
100
),
dtype
=
torch
.
cdouble
)
# Single then transform then batch
expected
=
[]
for
i
in
range
(
4
):
expected
.
append
(
PSD
()(
spec
[
i
]))
expected
=
torch
.
stack
(
expected
)
# Batch then transform
computed
=
PSD
()(
spec
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_PSD_with_mask
(
self
):
spec
=
torch
.
rand
((
4
,
6
,
201
,
100
),
dtype
=
torch
.
cdouble
)
mask
=
torch
.
rand
((
4
,
201
,
100
))
# Single then transform then batch
expected
=
[]
for
i
in
range
(
4
):
expected
.
append
(
PSD
()(
spec
[
i
],
mask
[
i
]))
expected
=
torch
.
stack
(
expected
)
# Batch then transform
computed
=
PSD
()(
spec
,
mask
)
self
.
assertEqual
(
computed
,
expected
)
@
parameterized
.
expand
([
[
True
],
[
False
],
])
def
test_MVDR
(
self
,
multi_mask
):
spec
=
torch
.
rand
((
4
,
6
,
201
,
100
),
dtype
=
torch
.
cdouble
)
if
multi_mask
:
mask
=
torch
.
rand
((
4
,
6
,
201
,
100
))
else
:
mask
=
torch
.
rand
((
4
,
201
,
100
))
# Single then transform then batch
expected
=
[]
for
i
in
range
(
4
):
expected
.
append
(
MVDR
(
multi_mask
=
multi_mask
)(
spec
[
i
],
mask
[
i
]))
expected
=
torch
.
stack
(
expected
)
# Batch then transform
computed
=
MVDR
(
multi_mask
=
multi_mask
)(
spec
,
mask
)
self
.
assertEqual
(
computed
,
expected
)
test/torchaudio_unittest/example/beamforming/torchscript_consistency_cpu_test.py
deleted
100644 → 0
View file @
88ca1e05
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsFloat64Only
class
TestTransformsFloat32
(
Transforms
,
PytorchTestCase
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestTransformsFloat64
(
Transforms
,
TransformsFloat64Only
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/example/beamforming/torchscript_consistency_cuda_test.py
deleted
100644 → 0
View file @
88ca1e05
import
torch
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsFloat64Only
@
skipIfNoCuda
class
TestTransformsFloat32
(
Transforms
,
PytorchTestCase
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestTransformsFloat64
(
Transforms
,
TransformsFloat64Only
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/example/beamforming/torchscript_consistency_impl.py
deleted
100644 → 0
View file @
88ca1e05
"""Test suites for jit-ability and its numerical compatibility"""
import
torch
from
beamforming.mvdr
import
PSD
,
MVDR
from
parameterized
import
parameterized
,
param
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
TestBaseMixin
,
)
class
Transforms
(
TempDirMixin
,
TestBaseMixin
):
"""Implements test for Transforms that are performed for different devices"""
def
_assert_consistency_complex
(
self
,
transform
,
tensors
):
assert
tensors
[
0
].
is_complex
()
tensors
=
[
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
for
tensor
in
tensors
]
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
transform
).
save
(
path
)
ts_transform
=
torch
.
jit
.
load
(
path
)
output
=
transform
(
*
tensors
)
ts_output
=
ts_transform
(
*
tensors
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_PSD
(
self
):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
n_channels
=
4
)
spectrogram
=
common_utils
.
get_spectrogram
(
tensor
,
n_fft
=
400
,
hop_length
=
100
)
self
.
_assert_consistency_complex
(
PSD
(),
(
spectrogram
,))
def
test_PSD_with_mask
(
self
):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
n_channels
=
4
)
spectrogram
=
common_utils
.
get_spectrogram
(
tensor
,
n_fft
=
400
,
hop_length
=
100
)
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:])
self
.
_assert_consistency_complex
(
PSD
(),
(
spectrogram
,
mask
))
class
TransformsFloat64Only
(
TestBaseMixin
):
@
parameterized
.
expand
([
param
(
solution
=
"ref_channel"
,
online
=
True
),
param
(
solution
=
"stv_evd"
,
online
=
True
),
param
(
solution
=
"stv_power"
,
online
=
True
),
param
(
solution
=
"ref_channel"
,
online
=
False
),
param
(
solution
=
"stv_evd"
,
online
=
False
),
param
(
solution
=
"stv_power"
,
online
=
False
),
])
def
test_MVDR
(
self
,
solution
,
online
):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
n_channels
=
4
)
spectrogram
=
common_utils
.
get_spectrogram
(
tensor
,
n_fft
=
400
,
hop_length
=
100
)
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:])
self
.
_assert_consistency_complex
(
MVDR
(
solution
=
solution
,
online
=
online
),
(
spectrogram
,
mask
)
)
test/torchaudio_unittest/example/beamforming/transforms_cpu_test.py
deleted
100644 → 0
View file @
88ca1e05
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.
transforms_test_impl
import
TransformsTestBase
class
TransformsCPUFloat32Test
(
TransformsTestBase
,
PytorchTestCase
):
device
=
'cpu'
dtype
=
torch
.
float32
class
TransformsCPUFloat64Test
(
TransformsTestBase
,
PytorchTestCase
):
device
=
'cpu'
dtype
=
torch
.
float64
test/torchaudio_unittest/example/beamforming/transforms_cuda_test.py
deleted
100644 → 0
View file @
88ca1e05
import
torch
from
torchaudio_unittest.common_utils
import
(
PytorchTestCase
,
skipIfNoCuda
,
)
from
.
transforms_test_impl
import
TransformsTestBase
@
skipIfNoCuda
class
TransformsCPUFloat32Test
(
TransformsTestBase
,
PytorchTestCase
):
device
=
'cuda'
dtype
=
torch
.
float32
@
skipIfNoCuda
class
TransformsCPUFloat64Test
(
TransformsTestBase
,
PytorchTestCase
):
device
=
'cpu'
dtype
=
torch
.
float64
test/torchaudio_unittest/transforms/autograd_test_impl.py
View file @
ac97ad82
...
...
@@ -262,6 +262,42 @@ class AutogradTestMixin(TestBaseMixin):
spectrogram
=
torch
.
view_as_real
(
spectrogram
)
self
.
assert_grad
(
transform
,
[
spectrogram
])
def
test_psd
(
self
):
transform
=
T
.
PSD
()
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
0.05
,
n_channels
=
2
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
400
)
self
.
assert_grad
(
transform
,
[
spectrogram
])
@
parameterized
.
expand
([
[
True
],
[
False
],
])
def
test_psd_with_mask
(
self
,
multi_mask
):
transform
=
T
.
PSD
(
multi_mask
=
multi_mask
)
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
0.05
,
n_channels
=
2
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
400
)
if
multi_mask
:
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
3
:])
else
:
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:])
self
.
assert_grad
(
transform
,
[
spectrogram
,
mask
])
@
parameterized
.
expand
([
"ref_channel"
,
# stv_power test time too long, comment for now
# "stv_power",
# stv_evd will fail since the eigenvalues are not distinct
# "stv_evd",
])
def
test_mvdr
(
self
,
solution
):
transform
=
T
.
MVDR
(
solution
=
solution
)
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
0.05
,
n_channels
=
2
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
400
)
mask_s
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:])
mask_n
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:])
self
.
assert_grad
(
transform
,
[
spectrogram
,
mask_s
,
mask_n
])
class
AutogradTestFloat32
(
TestBaseMixin
):
def
assert_grad
(
...
...
test/torchaudio_unittest/transforms/batch_consistency_test.py
View file @
ac97ad82
...
...
@@ -175,3 +175,54 @@ class TestTransforms(common_utils.TorchaudioTestCase):
transform
=
T
.
PitchShift
(
sample_rate
,
n_steps
,
n_fft
=
400
)
self
.
assert_batch_consistency
(
transform
,
waveform
)
def
test_batch_PSD
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
specgram
=
common_utils
.
get_spectrogram
(
waveform
,
n_fft
=
400
)
specgram
=
specgram
.
reshape
(
3
,
2
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
])
transform
=
T
.
PSD
()
self
.
assert_batch_consistency
(
transform
,
specgram
)
def
test_batch_PSD_with_mask
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
to
(
torch
.
double
)
specgram
=
common_utils
.
get_spectrogram
(
waveform
,
n_fft
=
400
)
specgram
=
specgram
.
reshape
(
3
,
2
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
])
mask
=
torch
.
rand
((
3
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
]))
transform
=
T
.
PSD
()
# Single then transform then batch
expected
=
[
transform
(
specgram
[
i
],
mask
[
i
])
for
i
in
range
(
3
)]
expected
=
torch
.
stack
(
expected
)
# Batch then transform
computed
=
transform
(
specgram
,
mask
)
self
.
assertEqual
(
computed
,
expected
)
@
parameterized
.
expand
([
[
True
],
[
False
],
])
def
test_MVDR
(
self
,
multi_mask
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
to
(
torch
.
double
)
specgram
=
common_utils
.
get_spectrogram
(
waveform
,
n_fft
=
400
)
specgram
=
specgram
.
reshape
(
3
,
2
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
])
if
multi_mask
:
mask_s
=
torch
.
rand
((
3
,
2
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
]))
mask_n
=
torch
.
rand
((
3
,
2
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
]))
else
:
mask_s
=
torch
.
rand
((
3
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
]))
mask_n
=
torch
.
rand
((
3
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
]))
transform
=
T
.
MVDR
(
multi_mask
=
multi_mask
)
# Single then transform then batch
expected
=
[
transform
(
specgram
[
i
],
mask_s
[
i
],
mask_n
[
i
])
for
i
in
range
(
3
)]
expected
=
torch
.
stack
(
expected
)
# Batch then transform
computed
=
transform
(
specgram
,
mask_s
,
mask_n
)
self
.
assertEqual
(
computed
,
expected
)
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
View file @
ac97ad82
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsFloat32Only
from
.torchscript_consistency_impl
import
Transforms
,
TransformsFloat32Only
,
TransformsFloat64Only
class
TestTransformsFloat32
(
Transforms
,
TransformsFloat32Only
,
PytorchTestCase
):
...
...
@@ -9,6 +9,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
device
=
torch
.
device
(
'cpu'
)
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat64
(
Transforms
,
TransformsFloat64Only
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
View file @
ac97ad82
import
torch
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsFloat32Only
from
.torchscript_consistency_impl
import
Transforms
,
TransformsFloat32Only
,
TransformsFloat64Only
@
skipIfNoCuda
...
...
@@ -11,6 +11,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
@
skipIfNoCuda
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat64
(
Transforms
,
TransformsFloat64Only
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
View file @
ac97ad82
...
...
@@ -24,7 +24,7 @@ class Transforms(TestBaseMixin):
ts_output
=
ts_transform
(
tensor
,
*
args
)
self
.
assertEqual
(
ts_output
,
output
)
def
_assert_consistency_complex
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
):
def
_assert_consistency_complex
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
,
*
args
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
...
@@ -33,9 +33,8 @@ class Transforms(TestBaseMixin):
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
output
=
transform
(
tensor
,
*
args
)
ts_output
=
ts_transform
(
tensor
,
*
args
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_Spectrogram
(
self
):
...
...
@@ -152,6 +151,19 @@ class Transforms(TestBaseMixin):
waveform
)
def
test_PSD
(
self
):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
n_channels
=
4
)
spectrogram
=
common_utils
.
get_spectrogram
(
tensor
,
n_fft
=
400
,
hop_length
=
100
)
spectrogram
=
spectrogram
.
to
(
self
.
device
)
self
.
_assert_consistency_complex
(
T
.
PSD
(),
spectrogram
)
def
test_PSD_with_mask
(
self
):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
n_channels
=
4
)
spectrogram
=
common_utils
.
get_spectrogram
(
tensor
,
n_fft
=
400
,
hop_length
=
100
)
spectrogram
=
spectrogram
.
to
(
self
.
device
)
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:],
device
=
self
.
device
)
self
.
_assert_consistency_complex
(
T
.
PSD
(),
spectrogram
,
False
,
mask
)
class
TransformsFloat32Only
(
TestBaseMixin
):
def
test_rnnt_loss
(
self
):
...
...
@@ -167,3 +179,24 @@ class TransformsFloat32Only(TestBaseMixin):
target_lengths
=
torch
.
tensor
([
2
],
device
=
tensor
.
device
,
dtype
=
torch
.
int32
)
self
.
_assert_consistency
(
T
.
RNNTLoss
(),
logits
,
targets
,
logit_lengths
,
target_lengths
)
class
TransformsFloat64Only
(
TestBaseMixin
):
@
parameterized
.
expand
([
[
"ref_channel"
,
True
],
[
"stv_evd"
,
True
],
[
"stv_power"
,
True
],
[
"ref_channel"
,
False
],
[
"stv_evd"
,
False
],
[
"stv_power"
,
False
],
])
def
test_MVDR
(
self
,
solution
,
online
):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
n_channels
=
4
)
spectrogram
=
common_utils
.
get_spectrogram
(
tensor
,
n_fft
=
400
,
hop_length
=
100
)
spectrogram
=
spectrogram
.
to
(
device
=
self
.
device
,
dtype
=
torch
.
cdouble
)
mask_s
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:],
device
=
self
.
device
)
mask_n
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:],
device
=
self
.
device
)
self
.
_assert_consistency_complex
(
T
.
MVDR
(
solution
=
solution
,
online
=
online
),
spectrogram
,
False
,
mask_s
,
mask_n
)
test/torchaudio_unittest/transforms/transforms_test_impl.py
View file @
ac97ad82
...
...
@@ -7,6 +7,7 @@ from torchaudio_unittest.common_utils import (
get_spectrogram
,
nested_params
,
)
from
torchaudio_unittest.common_utils.psd_utils
import
psd_numpy
def
_get_ratio
(
mat
):
...
...
@@ -108,3 +109,26 @@ class TransformsTestBase(TestBaseMixin):
transformed
=
s
.
forward
(
waveform
)
restored
=
inv_s
.
forward
(
transformed
,
length
=
waveform
.
shape
[
-
1
])
self
.
assertEqual
(
waveform
,
restored
,
atol
=
1e-6
,
rtol
=
1e-6
)
@
parameterized
.
expand
([
param
(
0.5
,
1
,
True
,
False
),
param
(
0.5
,
1
,
None
,
False
),
param
(
1
,
4
,
True
,
True
),
param
(
1
,
6
,
None
,
True
),
])
def
test_psd
(
self
,
duration
,
channel
,
mask
,
multi_mask
):
"""Providing dtype changes the kernel cache dtype"""
transform
=
T
.
PSD
(
multi_mask
)
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
duration
,
n_channels
=
channel
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
400
)
# (channel, freq, time)
spectrogram
=
spectrogram
.
to
(
torch
.
cdouble
)
if
mask
is
not
None
:
if
multi_mask
:
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
3
:])
else
:
mask
=
torch
.
rand
(
spectrogram
.
shape
[
-
2
:])
psd_np
=
psd_numpy
(
spectrogram
.
detach
().
numpy
(),
mask
.
detach
().
numpy
(),
multi_mask
)
else
:
psd_np
=
psd_numpy
(
spectrogram
.
detach
().
numpy
(),
mask
,
multi_mask
)
psd
=
transform
(
spectrogram
,
mask
)
self
.
assertEqual
(
psd
,
psd_np
,
atol
=
1e-5
,
rtol
=
1e-5
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment