Commit e502b10c authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add oscillator_bank (#2848)

Summary:
This commit adds `oscillator_bank` op, which is the core of (differential) digital signal processing ops.
The implementation itself is pretty simple, sum instantaneous frequencies, take sin and multiply with amplitudes.

Following the magenta implementation, amplitudes for frequency range outside of [-Nyquist, Nyquist] \
are suppressed.

The differentiability is tested within frequency range of [- Nyquist, Nyquist], and amplitude range of [-5, 5], which should be enough.

For example usages:
 - https://output.circle-artifacts.com/output/job/129f3e21-41ce-406b-bc6b-833efb3c3141/artifacts/0/docs/tutorials/oscillator_tutorial.html
 - https://output.circle-artifacts.com/output/job/129f3e21-41ce-406b-bc6b-833efb3c3141/artifacts/0/docs/tutorials/synthesis_tutorial.html

Part of https://github.com/pytorch/audio/issues/2835
Extracted from https://github.com/pytorch/audio/issues/2808

Pull Request resolved: https://github.com/pytorch/audio/pull/2848

Reviewed By: carolineechen

Differential Revision: D41353075

Pulled By: mthrok

fbshipit-source-id: 80e60772fb555760f2396f7df40458803c280225
parent e062110b
......@@ -23,3 +23,12 @@ fftconvolve
~~~~~~~~~~~
.. autofunction:: fftconvolve
DSP
~~~
.. autosummary::
:toctree: generated
:nosignatures:
oscillator_bank
import torch
import torchaudio.prototype.functional as F
from parameterized import parameterized
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
......@@ -28,3 +29,28 @@ class AutogradTestImpl(TestBaseMixin):
self.assertTrue(gradcheck(F.add_noise, (waveform, noise, lengths, snr)))
self.assertTrue(gradgradcheck(F.add_noise, (waveform, noise, lengths, snr)))
@parameterized.expand(
[
(8000, (2, 3, 5, 7)),
(8000, (8000, 1)),
]
)
def test_oscillator_bank(self, sample_rate, shape):
# can be replaced with math.prod when we drop 3.7 support
def prod(iterable):
ret = 1
for item in iterable:
ret *= item
return ret
numel = prod(shape)
# use 1.9 instead of 2 so as to include values above nyquist frequency
fmax = sample_rate / 1.9
freq = torch.linspace(-fmax, fmax, numel, dtype=self.dtype, device=self.device, requires_grad=True).reshape(
shape
)
amps = torch.linspace(-5, 5, numel, dtype=self.dtype, device=self.device, requires_grad=True).reshape(shape)
assert gradcheck(F.oscillator_bank, (freq, amps, sample_rate))
import numpy as np
from numpy.typing import ArrayLike
def oscillator_bank(
frequencies: ArrayLike,
amplitudes: ArrayLike,
sample_rate: float,
time_axis: int = -2,
) -> ArrayLike:
"""Reference implementation of oscillator_bank"""
invalid = np.abs(frequencies) >= sample_rate / 2
if np.any(invalid):
amplitudes = np.where(invalid, 0.0, amplitudes)
pi2 = 2.0 * np.pi
freqs = frequencies * pi2 / sample_rate % pi2
phases = np.cumsum(freqs, axis=time_axis, dtype=freqs.dtype)
waveform = amplitudes * np.sin(phases)
return waveform
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .functional_test_impl import FunctionalTestImpl
from .functional_test_impl import Functional64OnlyTestImpl, FunctionalTestImpl
class FunctionalFloat32CPUTest(FunctionalTestImpl, PytorchTestCase):
......@@ -12,3 +12,8 @@ class FunctionalFloat32CPUTest(FunctionalTestImpl, PytorchTestCase):
class FunctionalFloat64CPUTest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
class FunctionalFloat64OnlyCPUTest(Functional64OnlyTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .functional_test_impl import FunctionalTestImpl
from .functional_test_impl import Functional64OnlyTestImpl, FunctionalTestImpl
@skipIfNoCuda
class FunctionalFloat32CUDATest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
device = torch.device("cuda", 0)
@skipIfNoCuda
class FunctionalFloat64CUDATest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda", 0)
@skipIfNoCuda
class FunctionalFloat64OnlyCUDATest(Functional64OnlyTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
......@@ -5,6 +5,8 @@ from parameterized import parameterized
from scipy import signal
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
from .dsp_utils import oscillator_bank as oscillator_bank_np
class FunctionalTestImpl(TestBaseMixin):
@nested_params(
......@@ -109,3 +111,89 @@ class FunctionalTestImpl(TestBaseMixin):
with self.assertRaisesRegex(ValueError, "Length dimensions"):
F.add_noise(waveform, noise, lengths, snr)
@nested_params(
[(2, 3), (2, 3, 5), (2, 3, 5, 7)],
["sum", "mean", "none"],
)
def test_oscillator_bank_smoke_test(self, shape, reduction):
"""oscillator_bank supports variable dimension inputs on different device/dtypes"""
sample_rate = 8000
freqs = sample_rate // 2 * torch.rand(shape, dtype=self.dtype, device=self.device)
amps = torch.rand(shape, dtype=self.dtype, device=self.device)
waveform = F.oscillator_bank(freqs, amps, sample_rate, reduction=reduction)
expected_shape = shape if reduction == "none" else shape[:-1]
assert waveform.shape == expected_shape
assert waveform.dtype == self.dtype
assert waveform.device == self.device
def test_oscillator_invalid(self):
"""oscillator_bank rejects/warns invalid inputs"""
valid_shape = [2, 3, 5]
sample_rate = 8000
freqs = torch.ones(*valid_shape, dtype=self.dtype, device=self.device)
amps = torch.rand(*valid_shape, dtype=self.dtype, device=self.device)
# mismatching shapes
with self.assertRaises(ValueError):
F.oscillator_bank(freqs[0], amps, sample_rate)
# frequencies out of range
nyquist = sample_rate / 2
with self.assertWarnsRegex(UserWarning, r"above nyquist frequency"):
F.oscillator_bank(nyquist * freqs, amps, sample_rate)
with self.assertWarnsRegex(UserWarning, r"above nyquist frequency"):
F.oscillator_bank(-nyquist * freqs, amps, sample_rate)
class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
[1, 10, 100, 1000],
[1, 2, 4, 8],
[8000, 16000],
)
def test_oscillator_ref(self, f0, num_pitches, sample_rate):
"""oscillator_bank returns the matching values as reference implementation
Note: It looks like NumPy performs cumsum on higher precision and thus this test
does not pass on float32.
"""
duration = 4.0
num_frames = int(sample_rate * duration)
freq0 = f0 * torch.arange(1, num_pitches + 1, device=self.device, dtype=self.dtype)
amps = 1.0 / num_pitches * torch.ones_like(freq0)
ones = torch.ones([num_frames, num_pitches], device=self.device, dtype=self.dtype)
freq = ones * freq0[None, :]
amps = ones * amps[None, :]
wavs_ref = oscillator_bank_np(freq.cpu().numpy(), amps.cpu().numpy(), sample_rate)
wavs_hyp = F.oscillator_bank(freq, amps, sample_rate, reduction="none")
# Debug code to see what goes wrong.
# keeping it for future reference
def _debug_plot():
"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(num_pitches, 3, sharex=True, sharey=True)
for p in range(num_pitches):
(ax0, ax1, ax2) = axes[p] if num_pitches > 1 else axes
spec_ref, ys, xs, _ = ax0.specgram(wavs_ref[:, p])
spec_hyp, _, _, _ = ax1.specgram(wavs_hyp[:, p])
spec_diff = spec_ref - spec_hyp
ax2.imshow(spec_diff, aspect="auto", extent=[xs[0], xs[-1], ys[0], ys[-1]])
plt.show()
"""
pass
try:
self.assertEqual(wavs_hyp, wavs_ref)
except AssertionError:
_debug_plot()
raise
from ._dsp import oscillator_bank
from .functional import add_noise, barkscale_fbanks, convolve, fftconvolve
__all__ = ["add_noise", "barkscale_fbanks", "convolve", "fftconvolve"]
__all__ = [
"add_noise",
"barkscale_fbanks",
"convolve",
"fftconvolve",
"oscillator_bank",
]
import warnings
import torch
def oscillator_bank(
frequencies: torch.Tensor,
amplitudes: torch.Tensor,
sample_rate: float,
reduction: str = "sum",
) -> torch.Tensor:
"""Synthesize waveform from the given instantaneous frequencies and amplitudes.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Note:
The phase information of the output waveform is found by taking the cumulative sum
of the given instantaneous frequencies (``frequencies``).
This incurs roundoff error when the data type does not have enough precision.
Using ``torch.float64`` can work around this.
The following figure shows the difference between ``torch.float32`` and
``torch.float64`` when generating a sin wave of constant frequency and amplitude
with sample rate 8000 [Hz].
Notice that ``torch.float32`` version shows artifacts that are not seen in
``torch.float64`` version.
.. image:: https://download.pytorch.org/torchaudio/doc-assets/oscillator_precision.png
Args:
frequencies (Tensor): Sample-wise oscillator frequencies (Hz). Shape `(..., time, N)`.
amplitudes (Tensor): Sample-wise oscillator amplitude. Shape: `(..., time, N)`.
sample_rate (float): Sample rate
reduction (str): Reduction to perform.
Valid values are ``"sum"``, ``"mean"`` or ``"none"``. Default: ``"sum"``
Returns:
Tensor:
The resulting waveform.
If ``reduction`` is ``"none"``, then the shape is
`(..., time, N)`, otherwise the shape is `(..., time)`.
"""
if frequencies.shape != amplitudes.shape:
raise ValueError(
"The shapes of `frequencies` and `amplitudes` must match. "
f"Found: {frequencies.shape} and {amplitudes.shape} respectively."
)
reductions = ["sum", "mean", "none"]
if reduction not in reductions:
raise ValueError(f"The value of reduction must be either {reductions}. Found: {reduction}")
invalid = torch.abs(frequencies) >= sample_rate / 2
if torch.any(invalid):
warnings.warn(
"Some frequencies are above nyquist frequency. "
"Setting the corresponding amplitude to zero. "
"This might cause numerically unstable gradient."
)
amplitudes = torch.where(invalid, 0.0, amplitudes)
# Note:
# In magenta/ddsp, there is an option to reduce the number of summation to reduce
# the accumulation error.
# https://github.com/magenta/ddsp/blob/7cb3c37f96a3e5b4a2b7e94fdcc801bfd556021b/ddsp/core.py#L950-L955
# It mentions some performance penalty.
# In torchaudio, a simple way to work around is to use float64.
# We might add angular_cumsum if it turned out to be undesirable.
pi2 = 2.0 * torch.pi
freqs = frequencies * pi2 / sample_rate % pi2
phases = torch.cumsum(freqs, axis=-2)
waveform = amplitudes * torch.sin(phases)
if reduction == "sum":
return waveform.sum(-1)
if reduction == "mean":
return waveform.mean(-1)
return waveform
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment