Commit 3882c395 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add extend_pitch (#2863)

Summary:
Add `extend_pitch` function that can be used for augmenting fundamental frequencies with its harmonic overtones or inharmonic partials. it can be use for amplitude as well.

For example usages, see https://output.circle-artifacts.com/output/job/4ad0c29a-d75a-4244-baad-f5499f11d94b/artifacts/0/docs/tutorials/synthesis_tutorial.html

Part of https://github.com/pytorch/audio/issues/2835
Extracted from https://github.com/pytorch/audio/issues/2808

Pull Request resolved: https://github.com/pytorch/audio/pull/2863

Reviewed By: carolineechen

Differential Revision: D41543880

Pulled By: mthrok

fbshipit-source-id: 4f20e55770b0b3bee825ec07c73f9ec7cb181109
parent 8ba323bb
...@@ -32,4 +32,5 @@ DSP ...@@ -32,4 +32,5 @@ DSP
:nosignatures: :nosignatures:
adsr_envelope adsr_envelope
extend_pitch
oscillator_bank oscillator_bank
...@@ -54,3 +54,11 @@ class AutogradTestImpl(TestBaseMixin): ...@@ -54,3 +54,11 @@ class AutogradTestImpl(TestBaseMixin):
amps = torch.linspace(-5, 5, numel, dtype=self.dtype, device=self.device, requires_grad=True).reshape(shape) amps = torch.linspace(-5, 5, numel, dtype=self.dtype, device=self.device, requires_grad=True).reshape(shape)
assert gradcheck(F.oscillator_bank, (freq, amps, sample_rate)) assert gradcheck(F.oscillator_bank, (freq, amps, sample_rate))
def test_extend_pitch(self):
num_frames, num_pitches = 5, 7
input = torch.ones((num_frames, 1), device=self.device, dtype=self.dtype, requires_grad=True)
pattern = torch.linspace(1, num_pitches, num_pitches, device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.extend_pitch, (input, num_pitches))
assert gradcheck(F.extend_pitch, (input, pattern))
...@@ -290,6 +290,27 @@ class FunctionalTestImpl(TestBaseMixin): ...@@ -290,6 +290,27 @@ class FunctionalTestImpl(TestBaseMixin):
) )
self.assertEqual(out, torch.tensor(expected, device=self.device, dtype=self.dtype)) self.assertEqual(out, torch.tensor(expected, device=self.device, dtype=self.dtype))
def test_extend_pitch(self):
num_frames = 5
input = torch.ones((num_frames, 1), device=self.device, dtype=self.dtype)
num_pitches = 7
pattern = [i + 1 for i in range(num_pitches)]
expected = torch.tensor([pattern] * num_frames).to(dtype=self.dtype, device=self.device)
# passing int will append harmonic tones
output = F.extend_pitch(input, num_pitches)
self.assertEqual(output, expected)
# Same can be done with passing the list of multipliers
output = F.extend_pitch(input, pattern)
self.assertEqual(output, expected)
# or with tensor
pat = torch.tensor(pattern).to(dtype=self.dtype, device=self.device)
output = F.extend_pitch(input, pat)
self.assertEqual(output, expected)
class Functional64OnlyTestImpl(TestBaseMixin): class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params( @nested_params(
......
...@@ -65,3 +65,14 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin): ...@@ -65,3 +65,14 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
amps = torch.ones_like(freq) amps = torch.ones_like(freq)
self._assert_consistency(F.oscillator_bank, (freq, amps, sample_rate, "sum")) self._assert_consistency(F.oscillator_bank, (freq, amps, sample_rate, "sum"))
def test_extend_pitch(self):
num_frames = 5
input = torch.ones((num_frames, 1), device=self.device, dtype=self.dtype)
num_pitches = 7
pattern = [i + 1.0 for i in range(num_pitches)]
self._assert_consistency(F.extend_pitch, (input, num_pitches))
self._assert_consistency(F.extend_pitch, (input, pattern))
self._assert_consistency(F.extend_pitch, (input, torch.tensor(pattern)))
from ._dsp import adsr_envelope, oscillator_bank from ._dsp import adsr_envelope, extend_pitch, oscillator_bank
from .functional import add_noise, barkscale_fbanks, convolve, fftconvolve from .functional import add_noise, barkscale_fbanks, convolve, fftconvolve
__all__ = [ __all__ = [
...@@ -6,6 +6,7 @@ __all__ = [ ...@@ -6,6 +6,7 @@ __all__ = [
"adsr_envelope", "adsr_envelope",
"barkscale_fbanks", "barkscale_fbanks",
"convolve", "convolve",
"extend_pitch",
"fftconvolve", "fftconvolve",
"oscillator_bank", "oscillator_bank",
] ]
import warnings import warnings
from typing import Optional from typing import List, Optional, Union
import torch import torch
...@@ -180,3 +180,70 @@ def adsr_envelope( ...@@ -180,3 +180,70 @@ def adsr_envelope(
torch.linspace(sustain, 0, num_r + 1, out=out[-num_r - 1 :]) torch.linspace(sustain, 0, num_r + 1, out=out[-num_r - 1 :])
return out return out
def extend_pitch(
base: torch.Tensor,
pattern: Union[int, List[float], torch.Tensor],
):
"""Extend the given time series values with multipliers of them.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Given a series of fundamental frequencies (pitch), this function appends
its harmonic overtones or inharmonic partials.
Args:
base (torch.Tensor):
Base time series, like fundamental frequencies (Hz). Shape: `(..., time, 1)`.
pattern (int, list of floats or torch.Tensor):
If ``int``, the number of pitch series after the operation.
`pattern - 1` tones are added, so that the resulting Tensor contains
up to `pattern`-th overtones of the given series.
If list of float or ``torch.Tensor``, it must be one dimensional,
representing the custom multiplier of the fundamental frequency.
Returns:
Tensor: Oscillator frequencies (Hz). Shape: `(..., time, num_tones)`.
Example
>>> # fundamental frequency
>>> f0 = torch.linspace(1, 5, 5).unsqueeze(-1)
>>> f0
tensor([[1.],
[2.],
[3.],
[4.],
[5.]])
>>> # Add harmonic overtones, up to 3rd.
>>> f = extend_pitch(f0, 3)
>>> f.shape
torch.Size([5, 3])
>>> f
tensor([[ 1., 2., 3.],
[ 2., 4., 6.],
[ 3., 6., 9.],
[ 4., 8., 12.],
[ 5., 10., 15.]])
>>> # Add custom (inharmonic) partials.
>>> f = extend_pitch(f0, torch.tensor([1, 2.1, 3.3, 4.5]))
>>> f.shape
torch.Size([5, 4])
>>> f
tensor([[ 1.0000, 2.1000, 3.3000, 4.5000],
[ 2.0000, 4.2000, 6.6000, 9.0000],
[ 3.0000, 6.3000, 9.9000, 13.5000],
[ 4.0000, 8.4000, 13.2000, 18.0000],
[ 5.0000, 10.5000, 16.5000, 22.5000]])
"""
if isinstance(pattern, torch.Tensor):
mult = pattern
elif isinstance(pattern, int):
mult = torch.linspace(1.0, float(pattern), pattern, device=base.device, dtype=base.dtype)
else:
mult = torch.tensor(pattern, dtype=base.dtype, device=base.device)
h_freq = base @ mult.unsqueeze(0)
return h_freq
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment