Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
725f8b06
Unverified
Commit
725f8b06
authored
Oct 06, 2020
by
moto
Committed by
GitHub
Oct 06, 2020
Browse files
Add metrics to source separation example(#894)
parent
9871219d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
312 additions
and
0 deletions
+312
-0
examples/source_separation/utils/__init__.py
examples/source_separation/utils/__init__.py
+3
-0
examples/source_separation/utils/metrics.py
examples/source_separation/utils/metrics.py
+164
-0
test/torchaudio_unittest/example/__init__.py
test/torchaudio_unittest/example/__init__.py
+8
-0
test/torchaudio_unittest/example/souce_sepration/__init__.py
test/torchaudio_unittest/example/souce_sepration/__init__.py
+0
-0
test/torchaudio_unittest/example/souce_sepration/metrics_test.py
...rchaudio_unittest/example/souce_sepration/metrics_test.py
+39
-0
test/torchaudio_unittest/example/souce_sepration/sdr_reference.py
...chaudio_unittest/example/souce_sepration/sdr_reference.py
+98
-0
No files found.
examples/source_separation/utils/__init__.py
0 → 100644
View file @
725f8b06
from
.
import
(
metrics
,
)
examples/source_separation/utils/metrics.py
0 → 100644
View file @
725f8b06
import
math
from
itertools
import
permutations
import
torch
def
sdr
(
estimate
:
torch
.
Tensor
,
reference
:
torch
.
Tensor
,
epsilon
=
1e-8
)
->
torch
.
Tensor
:
"""Computes source-to-distortion ratio.
1. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref)
2. compute SNR between adjusted estimate and reference.
Args:
estimate (torch.Tensor): Estimtaed signal.
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: scale-invariant source-to-distortion ratio.
Shape: [batch, speaker]
References:
- Single-channel multi-speaker separation using deep clustering
Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey,
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
Notes:
This function is tested to produce the exact same result as
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L34-L56
"""
reference_pow
=
reference
.
pow
(
2
).
mean
(
axis
=
2
,
keepdim
=
True
)
mix_pow
=
(
estimate
*
reference
).
mean
(
axis
=
2
,
keepdim
=
True
)
scale
=
mix_pow
/
(
reference_pow
+
epsilon
)
reference
=
scale
*
reference
error
=
estimate
-
reference
reference_pow
=
reference
.
pow
(
2
).
mean
(
axis
=
2
)
error_pow
=
error
.
pow
(
2
).
mean
(
axis
=
2
)
return
10
*
torch
.
log10
(
reference_pow
)
-
10
*
torch
.
log10
(
error_pow
)
class
PIT
(
torch
.
nn
.
Module
):
"""Applies utterance-level speaker permutation
Computes the maxium possible value of the given utility function
over the permutations of the speakers.
Args:
utility_func (function):
Function that computes the utility (opposite of loss) with signature of
(extimate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor
where input Tensors are shape of [batch, speakers, frame] and
the output Tensor is shape of [batch, speakers].
References:
- Multi-talker Speech Separation with Utterance-level Permutation Invariant Training of
Deep Recurrent Neural Networks
Morten Kolbæk, Dong Yu, Zheng-Hua Tan and Jesper Jensen
https://arxiv.org/abs/1703.06284
"""
def
__init__
(
self
,
utility_func
):
super
().
__init__
()
self
.
utility_func
=
utility_func
def
forward
(
self
,
estimate
:
torch
.
Tensor
,
reference
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute utterance-level PIT Loss
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [bacth, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
Returns:
torch.Tensor: Maximum criterion over the speaker permutation.
Shape: [batch, ]
"""
assert
estimate
.
shape
==
reference
.
shape
batch_size
,
num_speakers
=
reference
.
shape
[:
2
]
num_permute
=
math
.
factorial
(
num_speakers
)
util_mat
=
torch
.
zeros
(
batch_size
,
num_permute
,
dtype
=
estimate
.
dtype
,
device
=
estimate
.
device
)
for
i
,
idx
in
enumerate
(
permutations
(
range
(
num_speakers
))):
util
=
self
.
utility_func
(
estimate
,
reference
[:,
idx
,
:])
util_mat
[:,
i
]
=
util
.
mean
(
dim
=
1
)
# take the average over speaker dimension
return
util_mat
.
max
(
dim
=
1
).
values
_sdr_pit
=
PIT
(
utility_func
=
sdr
)
def
sdr_pit
(
estimate
,
reference
):
"""Computes scale-invariant source-to-distortion ratio.
1. adjust both estimate and reference to have 0-mean
2. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref)
3. compute SNR between adjusted estimate and reference.
Args:
estimate (torch.Tensor): Estimtaed signal.
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: scale-invariant source-to-distortion ratio.
Shape: [batch, speaker]
References:
- Single-channel multi-speaker separation using deep clustering
Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey,
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
Notes:
This function is tested to produce the exact same result as the reference implementation,
*when the inputs have 0-mean*
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L107-L153
"""
return
_sdr_pit
(
estimate
,
reference
)
def
sdri
(
estimate
:
torch
.
Tensor
,
reference
:
torch
.
Tensor
,
mix
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute the improvement of SDR (SDRi).
This function compute how much SDR is improved if the estimation is changed from
the original mixture signal to the actual estimated source signals. That is,
``SDR(estimate, reference) - SDR(mix, reference)``.
For computing ``SDR(estimate, reference)``, PIT (permutation invariant training) is applied,
so that best combination of sources between the reference signals and the esimate signals
are picked.
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [batch, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Shape: [batch, speakers == 1, time frame]
Returns:
torch.Tensor: Improved SDR. Shape: [batch, ]
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
sdr_
=
sdr_pit
(
estimate
,
reference
)
# [batch, ]
base_sdr
=
sdr
(
mix
,
reference
)
# [batch, speaker]
return
(
sdr_
.
unsqueeze
(
1
)
-
base_sdr
).
mean
(
dim
=
1
)
test/torchaudio_unittest/example/__init__.py
0 → 100644
View file @
725f8b06
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'..'
,
'..'
,
'examples'
))
test/torchaudio_unittest/example/souce_sepration/__init__.py
0 → 100644
View file @
725f8b06
test/torchaudio_unittest/example/souce_sepration/metrics_test.py
0 → 100644
View file @
725f8b06
from
itertools
import
product
import
torch
from
torch.testing._internal.common_utils
import
TestCase
from
parameterized
import
parameterized
from
.
import
sdr_reference
from
source_separation.utils
import
metrics
class
TestSDR
(
TestCase
):
@
parameterized
.
expand
([(
1
,
),
(
2
,
),
(
32
,
)])
def
test_sdr
(
self
,
batch_size
):
"""sdr produces the same result as the reference implementation"""
num_frames
=
256
estimation
=
torch
.
rand
(
batch_size
,
num_frames
)
origin
=
torch
.
rand
(
batch_size
,
num_frames
)
sdr_ref
=
sdr_reference
.
calc_sdr_torch
(
estimation
,
origin
)
sdr
=
metrics
.
sdr
(
estimation
.
unsqueeze
(
1
),
origin
.
unsqueeze
(
1
)).
squeeze
(
1
)
self
.
assertEqual
(
sdr
,
sdr_ref
)
@
parameterized
.
expand
(
list
(
product
([
1
,
2
,
32
],
[
2
,
3
,
4
,
5
])))
def
test_sdr_pit
(
self
,
batch_size
,
num_sources
):
"""sdr_pit produces the same result as the reference implementation"""
num_frames
=
256
estimation
=
torch
.
randn
(
batch_size
,
num_sources
,
num_frames
)
origin
=
torch
.
randn
(
batch_size
,
num_sources
,
num_frames
)
estimation
-=
estimation
.
mean
(
axis
=
2
,
keepdim
=
True
)
origin
-=
origin
.
mean
(
axis
=
2
,
keepdim
=
True
)
batch_sdr_ref
=
sdr_reference
.
batch_SDR_torch
(
estimation
,
origin
)
batch_sdr
=
metrics
.
sdr_pit
(
estimation
,
origin
)
self
.
assertEqual
(
batch_sdr
,
batch_sdr_ref
)
test/torchaudio_unittest/example/souce_sepration/sdr_reference.py
0 → 100644
View file @
725f8b06
"""Reference Implementation of SDR and PIT SDR.
This module was taken from the following implementation
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py
which was made available by Yi Luo under the following liscence,
Creative Commons Attribution-NonCommercial-ShareAlike 3.0 United States License.
The module was modified in the following manner;
- Remove the functions other than `calc_sdr_torch` and `batch_SDR_torch`,
- Remove the import statements required only for the removed functions.
- Add `# flake8: noqa` so as not to report any format issue on this module.
The implementation of the retained functions and their formats are kept as-is.
"""
# flake8: noqa
import
numpy
as
np
from
itertools
import
permutations
import
torch
def
calc_sdr_torch
(
estimation
,
origin
,
mask
=
None
):
"""
batch-wise SDR caculation for one audio file on pytorch Variables.
estimation: (batch, nsample)
origin: (batch, nsample)
mask: optional, (batch, nsample), binary
"""
if
mask
is
not
None
:
origin
=
origin
*
mask
estimation
=
estimation
*
mask
origin_power
=
torch
.
pow
(
origin
,
2
).
sum
(
1
,
keepdim
=
True
)
+
1e-8
# (batch, 1)
scale
=
torch
.
sum
(
origin
*
estimation
,
1
,
keepdim
=
True
)
/
origin_power
# (batch, 1)
est_true
=
scale
*
origin
# (batch, nsample)
est_res
=
estimation
-
est_true
# (batch, nsample)
true_power
=
torch
.
pow
(
est_true
,
2
).
sum
(
1
)
res_power
=
torch
.
pow
(
est_res
,
2
).
sum
(
1
)
return
10
*
torch
.
log10
(
true_power
)
-
10
*
torch
.
log10
(
res_power
)
# (batch, 1)
def
batch_SDR_torch
(
estimation
,
origin
,
mask
=
None
):
"""
batch-wise SDR caculation for multiple audio files.
estimation: (batch, nsource, nsample)
origin: (batch, nsource, nsample)
mask: optional, (batch, nsample), binary
"""
batch_size_est
,
nsource_est
,
nsample_est
=
estimation
.
size
()
batch_size_ori
,
nsource_ori
,
nsample_ori
=
origin
.
size
()
assert
batch_size_est
==
batch_size_ori
,
"Estimation and original sources should have same shape."
assert
nsource_est
==
nsource_ori
,
"Estimation and original sources should have same shape."
assert
nsample_est
==
nsample_ori
,
"Estimation and original sources should have same shape."
assert
nsource_est
<
nsample_est
,
"Axis 1 should be the number of sources, and axis 2 should be the signal."
batch_size
=
batch_size_est
nsource
=
nsource_est
nsample
=
nsample_est
# zero mean signals
estimation
=
estimation
-
torch
.
mean
(
estimation
,
2
,
keepdim
=
True
).
expand_as
(
estimation
)
origin
=
origin
-
torch
.
mean
(
origin
,
2
,
keepdim
=
True
).
expand_as
(
estimation
)
# possible permutations
perm
=
list
(
set
(
permutations
(
np
.
arange
(
nsource
))))
# pair-wise SDR
SDR
=
torch
.
zeros
((
batch_size
,
nsource
,
nsource
)).
type
(
estimation
.
type
())
for
i
in
range
(
nsource
):
for
j
in
range
(
nsource
):
SDR
[:,
i
,
j
]
=
calc_sdr_torch
(
estimation
[:,
i
],
origin
[:,
j
],
mask
)
# choose the best permutation
SDR_max
=
[]
SDR_perm
=
[]
for
permute
in
perm
:
sdr
=
[]
for
idx
in
range
(
len
(
permute
)):
sdr
.
append
(
SDR
[:,
idx
,
permute
[
idx
]].
view
(
batch_size
,
-
1
))
sdr
=
torch
.
sum
(
torch
.
cat
(
sdr
,
1
),
1
)
SDR_perm
.
append
(
sdr
.
view
(
batch_size
,
1
))
SDR_perm
=
torch
.
cat
(
SDR_perm
,
1
)
SDR_max
,
_
=
torch
.
max
(
SDR_perm
,
dim
=
1
)
return
SDR_max
/
nsource
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment