Commit 06301c0a authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Add Frechet distance function (#3545)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3545

Adds function for computing the Fréchet distance between two multivariate normal distributions.

Reviewed By: mthrok

Differential Revision: D48126102

fbshipit-source-id: e4e122b831e1e752037c03f5baa9451e81ef1697
parent 8d858c38
......@@ -32,6 +32,7 @@ Utility
preemphasis
deemphasis
speed
frechet_distance
Forced Alignment
----------------
......
......@@ -579,3 +579,14 @@ booktitle = {International Conference on Acoustics, Speech and Signal Processing
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@article{dowson1982frechet,
title={The Fr{\'e}chet distance between multivariate normal distributions},
author={Dowson, DC and Landau, BV666017},
journal={Journal of multivariate analysis},
volume={12},
number={3},
pages={450--455},
year={1982},
publisher={Elsevier}
}
......@@ -383,6 +383,14 @@ class Autograd(TestBaseMixin):
coeff = 0.9
self.assert_grad(F.deemphasis, (waveform, coeff))
def test_frechet_distance(self):
N = 16
mu_x = torch.rand((N,))
sigma_x = torch.rand((N, N))
mu_y = torch.rand((N,))
sigma_y = torch.rand((N, N))
self.assert_grad(F.frechet_distance, (mu_x, sigma_x, mu_y, sigma_y))
class AutogradFloat32(TestBaseMixin):
def assert_grad(
......
......@@ -1282,6 +1282,38 @@ class Functional(TestBaseMixin):
spans = F.merge_tokens(tokens_, scores_, blank=0)
self._assert_tokens(spans, expected_)
def test_frechet_distance_univariate(self):
r"""Check that Frechet distance is computed correctly for simple one-dimensional case."""
mu_x = torch.rand((1,), device=self.device)
sigma_x = torch.rand((1, 1), device=self.device)
mu_y = torch.rand((1,), device=self.device)
sigma_y = torch.rand((1, 1), device=self.device)
# Matrix square root reduces to scalar square root.
expected = (mu_x - mu_y) ** 2 + sigma_x + sigma_y - 2 * torch.sqrt(sigma_x * sigma_y)
expected = expected.item()
actual = F.frechet_distance(mu_x, sigma_x, mu_y, sigma_y)
self.assertEqual(expected, actual)
def test_frechet_distance_diagonal_covariance(self):
r"""Check that Frechet distance is computed correctly for case where covariance matrices are diagonal."""
N = 15
mu_x = torch.rand((N,), device=self.device)
sigma_x = torch.diag(torch.rand((N,), device=self.device))
mu_y = torch.rand((N,), device=self.device)
sigma_y = torch.diag(torch.rand((N,), device=self.device))
expected = (
torch.sum((mu_x - mu_y) ** 2) + torch.sum(sigma_x + sigma_y) - 2 * torch.sum(torch.sqrt(sigma_x * sigma_y))
)
expected = expected.item()
actual = F.frechet_distance(mu_x, sigma_x, mu_y, sigma_y)
self.assertEqual(expected, actual)
class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self):
......
......@@ -36,6 +36,7 @@ from .functional import (
detect_pitch_frequency,
edit_distance,
fftconvolve,
frechet_distance,
griffinlim,
inverse_spectrogram,
linear_fbanks,
......@@ -122,4 +123,5 @@ __all__ = [
"speed",
"preemphasis",
"deemphasis",
"frechet_distance",
]
......@@ -2499,3 +2499,41 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device)
b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device)
return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
def frechet_distance(mu_x, sigma_x, mu_y, sigma_y):
r"""Computes the Fréchet distance between two multivariate normal distributions :cite:`dowson1982frechet`.
Concretely, for multivariate Gaussians :math:`X(\mu_X, \Sigma_X)`
and :math:`Y(\mu_Y, \Sigma_Y)`, the function computes and returns :math:`F` as
.. math::
F(X, Y) = || \mu_X - \mu_Y ||_2^2
+ \text{Tr}\left( \Sigma_X + \Sigma_Y - 2 \sqrt{\Sigma_X \Sigma_Y} \right)
Args:
mu_x (torch.Tensor): mean :math:`\mu_X` of multivariate Gaussian :math:`X`, with shape `(N,)`.
sigma_x (torch.Tensor): covariance matrix :math:`\Sigma_X` of :math:`X`, with shape `(N, N)`.
mu_y (torch.Tensor): mean :math:`\mu_Y` of multivariate Gaussian :math:`Y`, with shape `(N,)`.
sigma_y (torch.Tensor): covariance matrix :math:`\Sigma_Y` of :math:`Y`, with shape `(N, N)`.
Returns:
torch.Tensor: the Fréchet distance between :math:`X` and :math:`Y`.
"""
if len(mu_x.size()) != 1:
raise ValueError(f"Input mu_x must be one-dimensional; got dimension {len(mu_x.size())}.")
if len(sigma_x.size()) != 2:
raise ValueError(f"Input sigma_x must be two-dimensional; got dimension {len(sigma_x.size())}.")
if sigma_x.size(0) != sigma_x.size(1) != mu_x.size(0):
raise ValueError("Each of sigma_x's dimensions must match mu_x's size.")
if mu_x.size() != mu_y.size():
raise ValueError(f"Inputs mu_x and mu_y must have the same shape; got {mu_x.size()} and {mu_y.size()}.")
if sigma_x.size() != sigma_y.size():
raise ValueError(
f"Inputs sigma_x and sigma_y must have the same shape; got {sigma_x.size()} and {sigma_y.size()}."
)
a = (mu_x - mu_y).square().sum()
b = sigma_x.trace() + sigma_y.trace()
c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum()
return a + b - 2 * c
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment