Unverified Commit f5dbb002 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Add PitchShift to functional and transform (#1629)

parent 0ea6d10d
...@@ -211,6 +211,11 @@ vad ...@@ -211,6 +211,11 @@ vad
.. autofunction:: phase_vocoder .. autofunction:: phase_vocoder
:hidden:`pitch_shift`
-----------------------
.. autofunction:: pitch_shift
:hidden:`compute_deltas` :hidden:`compute_deltas`
------------------------ ------------------------
......
...@@ -101,6 +101,13 @@ Transforms are common audio transforms. They can be chained together using :clas ...@@ -101,6 +101,13 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward .. automethod:: forward
:hidden:`PitchShift`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: PitchShift
.. automethod:: forward
:hidden:`Fade` :hidden:`Fade`
~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~
......
...@@ -422,6 +422,16 @@ class Functional(TestBaseMixin): ...@@ -422,6 +422,16 @@ class Functional(TestBaseMixin):
assert F.edit_distance(seq1, seq2) == distance assert F.edit_distance(seq1, seq2) == distance
assert F.edit_distance(seq2, seq1) == distance assert F.edit_distance(seq2, seq1) == distance
@nested_params(
[-4, -2, 0, 2, 4],
)
def test_pitch_shift_shape(self, n_steps):
sample_rate = 16000
torch.random.manual_seed(42)
waveform = torch.rand(2, 44100 * 1, dtype=self.dtype, device=self.device)
waveform_shift = F.pitch_shift(waveform, sample_rate, n_steps)
assert waveform.size() == waveform_shift.size()
class FunctionalCPUOnly(TestBaseMixin): class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self): def test_create_fb_matrix_no_warning_high_n_freq(self):
......
...@@ -187,3 +187,15 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -187,3 +187,15 @@ class TestTransforms(common_utils.TorchaudioTestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_batch_pitch_shift(self):
sample_rate = 44100
n_steps = 4
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
# Single then transform then batch
expected = torchaudio.transforms.PitchShift(sample_rate, n_steps)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.PitchShift(sample_rate, n_steps)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
...@@ -135,3 +135,12 @@ class Transforms(TempDirMixin, TestBaseMixin): ...@@ -135,3 +135,12 @@ class Transforms(TempDirMixin, TestBaseMixin):
tensor, tensor,
test_pseudo_complex test_pseudo_complex
) )
def test_PitchShift(self):
sample_rate = 8000
n_steps = 4
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(
T.PitchShift(sample_rate=sample_rate, n_steps=n_steps),
waveform
)
...@@ -21,6 +21,7 @@ from .functional import ( ...@@ -21,6 +21,7 @@ from .functional import (
apply_codec, apply_codec,
resample, resample,
edit_distance, edit_distance,
pitch_shift,
) )
from .filtering import ( from .filtering import (
allpass_biquad, allpass_biquad,
...@@ -90,4 +91,5 @@ __all__ = [ ...@@ -90,4 +91,5 @@ __all__ = [
'apply_codec', 'apply_codec',
'resample', 'resample',
'edit_distance', 'edit_distance',
'pitch_shift',
] ]
...@@ -36,6 +36,7 @@ __all__ = [ ...@@ -36,6 +36,7 @@ __all__ = [
"apply_codec", "apply_codec",
"resample", "resample",
"edit_distance", "edit_distance",
"pitch_shift",
] ]
...@@ -1488,3 +1489,76 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int: ...@@ -1488,3 +1489,76 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
dnew, dold = dold, dnew dnew, dold = dold, dnew
return int(dold[-1]) return int(dold[-1])
def pitch_shift(
waveform: Tensor,
sample_rate: int,
n_steps: int,
bins_per_octave: int = 12,
n_fft: int = 512,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window: Optional[Tensor] = None,
) -> Tensor:
"""
Shift the pitch of a waveform by ``n_steps`` steps.
Args:
waveform (Tensor): The input waveform of shape `(..., time)`.
sample_rate (float): Sample rate of `waveform`.
n_steps (int): The (fractional) steps to shift `waveform`.
bins_per_octave (int, optional): The number of steps per octave (Default: ``12``).
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
hop_length (int or None, optional): Length of hop between STFT windows. If None, then
``win_length // 4`` is used (Default: ``None``).
window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).
Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
"""
if hop_length is None:
hop_length = n_fft // 4
if win_length is None:
win_length = n_fft
if window is None:
window = torch.hann_window(window_length=win_length, device=waveform.device)
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
ori_len = shape[-1]
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
spec_f = torch.stft(input=waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True)
phase_advance = torch.linspace(0, math.pi * hop_length, spec_f.shape[-2], device=spec_f.device)[..., None]
spec_stretch = phase_vocoder(spec_f, rate, phase_advance)
len_stretch = int(round(ori_len / rate))
waveform_stretch = torch.istft(spec_stretch,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
length=len_stretch)
waveform_shift = resample(waveform_stretch, sample_rate / rate, float(sample_rate))
shift_len = waveform_shift.size()[-1]
if shift_len > ori_len:
waveform_shift = waveform_shift[..., :ori_len]
else:
waveform_shift = torch.nn.functional.pad(waveform_shift, [0, ori_len - shift_len])
# unpack batch
waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:])
return waveform_shift
...@@ -34,6 +34,7 @@ __all__ = [ ...@@ -34,6 +34,7 @@ __all__ = [
'SpectralCentroid', 'SpectralCentroid',
'Vol', 'Vol',
'ComputeDeltas', 'ComputeDeltas',
'PitchShift',
] ]
...@@ -1210,3 +1211,56 @@ class SpectralCentroid(torch.nn.Module): ...@@ -1210,3 +1211,56 @@ class SpectralCentroid(torch.nn.Module):
return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length, return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length,
self.win_length) self.win_length)
class PitchShift(torch.nn.Module):
r"""Shift the pitch of a waveform by ``n_steps`` steps.
Args:
waveform (Tensor): The input waveform of shape `(..., time)`.
sample_rate (float): Sample rate of `waveform`.
n_steps (int): The (fractional) steps to shift `waveform`.
bins_per_octave (int, optional): The number of steps per octave (Default : ``12``).
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
hop_length (int or None, optional): Length of hop between STFT windows. If None, then ``win_length // 4``
is used (Default: ``None``).
window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> waveform_shift = transforms.PitchShift(sample_rate, 4)(waveform) # (channel, time)
"""
__constants__ = ['sample_rate', 'n_steps', 'bins_per_octave', 'n_fft', 'win_length', 'hop_length']
def __init__(self,
sample_rate: int,
n_steps: int,
bins_per_octave: int = 12,
n_fft: int = 512,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None) -> None:
super(PitchShift, self).__init__()
self.n_steps = n_steps
self.bins_per_octave = bins_per_octave
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 4
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`.
"""
return F.pitch_shift(waveform, self.sample_rate, self.n_steps, self.bins_per_octave, self.n_fft,
self.win_length, self.hop_length, self.window)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment