"docs/zh_cn/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "7156604eba7350bb6c43f29d0fb2ee336f8d9dff"
Commit df2262b5 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Modifying Pitchshift for faster resampling (#2441)

Summary:
Split existing Pitchshift into multiple helper functions in order to cache kernel and speed up overall process addressing https://github.com/pytorch/audio/issues/2359.
Existing unit tests pass.

edit: functional and transforms unit test pass. Adopted lazy initialization to avoid BC-breaking.

Pull Request resolved: https://github.com/pytorch/audio/pull/2441

Reviewed By: carolineechen

Differential Revision: D36905582

Pulled By: skim0514

fbshipit-source-id: 6780db3ac8a29d59017a6abe7e82ce1fd17aaac2
parent 4d2fa190
......@@ -139,7 +139,10 @@ class Transforms(TestBaseMixin):
sample_rate = 8000
n_steps = 4
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), waveform)
pitch_shift = T.PitchShift(sample_rate=sample_rate, n_steps=n_steps)
# dry-run for initializing parameters
pitch_shift(waveform)
self._assert_consistency(pitch_shift, waveform)
def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
......
import torch
import torchaudio.transforms as T
from torchaudio.functional.functional import _get_sinc_resample_kernel
from parameterized import param, parameterized
from torchaudio_unittest.common_utils import (
get_spectrogram,
......@@ -147,3 +148,21 @@ class TransformsTestBase(TestBaseMixin):
mask_n = torch.rand(specgram.shape[-2:])
specgram_enhanced = transform(specgram, mask_s, mask_n)
assert specgram_enhanced.dtype == dtype
def test_pitch_shift_resample_kernel(self):
"""The resampling kernel in PitchShift is identical to what helper function generates.
There should be no numerical difference caused by dtype conversion.
"""
sample_rate = 8000
trans = T.PitchShift(sample_rate=sample_rate, n_steps=4)
trans.to(self.dtype).to(self.device)
# dry run to initialize the kernel
trans(torch.randn(2, 8000, dtype=self.dtype, device=self.device))
expected, _ = _get_sinc_resample_kernel(
trans.orig_freq,
sample_rate,
trans.gcd,
device=self.device,
dtype=self.dtype)
self.assertEqual(trans.kernel, expected)
......@@ -4,7 +4,7 @@ import io
import math
import warnings
from collections.abc import Sequence
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, List
import torch
import torchaudio
......@@ -1389,10 +1389,10 @@ def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
gcd: int,
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float],
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
beta: Optional[float] = None,
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None,
):
......@@ -1635,6 +1635,39 @@ def pitch_shift(
Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
"""
waveform_stretch = _stretch_waveform(
waveform,
n_steps,
bins_per_octave,
n_fft,
win_length,
hop_length,
window,
)
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
return _fix_waveform_shape(waveform_shift, waveform.size())
def _stretch_waveform(
waveform: Tensor,
n_steps: int,
bins_per_octave: int = 12,
n_fft: int = 512,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window: Optional[Tensor] = None,
) -> Tensor:
"""
Pitch shift helper function to preprocess and stretch waveform before resampling step.
Args:
See pitch_shift arg descriptions.
Returns:
Tensor: The preprocessed waveform stretched prior to resampling.
"""
if hop_length is None:
hop_length = n_fft // 4
if win_length is None:
......@@ -1666,7 +1699,24 @@ def pitch_shift(
waveform_stretch = torch.istft(
spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch
)
waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
return waveform_stretch
def _fix_waveform_shape(
waveform_shift: Tensor,
shape: List[int],
) -> Tensor:
"""
PitchShift helper function to process after resampling step to fix the shape back.
Args:
waveform_shift(Tensor): The waveform after stretch and resample
shape (List[int]): The shape of initial waveform
Returns:
Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
"""
ori_len = shape[-1]
shift_len = waveform_shift.size()[-1]
if shift_len > ori_len:
waveform_shift = waveform_shift[..., :ori_len]
......
......@@ -6,10 +6,15 @@ from typing import Callable, Optional
import torch
from torch import Tensor
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedParameter
from torchaudio import functional as F
from torchaudio.functional.functional import (
_apply_sinc_resample_kernel,
_get_sinc_resample_kernel,
_stretch_waveform,
_fix_waveform_shape,
)
__all__ = []
......@@ -1511,7 +1516,7 @@ class SpectralCentroid(torch.nn.Module):
)
class PitchShift(torch.nn.Module):
class PitchShift(LazyModuleMixin, torch.nn.Module):
r"""Shift the pitch of a waveform by ``n_steps`` steps.
.. devices:: CPU CUDA
......@@ -1537,6 +1542,9 @@ class PitchShift(torch.nn.Module):
"""
__constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"]
kernel: UninitializedParameter
width: int
def __init__(
self,
sample_rate: int,
......@@ -1548,7 +1556,7 @@ class PitchShift(torch.nn.Module):
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None,
) -> None:
super(PitchShift, self).__init__()
super().__init__()
self.n_steps = n_steps
self.bins_per_octave = bins_per_octave
self.sample_rate = sample_rate
......@@ -1557,6 +1565,27 @@ class PitchShift(torch.nn.Module):
self.hop_length = hop_length if hop_length is not None else self.win_length // 4
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer("window", window)
rate = 2.0 ** (-float(n_steps) / bins_per_octave)
self.orig_freq = int(sample_rate / rate)
self.gcd = math.gcd(int(self.orig_freq), int(sample_rate))
if self.orig_freq != sample_rate:
self.width = -1
self.kernel = UninitializedParameter(device=None, dtype=None)
def initialize_parameters(self, input):
if self.has_uninitialized_params():
if self.orig_freq != self.sample_rate:
with torch.no_grad():
kernel, self.width = _get_sinc_resample_kernel(
self.orig_freq,
self.sample_rate,
self.gcd,
dtype=input.dtype,
device=input.device,
)
self.kernel.materialize(kernel.shape)
self.kernel.copy_(kernel)
def forward(self, waveform: Tensor) -> Tensor:
r"""
......@@ -1566,10 +1595,10 @@ class PitchShift(torch.nn.Module):
Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`.
"""
shape = waveform.size()
return F.pitch_shift(
waveform_stretch = _stretch_waveform(
waveform,
self.sample_rate,
self.n_steps,
self.bins_per_octave,
self.n_fft,
......@@ -1578,6 +1607,23 @@ class PitchShift(torch.nn.Module):
self.window,
)
if self.orig_freq != self.sample_rate:
waveform_shift = _apply_sinc_resample_kernel(
waveform_stretch,
self.orig_freq,
self.sample_rate,
self.gcd,
self.kernel,
self.width,
)
else:
waveform_shift = waveform_stretch
return _fix_waveform_shape(
waveform_shift,
shape,
)
class RNNTLoss(torch.nn.Module):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment