Unverified Commit 725f8b06 authored by moto's avatar moto Committed by GitHub
Browse files

Add metrics to source separation example(#894)

parent 9871219d
import math
from itertools import permutations
import torch
def sdr(estimate: torch.Tensor, reference: torch.Tensor, epsilon=1e-8) -> torch.Tensor:
"""Computes source-to-distortion ratio.
1. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref)
2. compute SNR between adjusted estimate and reference.
Args:
estimate (torch.Tensor): Estimtaed signal.
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: scale-invariant source-to-distortion ratio.
Shape: [batch, speaker]
References:
- Single-channel multi-speaker separation using deep clustering
Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey,
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
Notes:
This function is tested to produce the exact same result as
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L34-L56
"""
reference_pow = reference.pow(2).mean(axis=2, keepdim=True)
mix_pow = (estimate * reference).mean(axis=2, keepdim=True)
scale = mix_pow / (reference_pow + epsilon)
reference = scale * reference
error = estimate - reference
reference_pow = reference.pow(2).mean(axis=2)
error_pow = error.pow(2).mean(axis=2)
return 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
class PIT(torch.nn.Module):
"""Applies utterance-level speaker permutation
Computes the maxium possible value of the given utility function
over the permutations of the speakers.
Args:
utility_func (function):
Function that computes the utility (opposite of loss) with signature of
(extimate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor
where input Tensors are shape of [batch, speakers, frame] and
the output Tensor is shape of [batch, speakers].
References:
- Multi-talker Speech Separation with Utterance-level Permutation Invariant Training of
Deep Recurrent Neural Networks
Morten Kolbæk, Dong Yu, Zheng-Hua Tan and Jesper Jensen
https://arxiv.org/abs/1703.06284
"""
def __init__(self, utility_func):
super().__init__()
self.utility_func = utility_func
def forward(self, estimate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
"""Compute utterance-level PIT Loss
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [bacth, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
Returns:
torch.Tensor: Maximum criterion over the speaker permutation.
Shape: [batch, ]
"""
assert estimate.shape == reference.shape
batch_size, num_speakers = reference.shape[:2]
num_permute = math.factorial(num_speakers)
util_mat = torch.zeros(
batch_size, num_permute, dtype=estimate.dtype, device=estimate.device
)
for i, idx in enumerate(permutations(range(num_speakers))):
util = self.utility_func(estimate, reference[:, idx, :])
util_mat[:, i] = util.mean(dim=1) # take the average over speaker dimension
return util_mat.max(dim=1).values
_sdr_pit = PIT(utility_func=sdr)
def sdr_pit(estimate, reference):
"""Computes scale-invariant source-to-distortion ratio.
1. adjust both estimate and reference to have 0-mean
2. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref)
3. compute SNR between adjusted estimate and reference.
Args:
estimate (torch.Tensor): Estimtaed signal.
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: scale-invariant source-to-distortion ratio.
Shape: [batch, speaker]
References:
- Single-channel multi-speaker separation using deep clustering
Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey,
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
Notes:
This function is tested to produce the exact same result as the reference implementation,
*when the inputs have 0-mean*
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L107-L153
"""
return _sdr_pit(estimate, reference)
def sdri(estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor) -> torch.Tensor:
"""Compute the improvement of SDR (SDRi).
This function compute how much SDR is improved if the estimation is changed from
the original mixture signal to the actual estimated source signals. That is,
``SDR(estimate, reference) - SDR(mix, reference)``.
For computing ``SDR(estimate, reference)``, PIT (permutation invariant training) is applied,
so that best combination of sources between the reference signals and the esimate signals
are picked.
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [batch, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Shape: [batch, speakers == 1, time frame]
Returns:
torch.Tensor: Improved SDR. Shape: [batch, ]
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
sdr_ = sdr_pit(estimate, reference) # [batch, ]
base_sdr = sdr(mix, reference) # [batch, speaker]
return (sdr_.unsqueeze(1) - base_sdr).mean(dim=1)
import os
import sys
sys.path.append(
os.path.join(
os.path.dirname(__file__),
'..', '..', '..', 'examples'))
from itertools import product
import torch
from torch.testing._internal.common_utils import TestCase
from parameterized import parameterized
from . import sdr_reference
from source_separation.utils import metrics
class TestSDR(TestCase):
@parameterized.expand([(1, ), (2, ), (32, )])
def test_sdr(self, batch_size):
"""sdr produces the same result as the reference implementation"""
num_frames = 256
estimation = torch.rand(batch_size, num_frames)
origin = torch.rand(batch_size, num_frames)
sdr_ref = sdr_reference.calc_sdr_torch(estimation, origin)
sdr = metrics.sdr(estimation.unsqueeze(1), origin.unsqueeze(1)).squeeze(1)
self.assertEqual(sdr, sdr_ref)
@parameterized.expand(list(product([1, 2, 32], [2, 3, 4, 5])))
def test_sdr_pit(self, batch_size, num_sources):
"""sdr_pit produces the same result as the reference implementation"""
num_frames = 256
estimation = torch.randn(batch_size, num_sources, num_frames)
origin = torch.randn(batch_size, num_sources, num_frames)
estimation -= estimation.mean(axis=2, keepdim=True)
origin -= origin.mean(axis=2, keepdim=True)
batch_sdr_ref = sdr_reference.batch_SDR_torch(estimation, origin)
batch_sdr = metrics.sdr_pit(estimation, origin)
self.assertEqual(batch_sdr, batch_sdr_ref)
"""Reference Implementation of SDR and PIT SDR.
This module was taken from the following implementation
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py
which was made available by Yi Luo under the following liscence,
Creative Commons Attribution-NonCommercial-ShareAlike 3.0 United States License.
The module was modified in the following manner;
- Remove the functions other than `calc_sdr_torch` and `batch_SDR_torch`,
- Remove the import statements required only for the removed functions.
- Add `# flake8: noqa` so as not to report any format issue on this module.
The implementation of the retained functions and their formats are kept as-is.
"""
# flake8: noqa
import numpy as np
from itertools import permutations
import torch
def calc_sdr_torch(estimation, origin, mask=None):
"""
batch-wise SDR caculation for one audio file on pytorch Variables.
estimation: (batch, nsample)
origin: (batch, nsample)
mask: optional, (batch, nsample), binary
"""
if mask is not None:
origin = origin * mask
estimation = estimation * mask
origin_power = torch.pow(origin, 2).sum(1, keepdim=True) + 1e-8 # (batch, 1)
scale = torch.sum(origin*estimation, 1, keepdim=True) / origin_power # (batch, 1)
est_true = scale * origin # (batch, nsample)
est_res = estimation - est_true # (batch, nsample)
true_power = torch.pow(est_true, 2).sum(1)
res_power = torch.pow(est_res, 2).sum(1)
return 10*torch.log10(true_power) - 10*torch.log10(res_power) # (batch, 1)
def batch_SDR_torch(estimation, origin, mask=None):
"""
batch-wise SDR caculation for multiple audio files.
estimation: (batch, nsource, nsample)
origin: (batch, nsource, nsample)
mask: optional, (batch, nsample), binary
"""
batch_size_est, nsource_est, nsample_est = estimation.size()
batch_size_ori, nsource_ori, nsample_ori = origin.size()
assert batch_size_est == batch_size_ori, "Estimation and original sources should have same shape."
assert nsource_est == nsource_ori, "Estimation and original sources should have same shape."
assert nsample_est == nsample_ori, "Estimation and original sources should have same shape."
assert nsource_est < nsample_est, "Axis 1 should be the number of sources, and axis 2 should be the signal."
batch_size = batch_size_est
nsource = nsource_est
nsample = nsample_est
# zero mean signals
estimation = estimation - torch.mean(estimation, 2, keepdim=True).expand_as(estimation)
origin = origin - torch.mean(origin, 2, keepdim=True).expand_as(estimation)
# possible permutations
perm = list(set(permutations(np.arange(nsource))))
# pair-wise SDR
SDR = torch.zeros((batch_size, nsource, nsource)).type(estimation.type())
for i in range(nsource):
for j in range(nsource):
SDR[:,i,j] = calc_sdr_torch(estimation[:,i], origin[:,j], mask)
# choose the best permutation
SDR_max = []
SDR_perm = []
for permute in perm:
sdr = []
for idx in range(len(permute)):
sdr.append(SDR[:,idx,permute[idx]].view(batch_size,-1))
sdr = torch.sum(torch.cat(sdr, 1), 1)
SDR_perm.append(sdr.view(batch_size, 1))
SDR_perm = torch.cat(SDR_perm, 1)
SDR_max, _ = torch.max(SDR_perm, dim=1)
return SDR_max / nsource
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment