Commit bb077284 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Make lengths optional for additive noise operators (#2977)

Summary:
For greater flexibility, this PR makes argument `lengths` optional for `add_noise` and `AddNoise`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2977

Reviewed By: nateanl

Differential Revision: D42484211

Pulled By: hwangjeff

fbshipit-source-id: 54757dcc73df194bb98c1d9d42a2f43f3027b190
parent 51731bf9
......@@ -27,8 +27,8 @@ class AutogradTestImpl(TestBaseMixin):
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
self.assertTrue(gradcheck(F.add_noise, (waveform, noise, lengths, snr)))
self.assertTrue(gradgradcheck(F.add_noise, (waveform, noise, lengths, snr)))
self.assertTrue(gradcheck(F.add_noise, (waveform, noise, snr, lengths)))
self.assertTrue(gradgradcheck(F.add_noise, (waveform, noise, snr, lengths)))
@parameterized.expand(
[
......
......@@ -35,13 +35,13 @@ class BatchConsistencyTest(TorchaudioTestCase):
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, lengths, snr)
actual = F.add_noise(waveform, noise, snr, lengths)
expected = []
for i in range(leading_dims[0]):
for j in range(leading_dims[1]):
for k in range(leading_dims[2]):
expected.append(F.add_noise(waveform[i][j][k], noise[i][j][k], lengths[i][j][k], snr[i][j][k]))
expected.append(F.add_noise(waveform[i][j][k], noise[i][j][k], snr[i][j][k], lengths[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
......
......@@ -135,12 +135,12 @@ class FunctionalTestImpl(TestBaseMixin):
noise = torch.rand(5, 1, 1, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(5, 1, 3, dtype=self.dtype, device=self.device)
snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10
actual = F.add_noise(waveform, noise, lengths, snr)
actual = F.add_noise(waveform, noise, snr, lengths)
noise_expanded = noise.expand(*leading_dims, L)
snr_expanded = snr.expand(*leading_dims)
lengths_expanded = lengths.expand(*leading_dims)
expected = F.add_noise(waveform, noise_expanded, lengths_expanded, snr_expanded)
expected = F.add_noise(waveform, noise_expanded, snr_expanded, lengths_expanded)
self.assertEqual(expected, actual)
......@@ -157,7 +157,7 @@ class FunctionalTestImpl(TestBaseMixin):
snr = torch.rand(*snr_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Input leading dimensions"):
F.add_noise(waveform, noise, lengths, snr)
F.add_noise(waveform, noise, snr, lengths)
def test_add_noise_length_check(self):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
......@@ -170,7 +170,7 @@ class FunctionalTestImpl(TestBaseMixin):
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
with self.assertRaisesRegex(ValueError, "Length dimensions"):
F.add_noise(waveform, noise, lengths, snr)
F.add_noise(waveform, noise, snr, lengths)
@nested_params(
[(2, 3), (2, 3, 5), (2, 3, 5, 7)],
......
......@@ -37,16 +37,20 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
self._assert_consistency(getattr(F, fn), (x, y, mode))
def test_add_noise(self):
@nested_params([True, False])
def test_add_noise(self, use_lengths):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
else:
lengths = None
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
self._assert_consistency(F.add_noise, (waveform, noise, lengths, snr))
self._assert_consistency(F.add_noise, (waveform, noise, snr, lengths))
def test_barkscale_fbanks(self):
if self.device != torch.device("cpu"):
......
......@@ -75,18 +75,22 @@ class Autograd(TestBaseMixin):
assert gradcheck(speed, (waveform, lengths))
assert gradgradcheck(speed, (waveform, lengths))
def test_AddNoise(self):
@nested_params([True, False])
def test_AddNoise(self, use_lengths):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L, dtype=torch.float64, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=torch.float64, device=self.device, requires_grad=True)
lengths = torch.rand(*leading_dims, dtype=torch.float64, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.rand(*leading_dims, dtype=torch.float64, device=self.device, requires_grad=True)
else:
lengths = None
snr = torch.rand(*leading_dims, dtype=torch.float64, device=self.device, requires_grad=True) * 10
add_noise = T.AddNoise().to(self.device, torch.float64)
assert gradcheck(add_noise, (waveform, noise, lengths, snr))
assert gradgradcheck(add_noise, (waveform, noise, lengths, snr))
assert gradcheck(add_noise, (waveform, noise, snr, lengths))
assert gradgradcheck(add_noise, (waveform, noise, snr, lengths))
def test_Preemphasis(self):
waveform = torch.rand(3, 4, 10, dtype=torch.float64, device=self.device, requires_grad=True)
......
......@@ -124,13 +124,13 @@ class BatchConsistencyTest(TorchaudioTestCase):
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
actual = add_noise(waveform, noise, lengths, snr)
actual = add_noise(waveform, noise, snr, lengths)
expected = []
for i in range(leading_dims[0]):
for j in range(leading_dims[1]):
for k in range(leading_dims[2]):
expected.append(add_noise(waveform[i][j][k], noise[i][j][k], lengths[i][j][k], snr[i][j][k]))
expected.append(add_noise(waveform[i][j][k], noise[i][j][k], snr[i][j][k], lengths[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
......
......@@ -41,18 +41,22 @@ class Transforms(TestBaseMixin):
ts_output = torch_script(speed)(waveform, lengths)
self.assertEqual(ts_output, output)
def test_AddNoise(self):
@nested_params([True, False])
def test_AddNoise(self, use_lengths):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
else:
lengths = None
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
add_noise = T.AddNoise().to(self.device, self.dtype)
output = add_noise(waveform, noise, lengths, snr)
ts_output = torch_script(add_noise)(waveform, noise, lengths, snr)
output = add_noise(waveform, noise, snr, lengths)
ts_output = torch_script(add_noise)(waveform, noise, snr, lengths)
self.assertEqual(ts_output, output)
def test_Preemphasis(self):
......
......@@ -184,12 +184,12 @@ class TransformsTestImpl(TestBaseMixin):
snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
actual = add_noise(waveform, noise, lengths, snr)
actual = add_noise(waveform, noise, snr, lengths)
noise_expanded = noise.expand(*leading_dims, L)
snr_expanded = snr.expand(*leading_dims)
lengths_expanded = lengths.expand(*leading_dims)
expected = add_noise(waveform, noise_expanded, lengths_expanded, snr_expanded)
expected = add_noise(waveform, noise_expanded, snr_expanded, lengths_expanded)
self.assertEqual(expected, actual)
......@@ -208,7 +208,7 @@ class TransformsTestImpl(TestBaseMixin):
add_noise = T.AddNoise()
with self.assertRaisesRegex(ValueError, "Input leading dimensions"):
add_noise(waveform, noise, lengths, snr)
add_noise(waveform, noise, snr, lengths)
def test_AddNoise_length_check(self):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
......@@ -223,7 +223,7 @@ class TransformsTestImpl(TestBaseMixin):
add_noise = T.AddNoise()
with self.assertRaisesRegex(ValueError, "Length dimensions"):
add_noise(waveform, noise, lengths, snr)
add_noise(waveform, noise, snr, lengths)
@nested_params(
[(2, 1, 31)],
......
import math
import warnings
from typing import Tuple
from typing import Optional, Tuple
import torch
from torchaudio.functional import lfilter, resample
......@@ -134,7 +134,9 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens
return _apply_convolve_mode(result, x_size, y_size, mode)
def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor, snr: torch.Tensor) -> torch.Tensor:
def add_noise(
waveform: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor, lengths: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""Scales and adds noise to waveform per signal-to-noise ratio.
Specifically, for each pair of waveform vector :math:`x \in \mathbb{R}^L` and noise vector
......@@ -160,16 +162,17 @@ def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor
Args:
waveform (torch.Tensor): Input waveform, with shape `(..., L)`.
noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``).
lengths (torch.Tensor): Valid lengths of signals in ``waveform`` and ``noise``, with shape `(...,)`
(leading dimensions must match those of ``waveform``).
snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`.
lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform`` and ``noise``, with shape
`(...,)` (leading dimensions must match those of ``waveform``). If ``None``, all elements in ``waveform``
and ``noise`` are treated as valid. (Default: ``None``)
Returns:
torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)`
(same shape as ``waveform``).
"""
if not (waveform.ndim - 1 == noise.ndim - 1 == lengths.ndim == snr.ndim):
if not (waveform.ndim - 1 == noise.ndim - 1 == snr.ndim and (lengths is None or lengths.ndim == snr.ndim)):
raise ValueError("Input leading dimensions don't match.")
L = waveform.size(-1)
......@@ -178,11 +181,18 @@ def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor
raise ValueError(f"Length dimensions of waveform and noise don't match (got {L} and {noise.size(-1)}).")
# compute scale
mask = torch.arange(0, L, device=lengths.device).expand(waveform.shape) < lengths.unsqueeze(
-1
) # (*, L) < (*, 1) = (*, L)
energy_signal = torch.linalg.vector_norm(waveform * mask, ord=2, dim=-1) ** 2 # (*,)
energy_noise = torch.linalg.vector_norm(noise * mask, ord=2, dim=-1) ** 2 # (*,)
if lengths is not None:
mask = torch.arange(0, L, device=lengths.device).expand(waveform.shape) < lengths.unsqueeze(
-1
) # (*, L) < (*, 1) = (*, L)
masked_waveform = waveform * mask
masked_noise = noise * mask
else:
masked_waveform = waveform
masked_noise = noise
energy_signal = torch.linalg.vector_norm(masked_waveform, ord=2, dim=-1) ** 2 # (*,)
energy_noise = torch.linalg.vector_norm(masked_noise, ord=2, dim=-1) ** 2 # (*,)
original_snr_db = 10 * (torch.log10(energy_signal) - torch.log10(energy_noise))
scale = 10 ** ((original_snr_db - snr) / 20.0) # (*,)
......
......@@ -495,21 +495,22 @@ class AddNoise(torch.nn.Module):
"""
def forward(
self, waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor, snr: torch.Tensor
self, waveform: torch.Tensor, noise: torch.Tensor, snr: torch.Tensor, lengths: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""
Args:
waveform (torch.Tensor): Input waveform, with shape `(..., L)`.
noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``).
lengths (torch.Tensor): Valid lengths of signals in ``waveform`` and ``noise``, with shape `(...,)`
(leading dimensions must match those of ``waveform``).
snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`.
lengths (torch.Tensor or None, optional): Valid lengths of signals in ``waveform`` and ``noise``,
with shape `(...,)` (leading dimensions must match those of ``waveform``). If ``None``, all
elements in ``waveform`` and ``noise`` are treated as valid. (Default: ``None``)
Returns:
torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)`
(same shape as ``waveform``).
"""
return add_noise(waveform, noise, lengths, snr)
return add_noise(waveform, noise, snr, lengths)
class Preemphasis(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment