Commit df2262b5 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Modifying Pitchshift for faster resampling (#2441)

Summary:
Split existing Pitchshift into multiple helper functions in order to cache kernel and speed up overall process addressing https://github.com/pytorch/audio/issues/2359.
Existing unit tests pass.

edit: functional and transforms unit test pass. Adopted lazy initialization to avoid BC-breaking.

Pull Request resolved: https://github.com/pytorch/audio/pull/2441

Reviewed By: carolineechen

Differential Revision: D36905582

Pulled By: skim0514

fbshipit-source-id: 6780db3ac8a29d59017a6abe7e82ce1fd17aaac2
parent 4d2fa190
...@@ -139,7 +139,10 @@ class Transforms(TestBaseMixin): ...@@ -139,7 +139,10 @@ class Transforms(TestBaseMixin):
sample_rate = 8000 sample_rate = 8000
n_steps = 4 n_steps = 4
waveform = common_utils.get_whitenoise(sample_rate=sample_rate) waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), waveform) pitch_shift = T.PitchShift(sample_rate=sample_rate, n_steps=n_steps)
# dry-run for initializing parameters
pitch_shift(waveform)
self._assert_consistency(pitch_shift, waveform)
def test_PSD(self): def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
......
import torch import torch
import torchaudio.transforms as T import torchaudio.transforms as T
from torchaudio.functional.functional import _get_sinc_resample_kernel
from parameterized import param, parameterized from parameterized import param, parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_spectrogram, get_spectrogram,
...@@ -147,3 +148,21 @@ class TransformsTestBase(TestBaseMixin): ...@@ -147,3 +148,21 @@ class TransformsTestBase(TestBaseMixin):
mask_n = torch.rand(specgram.shape[-2:]) mask_n = torch.rand(specgram.shape[-2:])
specgram_enhanced = transform(specgram, mask_s, mask_n) specgram_enhanced = transform(specgram, mask_s, mask_n)
assert specgram_enhanced.dtype == dtype assert specgram_enhanced.dtype == dtype
def test_pitch_shift_resample_kernel(self):
"""The resampling kernel in PitchShift is identical to what helper function generates.
There should be no numerical difference caused by dtype conversion.
"""
sample_rate = 8000
trans = T.PitchShift(sample_rate=sample_rate, n_steps=4)
trans.to(self.dtype).to(self.device)
# dry run to initialize the kernel
trans(torch.randn(2, 8000, dtype=self.dtype, device=self.device))
expected, _ = _get_sinc_resample_kernel(
trans.orig_freq,
sample_rate,
trans.gcd,
device=self.device,
dtype=self.dtype)
self.assertEqual(trans.kernel, expected)
...@@ -4,7 +4,7 @@ import io ...@@ -4,7 +4,7 @@ import io
import math import math
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union, List
import torch import torch
import torchaudio import torchaudio
...@@ -1389,10 +1389,10 @@ def _get_sinc_resample_kernel( ...@@ -1389,10 +1389,10 @@ def _get_sinc_resample_kernel(
orig_freq: int, orig_freq: int,
new_freq: int, new_freq: int,
gcd: int, gcd: int,
lowpass_filter_width: int, lowpass_filter_width: int = 6,
rolloff: float, rolloff: float = 0.99,
resampling_method: str, resampling_method: str = "sinc_interpolation",
beta: Optional[float], beta: Optional[float] = None,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
...@@ -1635,6 +1635,39 @@ def pitch_shift( ...@@ -1635,6 +1635,39 @@ def pitch_shift(
Returns: Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`. Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
""" """
waveform_stretch = _stretch_waveform(
waveform,
n_steps,
bins_per_octave,
n_fft,
win_length,
hop_length,
window,
)
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
return _fix_waveform_shape(waveform_shift, waveform.size())
def _stretch_waveform(
waveform: Tensor,
n_steps: int,
bins_per_octave: int = 12,
n_fft: int = 512,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window: Optional[Tensor] = None,
) -> Tensor:
"""
Pitch shift helper function to preprocess and stretch waveform before resampling step.
Args:
See pitch_shift arg descriptions.
Returns:
Tensor: The preprocessed waveform stretched prior to resampling.
"""
if hop_length is None: if hop_length is None:
hop_length = n_fft // 4 hop_length = n_fft // 4
if win_length is None: if win_length is None:
...@@ -1666,7 +1699,24 @@ def pitch_shift( ...@@ -1666,7 +1699,24 @@ def pitch_shift(
waveform_stretch = torch.istft( waveform_stretch = torch.istft(
spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch
) )
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate) return waveform_stretch
def _fix_waveform_shape(
waveform_shift: Tensor,
shape: List[int],
) -> Tensor:
"""
PitchShift helper function to process after resampling step to fix the shape back.
Args:
waveform_shift(Tensor): The waveform after stretch and resample
shape (List[int]): The shape of initial waveform
Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
"""
ori_len = shape[-1]
shift_len = waveform_shift.size()[-1] shift_len = waveform_shift.size()[-1]
if shift_len > ori_len: if shift_len > ori_len:
waveform_shift = waveform_shift[..., :ori_len] waveform_shift = waveform_shift[..., :ori_len]
......
...@@ -6,10 +6,15 @@ from typing import Callable, Optional ...@@ -6,10 +6,15 @@ from typing import Callable, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedParameter
from torchaudio import functional as F from torchaudio import functional as F
from torchaudio.functional.functional import ( from torchaudio.functional.functional import (
_apply_sinc_resample_kernel, _apply_sinc_resample_kernel,
_get_sinc_resample_kernel, _get_sinc_resample_kernel,
_stretch_waveform,
_fix_waveform_shape,
) )
__all__ = [] __all__ = []
...@@ -1511,7 +1516,7 @@ class SpectralCentroid(torch.nn.Module): ...@@ -1511,7 +1516,7 @@ class SpectralCentroid(torch.nn.Module):
) )
class PitchShift(torch.nn.Module): class PitchShift(LazyModuleMixin, torch.nn.Module):
r"""Shift the pitch of a waveform by ``n_steps`` steps. r"""Shift the pitch of a waveform by ``n_steps`` steps.
.. devices:: CPU CUDA .. devices:: CPU CUDA
...@@ -1537,6 +1542,9 @@ class PitchShift(torch.nn.Module): ...@@ -1537,6 +1542,9 @@ class PitchShift(torch.nn.Module):
""" """
__constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"] __constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"]
kernel: UninitializedParameter
width: int
def __init__( def __init__(
self, self,
sample_rate: int, sample_rate: int,
...@@ -1548,7 +1556,7 @@ class PitchShift(torch.nn.Module): ...@@ -1548,7 +1556,7 @@ class PitchShift(torch.nn.Module):
window_fn: Callable[..., Tensor] = torch.hann_window, window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None, wkwargs: Optional[dict] = None,
) -> None: ) -> None:
super(PitchShift, self).__init__() super().__init__()
self.n_steps = n_steps self.n_steps = n_steps
self.bins_per_octave = bins_per_octave self.bins_per_octave = bins_per_octave
self.sample_rate = sample_rate self.sample_rate = sample_rate
...@@ -1557,6 +1565,27 @@ class PitchShift(torch.nn.Module): ...@@ -1557,6 +1565,27 @@ class PitchShift(torch.nn.Module):
self.hop_length = hop_length if hop_length is not None else self.win_length // 4 self.hop_length = hop_length if hop_length is not None else self.win_length // 4
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer("window", window) self.register_buffer("window", window)
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
self.orig_freq = int(sample_rate / rate)
self.gcd = math.gcd(int(self.orig_freq), int(sample_rate))
if self.orig_freq != sample_rate:
self.width = -1
self.kernel = UninitializedParameter(device=None, dtype=None)
def initialize_parameters(self, input):
if self.has_uninitialized_params():
if self.orig_freq != self.sample_rate:
with torch.no_grad():
kernel, self.width = _get_sinc_resample_kernel(
self.orig_freq,
self.sample_rate,
self.gcd,
dtype=input.dtype,
device=input.device,
)
self.kernel.materialize(kernel.shape)
self.kernel.copy_(kernel)
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
...@@ -1566,10 +1595,10 @@ class PitchShift(torch.nn.Module): ...@@ -1566,10 +1595,10 @@ class PitchShift(torch.nn.Module):
Returns: Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`. Tensor: The pitch-shifted audio of shape `(..., time)`.
""" """
shape = waveform.size()
return F.pitch_shift( waveform_stretch = _stretch_waveform(
waveform, waveform,
self.sample_rate,
self.n_steps, self.n_steps,
self.bins_per_octave, self.bins_per_octave,
self.n_fft, self.n_fft,
...@@ -1578,6 +1607,23 @@ class PitchShift(torch.nn.Module): ...@@ -1578,6 +1607,23 @@ class PitchShift(torch.nn.Module):
self.window, self.window,
) )
if self.orig_freq != self.sample_rate:
waveform_shift = _apply_sinc_resample_kernel(
waveform_stretch,
self.orig_freq,
self.sample_rate,
self.gcd,
self.kernel,
self.width,
)
else:
waveform_shift = waveform_stretch
return _fix_waveform_shape(
waveform_shift,
shape,
)
class RNNTLoss(torch.nn.Module): class RNNTLoss(torch.nn.Module):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks* """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment