Commit c2b62ae8 authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Add preemphasis and deemphasis transforms (#2935)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2935

Reviewed By: mthrok

Differential Revision: D42302275

Pulled By: hwangjeff

fbshipit-source-id: d995d335bf17d63d3c1dda77d8ef596570853638
parent d0dca115
...@@ -17,3 +17,5 @@ torchaudio.prototype.transforms ...@@ -17,3 +17,5 @@ torchaudio.prototype.transforms
BarkSpectrogram BarkSpectrogram
Speed Speed
SpeedPerturbation SpeedPerturbation
Deemphasis
Preemphasis
...@@ -87,3 +87,15 @@ class Autograd(TestBaseMixin): ...@@ -87,3 +87,15 @@ class Autograd(TestBaseMixin):
add_noise = T.AddNoise().to(self.device, torch.float64) add_noise = T.AddNoise().to(self.device, torch.float64)
assert gradcheck(add_noise, (waveform, noise, lengths, snr)) assert gradcheck(add_noise, (waveform, noise, lengths, snr))
assert gradgradcheck(add_noise, (waveform, noise, lengths, snr)) assert gradgradcheck(add_noise, (waveform, noise, lengths, snr))
def test_Preemphasis(self):
waveform = torch.rand(3, 4, 10, dtype=torch.float64, device=self.device, requires_grad=True)
preemphasis = T.Preemphasis(coeff=0.97).to(dtype=torch.float64, device=self.device)
assert gradcheck(preemphasis, (waveform,))
assert gradgradcheck(preemphasis, (waveform,))
def test_Deemphasis(self):
waveform = torch.rand(3, 4, 10, dtype=torch.float64, device=self.device, requires_grad=True)
deemphasis = T.Deemphasis(coeff=0.97).to(dtype=torch.float64, device=self.device)
assert gradcheck(deemphasis, (waveform,))
assert gradgradcheck(deemphasis, (waveform,))
...@@ -133,3 +133,29 @@ class BatchConsistencyTest(TorchaudioTestCase): ...@@ -133,3 +133,29 @@ class BatchConsistencyTest(TorchaudioTestCase):
expected.append(add_noise(waveform[i][j][k], noise[i][j][k], lengths[i][j][k], snr[i][j][k])) expected.append(add_noise(waveform[i][j][k], noise[i][j][k], lengths[i][j][k], snr[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L)) self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
def test_Preemphasis(self):
waveform = torch.rand((3, 5, 2, 100), dtype=self.dtype, device=self.device)
preemphasis = T.Preemphasis(coeff=0.97)
actual = preemphasis(waveform)
expected = []
for i in range(waveform.size(0)):
for j in range(waveform.size(1)):
for k in range(waveform.size(2)):
expected.append(preemphasis(waveform[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, waveform.size(-1)))
def test_Deemphasis(self):
waveform = torch.rand((3, 5, 2, 100), dtype=self.dtype, device=self.device)
deemphasis = T.Deemphasis(coeff=0.97)
actual = deemphasis(waveform)
expected = []
for i in range(waveform.size(0)):
for j in range(waveform.size(1)):
for k in range(waveform.size(2)):
expected.append(deemphasis(waveform[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, waveform.size(-1)))
...@@ -54,3 +54,17 @@ class Transforms(TestBaseMixin): ...@@ -54,3 +54,17 @@ class Transforms(TestBaseMixin):
output = add_noise(waveform, noise, lengths, snr) output = add_noise(waveform, noise, lengths, snr)
ts_output = torch_script(add_noise)(waveform, noise, lengths, snr) ts_output = torch_script(add_noise)(waveform, noise, lengths, snr)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def test_Preemphasis(self):
waveform = torch.rand(3, 4, 10, dtype=self.dtype, device=self.device)
preemphasis = T.Preemphasis(coeff=0.97).to(dtype=self.dtype, device=self.device)
output = preemphasis(waveform)
ts_output = torch_script(preemphasis)(waveform)
self.assertEqual(ts_output, output)
def test_Deemphasis(self):
waveform = torch.rand(3, 4, 10, dtype=self.dtype, device=self.device)
deemphasis = T.Deemphasis(coeff=0.97).to(dtype=self.dtype, device=self.device)
output = deemphasis(waveform)
ts_output = torch_script(deemphasis)(waveform)
self.assertEqual(ts_output, output)
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
import torchaudio.prototype.transforms as T import torchaudio.prototype.transforms as T
from parameterized import parameterized from parameterized import parameterized
from scipy import signal from scipy import signal
from torchaudio.functional import lfilter
from torchaudio.prototype.functional import preemphasis
from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, TestBaseMixin from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, TestBaseMixin
...@@ -222,3 +224,28 @@ class TransformsTestImpl(TestBaseMixin): ...@@ -222,3 +224,28 @@ class TransformsTestImpl(TestBaseMixin):
with self.assertRaisesRegex(ValueError, "Length dimensions"): with self.assertRaisesRegex(ValueError, "Length dimensions"):
add_noise(waveform, noise, lengths, snr) add_noise(waveform, noise, lengths, snr)
@nested_params(
[(2, 1, 31)],
[0.97, 0.72],
)
def test_Preemphasis(self, input_shape, coeff):
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
preemphasis = T.Preemphasis(coeff=coeff).to(dtype=self.dtype, device=self.device)
actual = preemphasis(waveform)
a_coeffs = torch.tensor([1.0, 0.0], device=self.device, dtype=self.dtype)
b_coeffs = torch.tensor([1.0, -coeff], device=self.device, dtype=self.dtype)
expected = lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
self.assertEqual(actual, expected)
@nested_params(
[(2, 1, 31)],
[0.97, 0.72],
)
def test_Deemphasis(self, input_shape, coeff):
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
preemphasized = preemphasis(waveform, coeff=coeff)
deemphasis = T.Deemphasis(coeff=coeff).to(dtype=self.dtype, device=self.device)
deemphasized = deemphasis(preemphasized)
self.assertEqual(deemphasized, waveform)
...@@ -3,8 +3,10 @@ from ._transforms import ( ...@@ -3,8 +3,10 @@ from ._transforms import (
BarkScale, BarkScale,
BarkSpectrogram, BarkSpectrogram,
Convolve, Convolve,
Deemphasis,
FFTConvolve, FFTConvolve,
InverseBarkScale, InverseBarkScale,
Preemphasis,
Speed, Speed,
SpeedPerturbation, SpeedPerturbation,
) )
...@@ -14,8 +16,10 @@ __all__ = [ ...@@ -14,8 +16,10 @@ __all__ = [
"BarkScale", "BarkScale",
"BarkSpectrogram", "BarkSpectrogram",
"Convolve", "Convolve",
"Deemphasis",
"FFTConvolve", "FFTConvolve",
"InverseBarkScale", "InverseBarkScale",
"Preemphasis",
"SpeedPerturbation", "SpeedPerturbation",
"Speed", "Speed",
] ]
...@@ -2,7 +2,7 @@ import math ...@@ -2,7 +2,7 @@ import math
from typing import Callable, Optional, Sequence, Tuple from typing import Callable, Optional, Sequence, Tuple
import torch import torch
from torchaudio.prototype.functional import add_noise, barkscale_fbanks, convolve, fftconvolve from torchaudio.prototype.functional import add_noise, barkscale_fbanks, convolve, deemphasis, fftconvolve, preemphasis
from torchaudio.prototype.functional.functional import _check_convolve_mode from torchaudio.prototype.functional.functional import _check_convolve_mode
from torchaudio.transforms import Resample, Spectrogram from torchaudio.transforms import Resample, Spectrogram
...@@ -510,3 +510,59 @@ class AddNoise(torch.nn.Module): ...@@ -510,3 +510,59 @@ class AddNoise(torch.nn.Module):
(same shape as ``waveform``). (same shape as ``waveform``).
""" """
return add_noise(waveform, noise, lengths, snr) return add_noise(waveform, noise, lengths, snr)
class Preemphasis(torch.nn.Module):
r"""Pre-emphasizes a waveform along its last dimension.
See :meth:`torchaudio.prototype.functional.preemphasis` for more details.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Args:
coeff (float, optional): Pre-emphasis coefficient. Typically between 0.0 and 1.0.
(Default: 0.97)
"""
def __init__(self, coeff: float = 0.97) -> None:
super().__init__()
self.coeff = coeff
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
r"""
Args:
waveform (torch.Tensor): Waveform, with shape `(..., N)`.
Returns:
torch.Tensor: Pre-emphasized waveform, with shape `(..., N)`.
"""
return preemphasis(waveform, coeff=self.coeff)
class Deemphasis(torch.nn.Module):
r"""De-emphasizes a waveform along its last dimension.
See :meth:`torchaudio.prototype.functional.deemphasis` for more details.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Args:
coeff (float, optional): De-emphasis coefficient. Typically between 0.0 and 1.0.
(Default: 0.97)
"""
def __init__(self, coeff: float = 0.97) -> None:
super().__init__()
self.coeff = coeff
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
r"""
Args:
waveform (torch.Tensor): Waveform, with shape `(..., N)`.
Returns:
torch.Tensor: De-emphasized waveform, with shape `(..., N)`.
"""
return deemphasis(waveform, coeff=self.coeff)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment