Commit 29ecf7e8 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add additive noise transform (#2889)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2889

Reviewed By: xiaohui-zhang

Differential Revision: D41760084

Pulled By: hwangjeff

fbshipit-source-id: d2f5253e1fae7e7aafa9fa6043c6a7045c5b33a0
parent 45c7d05a
......@@ -9,6 +9,7 @@ torchaudio.prototype.transforms
:toctree: generated
:nosignatures:
AddNoise
Convolve
FFTConvolve
BarkScale
......
......@@ -74,3 +74,16 @@ class Autograd(TestBaseMixin):
speed = T.SpeedPerturbation(1000, [0.9]).to(device=self.device, dtype=torch.float64)
assert gradcheck(speed, (waveform, lengths))
assert gradgradcheck(speed, (waveform, lengths))
def test_AddNoise(self):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L, dtype=torch.float64, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=torch.float64, device=self.device, requires_grad=True)
lengths = torch.rand(*leading_dims, dtype=torch.float64, device=self.device, requires_grad=True)
snr = torch.rand(*leading_dims, dtype=torch.float64, device=self.device, requires_grad=True) * 10
add_noise = T.AddNoise().to(self.device, torch.float64)
assert gradcheck(add_noise, (waveform, noise, lengths, snr))
assert gradgradcheck(add_noise, (waveform, noise, lengths, snr))
......@@ -113,3 +113,23 @@ class BatchConsistencyTest(TorchaudioTestCase):
for idx in range(len(unbatched_output)):
w, l = output[idx], output_lengths[idx]
self.assertEqual(unbatched_output[idx], w[:l])
def test_AddNoise(self):
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
actual = add_noise(waveform, noise, lengths, snr)
expected = []
for i in range(leading_dims[0]):
for j in range(leading_dims[1]):
for k in range(leading_dims[2]):
expected.append(add_noise(waveform[i][j][k], noise[i][j][k], lengths[i][j][k], snr[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
......@@ -40,3 +40,17 @@ class Transforms(TestBaseMixin):
output = speed(waveform, lengths)
ts_output = torch_script(speed)(waveform, lengths)
self.assertEqual(ts_output, output)
def test_AddNoise(self):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
add_noise = T.AddNoise().to(self.device, self.dtype)
output = add_noise(waveform, noise, lengths, snr)
ts_output = torch_script(add_noise)(waveform, noise, lengths, snr)
self.assertEqual(ts_output, output)
......@@ -5,6 +5,7 @@ from unittest.mock import patch
import numpy as np
import torch
import torchaudio.prototype.transforms as T
from parameterized import parameterized
from scipy import signal
from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, TestBaseMixin
......@@ -169,3 +170,55 @@ class TransformsTestImpl(TestBaseMixin):
atol=1e-1,
rtol=1e-4,
)
def test_AddNoise_broadcast(self):
"""Check that add_noise produces correct outputs when broadcasting input dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(5, 1, 1, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(5, 1, 3, dtype=self.dtype, device=self.device)
snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
actual = add_noise(waveform, noise, lengths, snr)
noise_expanded = noise.expand(*leading_dims, L)
snr_expanded = snr.expand(*leading_dims)
lengths_expanded = lengths.expand(*leading_dims)
expected = add_noise(waveform, noise_expanded, lengths_expanded, snr_expanded)
self.assertEqual(expected, actual)
@parameterized.expand(
[((5, 2, 3), (2, 1, 1), (5, 2), (5, 2, 3)), ((2, 1), (5,), (5,), (5,)), ((3,), (5, 2, 3), (2, 1, 1), (5, 2))]
)
def test_AddNoise_leading_dim_check(self, waveform_dims, noise_dims, lengths_dims, snr_dims):
"""Check that add_noise properly rejects inputs with different leading dimension lengths."""
L = 51
waveform = torch.rand(*waveform_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*noise_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*lengths_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*snr_dims, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
with self.assertRaisesRegex(ValueError, "Input leading dimensions"):
add_noise(waveform, noise, lengths, snr)
def test_AddNoise_length_check(self):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, 50, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
with self.assertRaisesRegex(ValueError, "Length dimensions"):
add_noise(waveform, noise, lengths, snr)
from ._transforms import BarkScale, BarkSpectrogram, Convolve, FFTConvolve, InverseBarkScale, Speed, SpeedPerturbation
from ._transforms import (
AddNoise,
BarkScale,
BarkSpectrogram,
Convolve,
FFTConvolve,
InverseBarkScale,
Speed,
SpeedPerturbation,
)
__all__ = [
"AddNoise",
"BarkScale",
"BarkSpectrogram",
"Convolve",
......
......@@ -2,7 +2,7 @@ import math
from typing import Callable, Optional, Sequence, Tuple
import torch
from torchaudio.prototype.functional import barkscale_fbanks, convolve, fftconvolve
from torchaudio.prototype.functional import add_noise, barkscale_fbanks, convolve, fftconvolve
from torchaudio.prototype.functional.functional import _check_convolve_mode
from torchaudio.transforms import Resample, Spectrogram
......@@ -483,3 +483,30 @@ class SpeedPerturbation(torch.nn.Module):
if idx == speeder_idx:
return speeder(waveform, lengths)
raise RuntimeError("Speeder not found; execution should have never reached here.")
class AddNoise(torch.nn.Module):
r"""Scales and adds noise to waveform per signal-to-noise ratio.
See :meth:`torchaudio.prototype.functional.add_noise` for more details.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
"""
def forward(
self, waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor, snr: torch.Tensor
) -> torch.Tensor:
r"""
Args:
waveform (torch.Tensor): Input waveform, with shape `(..., L)`.
noise (torch.Tensor): Noise, with shape `(..., L)` (same shape as ``waveform``).
lengths (torch.Tensor): Valid lengths of signals in ``waveform`` and ``noise``, with shape `(...,)`
(leading dimensions must match those of ``waveform``).
snr (torch.Tensor): Signal-to-noise ratios in dB, with shape `(...,)`.
Returns:
torch.Tensor: Result of scaling and adding ``noise`` to ``waveform``, with shape `(..., L)`
(same shape as ``waveform``).
"""
return add_noise(waveform, noise, lengths, snr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment