Commit d234498c authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add frequency_impulse_response (#2879)

Summary:
This commit adds `frequency_impulse_response` function, which generates filter from desired frequency response.

[Example](https://output.circle-artifacts.com/output/job/5233fda9-dadb-4710-9389-7e8ac20a062f/artifacts/0/docs/tutorials/filter_design_tutorial.html#frequency-sampling)

Pull Request resolved: https://github.com/pytorch/audio/pull/2879

Reviewed By: hwangjeff

Differential Revision: D41767787

Pulled By: mthrok

fbshipit-source-id: 6d5e44c6390e8cf3028994a1b1de590ff3aaf6c2
parent d8a5a11d
......@@ -50,3 +50,4 @@ DSP
extend_pitch
oscillator_bank
sinc_impulse_response
frequency_impulse_response
......@@ -87,3 +87,7 @@ class AutogradTestImpl(TestBaseMixin):
coeff = 0.9
self.assertTrue(gradcheck(F.deemphasis, (waveform, coeff)))
self.assertTrue(gradgradcheck(F.deemphasis, (waveform, coeff)))
def test_freq_ir(self):
mags = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.frequency_impulse_response, (mags,))
......@@ -35,3 +35,9 @@ def sinc_ir(cutoff: ArrayLike, window_size: int = 513, high_pass: bool = False):
filt *= -1
filt[..., half] = 1.0 + filt[..., half]
return filt
def freq_ir(magnitudes):
ir = np.fft.fftshift(np.fft.irfft(magnitudes), axes=-1)
window = np.hanning(ir.shape[-1])
return (ir * window).astype(magnitudes.dtype)
......@@ -8,7 +8,7 @@ from scipy import signal
from torchaudio.functional import lfilter
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
from .dsp_utils import oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np
from .dsp_utils import freq_ir as freq_ir_np, oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np
def _prod(l):
......@@ -470,6 +470,22 @@ class FunctionalTestImpl(TestBaseMixin):
deemphasized = F.deemphasis(preemphasized, coeff=coeff)
self.assertEqual(deemphasized, waveform)
def test_freq_ir_warns_negative_values(self):
"""frequency_impulse_response warns negative input value"""
magnitudes = -torch.ones((1, 30), device=self.device, dtype=self.dtype)
with self.assertWarnsRegex(UserWarning, "^.+should not contain negative values.$"):
F.frequency_impulse_response(magnitudes)
@parameterized.expand([((2, 3, 4),), ((1000,),)])
def test_freq_ir_reference(self, shape):
"""frequency_impulse_response produces the same result as reference implementation"""
magnitudes = torch.rand(shape, device=self.device, dtype=self.dtype)
hyp = F.frequency_impulse_response(magnitudes)
ref = freq_ir_np(magnitudes.cpu().numpy())
self.assertEqual(hyp, ref)
class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
......
......@@ -98,3 +98,7 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
self._assert_consistency(F.deemphasis, (waveform, coeff))
def test_freq_ir(self):
mags = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype)
self._assert_consistency(F.frequency_impulse_response, (mags,))
from ._dsp import adsr_envelope, extend_pitch, oscillator_bank, sinc_impulse_response
from ._dsp import adsr_envelope, extend_pitch, frequency_impulse_response, oscillator_bank, sinc_impulse_response
from .functional import add_noise, barkscale_fbanks, convolve, deemphasis, fftconvolve, preemphasis, speed
__all__ = [
"add_noise",
"adsr_envelope",
......@@ -9,6 +10,7 @@ __all__ = [
"deemphasis",
"extend_pitch",
"fftconvolve",
"frequency_impulse_response",
"oscillator_bank",
"preemphasis",
"sinc_impulse_response",
......
......@@ -286,3 +286,23 @@ def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pas
filt = -filt
filt[..., half] = 1.0 + filt[..., half]
return filt
def frequency_impulse_response(magnitudes):
"""Create filter from desired frequency response
Args:
magnitudes: The desired frequency responses. Shape: `(..., num_fft_bins)`
Returns:
Tensor: Impulse response. Shape `(..., 2 * (num_fft_bins - 1))`
"""
if magnitudes.min() < 0.0:
# Negative magnitude does not make sense but allowing so that autograd works
# around 0.
# Should we raise error?
warnings.warn("The input frequency response should not contain negative values.")
ir = torch.fft.fftshift(torch.fft.irfft(magnitudes), dim=-1)
device, dtype = magnitudes.device, magnitudes.dtype
window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir)
return ir * window
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment