Commit bb077284 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Make lengths optional for additive noise operators (#2977)

Summary:
For greater flexibility, this PR makes argument `lengths` optional for `add_noise` and `AddNoise`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2977

Reviewed By: nateanl

Differential Revision: D42484211

Pulled By: hwangjeff

fbshipit-source-id: 54757dcc73df194bb98c1d9d42a2f43f3027b190
parent 51731bf9
...@@ -27,8 +27,8 @@ class AutogradTestImpl(TestBaseMixin): ...@@ -27,8 +27,8 @@ class AutogradTestImpl(TestBaseMixin):
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10 snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
self.assertTrue(gradcheck(F.add_noise, (waveform, noise, lengths, snr))) self.assertTrue(gradcheck(F.add_noise, (waveform, noise, snr, lengths)))
self.assertTrue(gradgradcheck(F.add_noise, (waveform, noise, lengths, snr))) self.assertTrue(gradgradcheck(F.add_noise, (waveform, noise, snr, lengths)))
@parameterized.expand( @parameterized.expand(
[ [
......
...@@ -35,13 +35,13 @@ class BatchConsistencyTest(TorchaudioTestCase): ...@@ -35,13 +35,13 @@ class BatchConsistencyTest(TorchaudioTestCase):
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10 snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, lengths, snr) actual = F.add_noise(waveform, noise, snr, lengths)
expected = [] expected = []
for i in range(leading_dims[0]): for i in range(leading_dims[0]):
for j in range(leading_dims[1]): for j in range(leading_dims[1]):
for k in range(leading_dims[2]): for k in range(leading_dims[2]):
expected.append(F.add_noise(waveform[i][j][k], noise[i][j][k], lengths[i][j][k], snr[i][j][k])) expected.append(F.add_noise(waveform[i][j][k], noise[i][j][k], snr[i][j][k], lengths[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L)) self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
......
...@@ -135,12 +135,12 @@ class FunctionalTestImpl(TestBaseMixin): ...@@ -135,12 +135,12 @@ class FunctionalTestImpl(TestBaseMixin):
noise = torch.rand(5, 1, 1, L, dtype=self.dtype, device=self.device) noise = torch.rand(5, 1, 1, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(5, 1, 3, dtype=self.dtype, device=self.device) lengths = torch.rand(5, 1, 3, dtype=self.dtype, device=self.device)
snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10 snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, lengths, snr) actual = F.add_noise(waveform, noise, snr, lengths)
noise_expanded = noise.expand(*leading_dims, L) noise_expanded = noise.expand(*leading_dims, L)
snr_expanded = snr.expand(*leading_dims) snr_expanded = snr.expand(*leading_dims)
lengths_expanded = lengths.expand(*leading_dims) lengths_expanded = lengths.expand(*leading_dims)
expected = F.add_noise(waveform, noise_expanded, lengths_expanded, snr_expanded) expected = F.add_noise(waveform, noise_expanded, snr_expanded, lengths_expanded)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
...@@ -157,7 +157,7 @@ class FunctionalTestImpl(TestBaseMixin): ...@@ -157,7 +157,7 @@ class FunctionalTestImpl(TestBaseMixin):
snr = torch.rand(*snr_dims, dtype=self.dtype, device=self.device) * 10 snr = torch.rand(*snr_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Input leading dimensions"): with self.assertRaisesRegex(ValueError, "Input leading dimensions"):
F.add_noise(waveform, noise, lengths, snr) F.add_noise(waveform, noise, snr, lengths)
def test_add_noise_length_check(self): def test_add_noise_length_check(self):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions.""" """Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
...@@ -170,7 +170,7 @@ class FunctionalTestImpl(TestBaseMixin): ...@@ -170,7 +170,7 @@ class FunctionalTestImpl(TestBaseMixin):
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10 snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Length dimensions"): with self.assertRaisesRegex(ValueError, "Length dimensions"):
F.add_noise(waveform, noise, lengths, snr) F.add_noise(waveform, noise, snr, lengths)
@nested_params( @nested_params(
[(2, 3), (2, 3, 5), (2, 3, 5, 7)], [(2, 3), (2, 3, 5), (2, 3, 5, 7)],
......
...@@ -37,16 +37,20 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin): ...@@ -37,16 +37,20 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
self._assert_consistency(getattr(F, fn), (x, y, mode)) self._assert_consistency(getattr(F, fn), (x, y, mode))
def test_add_noise(self): @nested_params([True, False])
def test_add_noise(self, use_lengths):
leading_dims = (2, 3) leading_dims = (2, 3)
L = 31 L = 31
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True) waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True) noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
else:
lengths = None
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10 snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
self._assert_consistency(F.add_noise, (waveform, noise, lengths, snr)) self._assert_consistency(F.add_noise, (waveform, noise, snr, lengths))
def test_barkscale_fbanks(self): def test_barkscale_fbanks(self):
if self.device != torch.device("cpu"): if self.device != torch.device("cpu"):
......
...@@ -75,18 +75,22 @@ class Autograd(TestBaseMixin): ...@@ -75,18 +75,22 @@ class Autograd(TestBaseMixin):
assert gradcheck(speed, (waveform, lengths)) assert gradcheck(speed, (waveform, lengths))
assert gradgradcheck(speed, (waveform, lengths)) assert gradgradcheck(speed, (waveform, lengths))
def test_AddNoise(self): @nested_params([True, False])
def test_AddNoise(self, use_lengths):
leading_dims = (2, 3) leading_dims = (2, 3)
L = 31 L = 31
waveform = torch.rand(*leading_dims, L, dtype=torch.float64, device=self.device, requires_grad=True) waveform = torch.rand(*leading_dims, L, dtype=torch.float64, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=torch.float64, device=self.device, requires_grad=True) noise = torch.rand(*leading_dims, L, dtype=torch.float64, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.rand(*leading_dims, dtype=torch.float64, device=self.device, requires_grad=True) lengths = torch.rand(*leading_dims, dtype=torch.float64, device=self.device, requires_grad=True)
else:
lengths = None
snr = torch.rand(*leading_dims, dtype=torch.float64, device=self.device, requires_grad=True) * 10 snr = torch.rand(*leading_dims, dtype=torch.float64, device=self.device, requires_grad=True) * 10
add_noise = T.AddNoise().to(self.device, torch.float64) add_noise = T.AddNoise().to(self.device, torch.float64)
assert gradcheck(add_noise, (waveform, noise, lengths, snr)) assert gradcheck(add_noise, (waveform, noise, snr, lengths))
assert gradgradcheck(add_noise, (waveform, noise, lengths, snr)) assert gradgradcheck(add_noise, (waveform, noise, snr, lengths))
def test_Preemphasis(self): def test_Preemphasis(self):
waveform = torch.rand(3, 4, 10, dtype=torch.float64, device=self.device, requires_grad=True) waveform = torch.rand(3, 4, 10, dtype=torch.float64, device=self.device, requires_grad=True)
......
...@@ -124,13 +124,13 @@ class BatchConsistencyTest(TorchaudioTestCase): ...@@ -124,13 +124,13 @@ class BatchConsistencyTest(TorchaudioTestCase):
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10 snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise() add_noise = T.AddNoise()
actual = add_noise(waveform, noise, lengths, snr) actual = add_noise(waveform, noise, snr, lengths)
expected = [] expected = []
for i in range(leading_dims[0]): for i in range(leading_dims[0]):
for j in range(leading_dims[1]): for j in range(leading_dims[1]):
for k in range(leading_dims[2]): for k in range(leading_dims[2]):
expected.append(add_noise(waveform[i][j][k], noise[i][j][k], lengths[i][j][k], snr[i][j][k])) expected.append(add_noise(waveform[i][j][k], noise[i][j][k], snr[i][j][k], lengths[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L)) self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
......
...@@ -41,18 +41,22 @@ class Transforms(TestBaseMixin): ...@@ -41,18 +41,22 @@ class Transforms(TestBaseMixin):
ts_output = torch_script(speed)(waveform, lengths) ts_output = torch_script(speed)(waveform, lengths)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def test_AddNoise(self): @nested_params([True, False])
def test_AddNoise(self, use_lengths):
leading_dims = (2, 3) leading_dims = (2, 3)
L = 31 L = 31
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True) waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True) noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
else:
lengths = None
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10 snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
add_noise = T.AddNoise().to(self.device, self.dtype) add_noise = T.AddNoise().to(self.device, self.dtype)
output = add_noise(waveform, noise, lengths, snr) output = add_noise(waveform, noise, snr, lengths)
ts_output = torch_script(add_noise)(waveform, noise, lengths, snr) ts_output = torch_script(add_noise)(waveform, noise, snr, lengths)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def test_Preemphasis(self): def test_Preemphasis(self):
......
...@@ -184,12 +184,12 @@ class TransformsTestImpl(TestBaseMixin): ...@@ -184,12 +184,12 @@ class TransformsTestImpl(TestBaseMixin):
snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10 snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise() add_noise = T.AddNoise()
actual = add_noise(waveform, noise, lengths, snr) actual = add_noise(waveform, noise, snr, lengths)
noise_expanded = noise.expand(*leading_dims, L) noise_expanded = noise.expand(*leading_dims, L)
snr_expanded = snr.expand(*leading_dims) snr_expanded = snr.expand(*leading_dims)
lengths_expanded = lengths.expand(*leading_dims) lengths_expanded = lengths.expand(*leading_dims)
expected = add_noise(waveform, noise_expanded, lengths_expanded, snr_expanded) expected = add_noise(waveform, noise_expanded, snr_expanded, lengths_expanded)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
...@@ -208,7 +208,7 @@ class TransformsTestImpl(TestBaseMixin): ...@@ -208,7 +208,7 @@ class TransformsTestImpl(TestBaseMixin):
add_noise = T.AddNoise() add_noise = T.AddNoise()
with self.assertRaisesRegex(ValueError, "Input leading dimensions"): with self.assertRaisesRegex(ValueError, "Input leading dimensions"):
add_noise(waveform, noise, lengths, snr) add_noise(waveform, noise, snr, lengths)
def test_AddNoise_length_check(self): def test_AddNoise_length_check(self):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions.""" """Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
...@@ -223,7 +223,7 @@ class TransformsTestImpl(TestBaseMixin): ...@@ -223,7 +223,7 @@ class TransformsTestImpl(TestBaseMixin):
add_noise = T.AddNoise() add_noise = T.AddNoise()
with self.assertRaisesRegex(ValueError, "Length dimensions"): with self.assertRaisesRegex(ValueError, "Length dimensions"):
add_noise(waveform, noise, lengths, snr) add_noise(waveform, noise, snr, lengths)
@nested_params( @nested_params(
[(2, 1, 31)], [(2, 1, 31)],
......
import math import math
import warnings import warnings
from typing import Tuple from typing import Optional, Tuple
import torch import torch
from torchaudio.functional import lfilter, resample from torchaudio.functional import lfilter, resample
...@@ -134,7 +134,9 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens ...@@ -134,7 +134,9 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens
return _apply_convolve_mode(result, x_size, y_size, mode) return _apply_convolve_mode(result, x_size, y_size, mode)
def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor, snr: torch.Tensor) -> torch.Tensor: def add_noise(
waveform: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor, lengths: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""Scales and adds noise to waveform per signal-to-noise ratio. r"""Scales and adds noise to waveform per signal-to-noise ratio.
Specifically, for each pair of waveform vector :math:`x \in \mathbb{R}^L` and noise vector Specifically, for each pair of waveform vector :math:`x \in \mathbb{R}^L` and noise vector
...@@ -160,16 +162,17 @@ def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor ...@@ -160,16 +162,17 @@ def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor
Args: Args:
waveform (torch.Tensor): Input waveform, with shape `(..., L)`. waveform (torch.Tensor): Input waveform, with shape `(..., L)`.
noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``). noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``).
lengths (torch.Tensor): Valid lengths of signals in ``waveform`` and ``noise``, with shape `(...,)`
(leading dimensions must match those of ``waveform``).
snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`. snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`.
lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform`` and ``noise``, with shape
`(...,)` (leading dimensions must match those of ``waveform``). If ``None``, all elements in ``waveform``
and ``noise`` are treated as valid. (Default: ``None``)
Returns: Returns:
torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)` torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)`
(same shape as ``waveform``). (same shape as ``waveform``).
""" """
if not (waveform.ndim - 1 == noise.ndim - 1 == lengths.ndim == snr.ndim): if not (waveform.ndim - 1 == noise.ndim - 1 == snr.ndim and (lengths is None or lengths.ndim == snr.ndim)):
raise ValueError("Input leading dimensions don't match.") raise ValueError("Input leading dimensions don't match.")
L = waveform.size(-1) L = waveform.size(-1)
...@@ -178,11 +181,18 @@ def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor ...@@ -178,11 +181,18 @@ def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor
raise ValueError(f"Length dimensions of waveform and noise don't match (got {L} and {noise.size(-1)}).") raise ValueError(f"Length dimensions of waveform and noise don't match (got {L} and {noise.size(-1)}).")
# compute scale # compute scale
if lengths is not None:
mask = torch.arange(0, L, device=lengths.device).expand(waveform.shape) < lengths.unsqueeze( mask = torch.arange(0, L, device=lengths.device).expand(waveform.shape) < lengths.unsqueeze(
-1 -1
) # (*, L) < (*, 1) = (*, L) ) # (*, L) < (*, 1) = (*, L)
energy_signal = torch.linalg.vector_norm(waveform * mask, ord=2, dim=-1) ** 2 # (*,) masked_waveform = waveform * mask
energy_noise = torch.linalg.vector_norm(noise * mask, ord=2, dim=-1) ** 2 # (*,) masked_noise = noise * mask
else:
masked_waveform = waveform
masked_noise = noise
energy_signal = torch.linalg.vector_norm(masked_waveform, ord=2, dim=-1) ** 2 # (*,)
energy_noise = torch.linalg.vector_norm(masked_noise, ord=2, dim=-1) ** 2 # (*,)
original_snr_db = 10 * (torch.log10(energy_signal) - torch.log10(energy_noise)) original_snr_db = 10 * (torch.log10(energy_signal) - torch.log10(energy_noise))
scale = 10 ** ((original_snr_db - snr) / 20.0) # (*,) scale = 10 ** ((original_snr_db - snr) / 20.0) # (*,)
......
...@@ -495,21 +495,22 @@ class AddNoise(torch.nn.Module): ...@@ -495,21 +495,22 @@ class AddNoise(torch.nn.Module):
""" """
def forward( def forward(
self, waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor, snr: torch.Tensor self, waveform: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor, lengths: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
waveform (torch.Tensor): Input waveform, with shape `(..., L)`. waveform (torch.Tensor): Input waveform, with shape `(..., L)`.
noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``). noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``).
lengths (torch.Tensor): Valid lengths of signals in ``waveform`` and ``noise``, with shape `(...,)`
(leading dimensions must match those of ``waveform``).
snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`. snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`.
lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform`` and ``noise``,
with shape `(...,)` (leading dimensions must match those of ``waveform``). If ``None``, all
elements in ``waveform`` and ``noise`` are treated as valid. (Default: ``None``)
Returns: Returns:
torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)` torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)`
(same shape as ``waveform``). (same shape as ``waveform``).
""" """
return add_noise(waveform, noise, lengths, snr) return add_noise(waveform, noise, snr, lengths)
class Preemphasis(torch.nn.Module): class Preemphasis(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment