Commit f3bb30b8 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add additive noise function (#2608)

Summary:
Adds function `add_noise`, which computes and returns the sum of a waveform and scaled noise.

Pull Request resolved: https://github.com/pytorch/audio/pull/2608

Reviewed By: nateanl

Differential Revision: D38557141

Pulled By: hwangjeff

fbshipit-source-id: 1457fa213f43ca5b4333d3c7580971655d4260a0
parent cd4d6607
...@@ -4,6 +4,11 @@ torchaudio.prototype.functional ...@@ -4,6 +4,11 @@ torchaudio.prototype.functional
.. py:module:: torchaudio.prototype.functional .. py:module:: torchaudio.prototype.functional
.. currentmodule:: torchaudio.prototype.functional .. currentmodule:: torchaudio.prototype.functional
add_noise
~~~~~~~~~
.. autofunction:: add_noise
convolve convolve
~~~~~~~~ ~~~~~~~~
......
...@@ -19,3 +19,15 @@ class AutogradTestImpl(TestBaseMixin): ...@@ -19,3 +19,15 @@ class AutogradTestImpl(TestBaseMixin):
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device, requires_grad=True) y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device, requires_grad=True)
self.assertTrue(gradcheck(fn, (x, y))) self.assertTrue(gradcheck(fn, (x, y)))
self.assertTrue(gradgradcheck(fn, (x, y))) self.assertTrue(gradgradcheck(fn, (x, y)))
def test_add_noise(self):
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
self.assertTrue(gradcheck(F.add_noise, (waveform, noise, lengths, snr)))
self.assertTrue(gradgradcheck(F.add_noise, (waveform, noise, lengths, snr)))
...@@ -22,3 +22,22 @@ class BatchConsistencyTest(TorchaudioTestCase): ...@@ -22,3 +22,22 @@ class BatchConsistencyTest(TorchaudioTestCase):
) )
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
def test_add_noise(self):
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, lengths, snr)
expected = []
for i in range(leading_dims[0]):
for j in range(leading_dims[1]):
for k in range(leading_dims[2]):
expected.append(F.add_noise(waveform[i][j][k], noise[i][j][k], lengths[i][j][k], snr[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
import numpy as np import numpy as np
import torch import torch
import torchaudio.prototype.functional as F import torchaudio.prototype.functional as F
from parameterized import parameterized
from scipy import signal from scipy import signal
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
...@@ -60,3 +61,49 @@ class FunctionalTestImpl(TestBaseMixin): ...@@ -60,3 +61,49 @@ class FunctionalTestImpl(TestBaseMixin):
y = torch.rand(*y_shape, dtype=self.dtype, device=self.device) y = torch.rand(*y_shape, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, "Leading dimensions"): with self.assertRaisesRegex(ValueError, "Leading dimensions"):
fn(x, y) fn(x, y)
def test_add_noise_broadcast(self):
"""Check that add_noise produces correct outputs when broadcasting input dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(5, 1, 1, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(5, 1, 3, dtype=self.dtype, device=self.device)
snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, lengths, snr)
noise_expanded = noise.expand(*leading_dims, L)
snr_expanded = snr.expand(*leading_dims)
lengths_expanded = lengths.expand(*leading_dims)
expected = F.add_noise(waveform, noise_expanded, lengths_expanded, snr_expanded)
self.assertEqual(expected, actual)
@parameterized.expand(
[((5, 2, 3), (2, 1, 1), (5, 2), (5, 2, 3)), ((2, 1), (5,), (5,), (5,)), ((3,), (5, 2, 3), (2, 1, 1), (5, 2))]
)
def test_add_noise_leading_dim_check(self, waveform_dims, noise_dims, lengths_dims, snr_dims):
"""Check that add_noise properly rejects inputs with different leading dimension lengths."""
L = 51
waveform = torch.rand(*waveform_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*noise_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*lengths_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*snr_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Input leading dimensions"):
F.add_noise(waveform, noise, lengths, snr)
def test_add_noise_length_check(self):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, 50, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Length dimensions"):
F.add_noise(waveform, noise, lengths, snr)
...@@ -37,3 +37,14 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin): ...@@ -37,3 +37,14 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device) y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
self._assert_consistency(fn, (x, y)) self._assert_consistency(fn, (x, y))
def test_add_noise(self):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
self._assert_consistency(F.add_noise, (waveform, noise, lengths, snr))
from .functional import convolve, fftconvolve from .functional import add_noise, convolve, fftconvolve
__all__ = ["convolve", "fftconvolve"] __all__ = ["add_noise", "convolve", "fftconvolve"]
...@@ -19,12 +19,12 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...@@ -19,12 +19,12 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
.. properties:: Autograd TorchScript .. properties:: Autograd TorchScript
Args: Args:
x (torch.Tensor): First convolution operand, with shape `(*, N)`. x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(*, M)` y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``). (leading dimensions must match those of ``x``).
Returns: Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(*, N + M - 1)`, where torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., N + M - 1)`, where
the leading dimensions match those of ``x``. the leading dimensions match those of ``x``.
.. _convolution: .. _convolution:
...@@ -48,12 +48,12 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...@@ -48,12 +48,12 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
.. properties:: Autograd TorchScript .. properties:: Autograd TorchScript
Args: Args:
x (torch.Tensor): First convolution operand, with shape `(*, N)`. x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(*, M)` y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``). (leading dimensions must match those of ``x``).
Returns: Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(*, N + M - 1)`, where torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., N + M - 1)`, where
the leading dimensions match those of ``x``. the leading dimensions match those of ``x``.
.. _convolution: .. _convolution:
...@@ -76,3 +76,61 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...@@ -76,3 +76,61 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
) )
output_shape = x.shape[:-1] + (-1,) output_shape = x.shape[:-1] + (-1,)
return output.reshape(output_shape) return output.reshape(output_shape)
def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor, snr: torch.Tensor) -> torch.Tensor:
r"""Scales and adds noise to waveform per signal-to-noise ratio.
Specifically, for each pair of waveform vector :math:`x \in \mathbb{R}^L` and noise vector
:math:`n \in \mathbb{R}^L`, the function computes output :math:`y` as
.. math::
y = x + a n \, \text{,}
where
.. math::
a = \sqrt{ \frac{ ||x||_{2}^{2} }{ ||n||_{2}^{2} } \cdot 10^{-\frac{\text{SNR}}{10}} } \, \text{,}
with :math:`\text{SNR}` being the desired signal-to-noise ratio between :math:`x` and :math:`n`, in dB.
Note that this function broadcasts singleton leading dimensions in its inputs in a manner that is
consistent with the above formulae and PyTorch's broadcasting semantics.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Args:
waveform (torch.Tensor): Input waveform, with shape `(..., L)`.
noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``).
lengths (torch.Tensor): Valid lengths of signals in ``waveform`` and ``noise``, with shape `(...,)`
(leading dimensions must match those of ``waveform``).
snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`.
Returns:
torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)`
(same shape as ``waveform``).
"""
if not (waveform.ndim - 1 == noise.ndim - 1 == lengths.ndim == snr.ndim):
raise ValueError("Input leading dimensions don't match.")
L = waveform.size(-1)
if L != noise.size(-1):
raise ValueError(f"Length dimensions of waveform and noise don't match (got {L} and {noise.size(-1)}).")
# compute scale
mask = torch.arange(0, L, device=lengths.device).expand(waveform.shape) < lengths.unsqueeze(
-1
) # (*, L) < (*, 1) = (*, L)
energy_signal = torch.linalg.vector_norm(waveform * mask, ord=2, dim=-1) ** 2 # (*,)
energy_noise = torch.linalg.vector_norm(noise * mask, ord=2, dim=-1) ** 2 # (*,)
original_snr_db = 10 * (torch.log10(energy_signal) - torch.log10(energy_noise))
scale = 10 ** ((original_snr_db - snr) / 20.0) # (*,)
# scale noise
scaled_noise = scale.unsqueeze(-1) * noise # (*, 1) * (*, L) = (*, L)
return waveform + scaled_noise # (*, L)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment