Commit 55e9978a authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add pre-emphasis and de-emphasis functions (#2871)

Summary:
Adds pre-emphasis and de-emphasis functions.

Pull Request resolved: https://github.com/pytorch/audio/pull/2871

Reviewed By: carolineechen

Differential Revision: D41651097

Pulled By: hwangjeff

fbshipit-source-id: 7a3cf6ce68b6ce1b9ae315ddd8bd8ed71acccdf1
parent c28073cc
......@@ -19,11 +19,21 @@ convolve
.. autofunction:: convolve
deemphasis
~~~~~~~~~~
.. autofunction:: deemphasis
fftconvolve
~~~~~~~~~~~
.. autofunction:: fftconvolve
preemphasis
~~~~~~~~~~~
.. autofunction:: preemphasis
speed
~~~~~
......
......@@ -75,3 +75,15 @@ class AutogradTestImpl(TestBaseMixin):
lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device)
self.assertTrue(gradcheck(F.speed, (waveform, lengths, 1000, 1.1)))
self.assertTrue(gradgradcheck(F.speed, (waveform, lengths, 1000, 1.1)))
def test_preemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
coeff = 0.9
self.assertTrue(gradcheck(F.preemphasis, (waveform, coeff)))
self.assertTrue(gradgradcheck(F.preemphasis, (waveform, coeff)))
def test_deemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
coeff = 0.9
self.assertTrue(gradcheck(F.deemphasis, (waveform, coeff)))
self.assertTrue(gradgradcheck(F.deemphasis, (waveform, coeff)))
......@@ -67,3 +67,25 @@ class BatchConsistencyTest(TorchaudioTestCase):
for idx in range(len(unbatched_output)):
w, l = output[idx], output_lengths[idx]
self.assertEqual(unbatched_output[idx], w[:l])
def test_preemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
actual = F.preemphasis(waveform, coeff=coeff)
expected = []
for i in range(waveform.size(0)):
expected.append(F.preemphasis(waveform[i], coeff=coeff))
self.assertEqual(torch.stack(expected), actual)
def test_deemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
actual = F.deemphasis(waveform, coeff=coeff)
expected = []
for i in range(waveform.size(0)):
expected.append(F.deemphasis(waveform[i], coeff=coeff))
self.assertEqual(torch.stack(expected), actual)
......@@ -5,6 +5,7 @@ import torch
import torchaudio.prototype.functional as F
from parameterized import param, parameterized
from scipy import signal
from torchaudio.functional import lfilter
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
from .dsp_utils import oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np
......@@ -446,6 +447,29 @@ class FunctionalTestImpl(TestBaseMixin):
expected_waveform[..., n_to_trim:-n_to_trim], output[..., n_to_trim:-n_to_trim], atol=1e-1, rtol=1e-4
)
@nested_params(
[(3, 2, 100), (95,)],
[0.97, 0.9, 0.68],
)
def test_preemphasis(self, input_shape, coeff):
waveform = torch.rand(*input_shape, device=self.device, dtype=self.dtype)
actual = F.preemphasis(waveform, coeff=coeff)
a_coeffs = torch.tensor([1.0, 0.0], device=self.device, dtype=self.dtype)
b_coeffs = torch.tensor([1.0, -coeff], device=self.device, dtype=self.dtype)
expected = lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
self.assertEqual(actual, expected)
@nested_params(
[(3, 2, 100), (95,)],
[0.97, 0.9, 0.68],
)
def test_preemphasis_deemphasis_roundtrip(self, input_shape, coeff):
waveform = torch.rand(*input_shape, device=self.device, dtype=self.dtype)
preemphasized = F.preemphasis(waveform, coeff=coeff)
deemphasized = F.deemphasis(preemphasized, coeff=coeff)
self.assertEqual(deemphasized, waveform)
class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
......
......@@ -88,3 +88,13 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True)
lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device)
self._assert_consistency(F.speed, (waveform, lengths, 1000, 1.1))
def test_preemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
self._assert_consistency(F.preemphasis, (waveform, coeff))
def test_deemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
self._assert_consistency(F.deemphasis, (waveform, coeff))
from ._dsp import adsr_envelope, extend_pitch, oscillator_bank, sinc_impulse_response
from .functional import add_noise, barkscale_fbanks, convolve, fftconvolve, speed
from .functional import add_noise, barkscale_fbanks, convolve, deemphasis, fftconvolve, preemphasis, speed
__all__ = [
"add_noise",
"adsr_envelope",
"barkscale_fbanks",
"convolve",
"deemphasis",
"extend_pitch",
"fftconvolve",
"oscillator_bank",
"preemphasis",
"sinc_impulse_response",
"speed",
]
......@@ -4,8 +4,7 @@ import warnings
from typing import Tuple
import torch
from torchaudio.functional import resample
from torchaudio.functional import lfilter, resample
from torchaudio.functional.functional import _create_triangular_filterbank
......@@ -345,3 +344,45 @@ def speed(
return resample(waveform, source_sample_rate, target_sample_rate), torch.ceil(
lengths * target_sample_rate / source_sample_rate
).to(lengths.dtype)
def preemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
r"""Pre-emphasizes a waveform along its last dimension, i.e.
for each signal :math:`x` in ``waveform``, computes
output :math:`y` as
.. math::
y[i] = x[i] - \text{coeff} \cdot x[i - 1]
Args:
waveform (torch.Tensor): Waveform, with shape `(..., N)`.
coeff (float, optional): Pre-emphasis coefficient. Typically between 0.0 and 1.0.
(Default: 0.97)
Returns:
torch.Tensor: Pre-emphasized waveform, with shape `(..., N)`.
"""
waveform = waveform.clone()
waveform[..., 1:] -= coeff * waveform[..., :-1]
return waveform
def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
r"""De-emphasizes a waveform along its last dimension.
Inverse of :meth:`preemphasis`. Concretely, for each signal
:math:`x` in ``waveform``, computes output :math:`y` as
.. math::
y[i] = x[i] + \text{coeff} \cdot y[i - 1]
Args:
waveform (torch.Tensor): Waveform, with shape `(..., N)`.
coeff (float, optional): De-emphasis coefficient. Typically between 0.0 and 1.0.
(Default: 0.97)
Returns:
torch.Tensor: De-emphasized waveform, with shape `(..., N)`.
"""
a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device)
b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device)
return lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment