Unverified Commit dab7f64b authored by Jeremy Chen's avatar Jeremy Chen Committed by GitHub
Browse files

Remove istft (#841)



* `istft` has been migrated to `pytorch`, and `torchaudio.functional.istft` has been deprecated in 0.6.0 release. This PR removes it
Co-authored-by: default avatarJeremy Chen <jeremyyy@fb.com>
parent 870811c7
...@@ -8,11 +8,6 @@ torchaudio.functional ...@@ -8,11 +8,6 @@ torchaudio.functional
Functions to perform common audio operations. Functions to perform common audio operations.
:hidden:`istft`
~~~~~~~~~~~~~~~
.. autofunction:: istft
:hidden:`spectrogram` :hidden:`spectrogram`
~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~
......
...@@ -11,31 +11,6 @@ from . import common_utils ...@@ -11,31 +11,6 @@ from . import common_utils
from .functional_impl import Lfilter from .functional_impl import Lfilter
def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
""" Generates random tensors given a seed and size
https://en.wikipedia.org/wiki/Linear_congruential_generator
X_{n + 1} = (a * X_n + c) % m
Using Borland C/C++ values
The tensor will have values between [0,1)
Inputs:
seed (int): an int
size (Tuple[int]): the size of the output tensor
a (int): the multiplier constant to the generator
c (int): the additive constant to the generator
m (int): the modulus constant to the generator
"""
num_elements = 1
for s in size:
num_elements *= s
arr = [(a * seed + c) % m]
for i in range(num_elements - 1):
arr.append((a * arr[i] + c) % m)
return torch.tensor(arr).float().view(size) / m
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
...@@ -63,242 +38,6 @@ class TestComputeDeltas(common_utils.TorchaudioTestCase): ...@@ -63,242 +38,6 @@ class TestComputeDeltas(common_utils.TorchaudioTestCase):
torch.testing.assert_allclose(computed, expected) torch.testing.assert_allclose(computed, expected)
def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)]
torch.testing.assert_allclose(estimate, sound, atol=atol, rtol=rtol)
def _test_istft_is_inverse_of_stft(kwargs):
# generates a random sound signal for each tril and then does the stft/istft
# operation to check whether we can reconstruct signal
for data_size in [(2, 20), (3, 15), (4, 10)]:
for i in range(100):
sound = random_float_tensor(i, data_size)
stft = torch.stft(sound, **kwargs)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
_compare_estimate(sound, estimate)
class TestIstft(common_utils.TorchaudioTestCase):
"""Test suite for correctness of istft with various input"""
number_of_trials = 100
def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided
kwargs1 = {
'n_fft': 12,
'hop_length': 4,
'win_length': 12,
'window': torch.hann_window(12),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
_test_istft_is_inverse_of_stft(kwargs1)
def test_istft_is_inverse_of_stft2(self):
# hann_window, centered, not normalized, not onesided
kwargs2 = {
'n_fft': 12,
'hop_length': 2,
'win_length': 8,
'window': torch.hann_window(8),
'center': True,
'pad_mode': 'reflect',
'normalized': False,
'onesided': False,
}
_test_istft_is_inverse_of_stft(kwargs2)
def test_istft_is_inverse_of_stft3(self):
# hamming_window, centered, normalized, not onesided
kwargs3 = {
'n_fft': 15,
'hop_length': 3,
'win_length': 11,
'window': torch.hamming_window(11),
'center': True,
'pad_mode': 'constant',
'normalized': True,
'onesided': False,
}
_test_istft_is_inverse_of_stft(kwargs3)
def test_istft_is_inverse_of_stft4(self):
# hamming_window, not centered, not normalized, onesided
# window same size as n_fft
kwargs4 = {
'n_fft': 5,
'hop_length': 2,
'win_length': 5,
'window': torch.hamming_window(5),
'center': False,
'pad_mode': 'constant',
'normalized': False,
'onesided': True,
}
_test_istft_is_inverse_of_stft(kwargs4)
def test_istft_is_inverse_of_stft5(self):
# hamming_window, not centered, not normalized, not onesided
# window same size as n_fft
kwargs5 = {
'n_fft': 3,
'hop_length': 2,
'win_length': 3,
'window': torch.hamming_window(3),
'center': False,
'pad_mode': 'reflect',
'normalized': False,
'onesided': False,
}
_test_istft_is_inverse_of_stft(kwargs5)
def test_istft_of_ones(self):
# stft = torch.stft(torch.ones(4), 4)
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
_compare_estimate(torch.ones(4), estimate)
def test_istft_of_zeros(self):
# stft = torch.stft(torch.zeros(4), 4)
stft = torch.zeros((3, 5, 2))
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
_compare_estimate(torch.zeros(4), estimate)
def test_istft_requires_overlap_windows(self):
# the window is size 1 but it hops 20 so there is a gap which throw an error
stft = torch.zeros((3, 5, 2))
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
hop_length=20, win_length=1, window=torch.ones(1))
def test_istft_requires_nola(self):
stft = torch.zeros((3, 5, 2))
kwargs_ok = {
'n_fft': 4,
'win_length': 4,
'window': torch.ones(4),
}
kwargs_not_ok = {
'n_fft': 4,
'win_length': 4,
'window': torch.zeros(4),
}
# A window of ones meets NOLA but a window of zeros does not. This should
# throw an error.
torchaudio.functional.istft(stft, **kwargs_ok)
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok)
def test_istft_requires_non_empty(self):
self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)
def _test_istft_of_sine(self, amplitude, L, n):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
sound = amplitude * torch.sin(2 * math.pi / L * x * n)
# stft = torch.stft(sound, L, hop_length=L, win_length=L,
# window=torch.ones(L), center=False, normalized=False)
stft = torch.zeros((L // 2 + 1, 2, 2))
stft_largest_val = (amplitude * L) / 2.0
if n < stft.size(0):
stft[n, :, 1] = -stft_largest_val
if 0 <= L - n < stft.size(0):
# symmetric about L // 2
stft[L - n, :, 1] = stft_largest_val
estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
window=torch.ones(L), center=False, normalized=False)
# There is a larger error due to the scaling of amplitude
_compare_estimate(sound, estimate, atol=1e-3)
def test_istft_of_sine(self):
self._test_istft_of_sine(amplitude=123, L=5, n=1)
self._test_istft_of_sine(amplitude=150, L=5, n=2)
self._test_istft_of_sine(amplitude=111, L=5, n=3)
self._test_istft_of_sine(amplitude=160, L=7, n=4)
self._test_istft_of_sine(amplitude=145, L=8, n=5)
self._test_istft_of_sine(amplitude=80, L=9, n=6)
self._test_istft_of_sine(amplitude=99, L=10, n=7)
def _test_linearity_of_istft(self, data_size, kwargs, atol=1e-6, rtol=1e-8):
for i in range(self.number_of_trials):
tensor1 = random_float_tensor(i, data_size)
tensor2 = random_float_tensor(i * 2, data_size)
a, b = torch.rand(2)
istft1 = torchaudio.functional.istft(tensor1, **kwargs)
istft2 = torchaudio.functional.istft(tensor2, **kwargs)
istft = a * istft1 + b * istft2
estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2, **kwargs)
_compare_estimate(istft, estimate, atol, rtol)
def test_linearity_of_istft1(self):
# hann_window, centered, normalized, onesided
kwargs1 = {
'n_fft': 12,
'window': torch.hann_window(12),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
data_size = (2, 7, 7, 2)
self._test_linearity_of_istft(data_size, kwargs1)
def test_linearity_of_istft2(self):
# hann_window, centered, not normalized, not onesided
kwargs2 = {
'n_fft': 12,
'window': torch.hann_window(12),
'center': True,
'pad_mode': 'reflect',
'normalized': False,
'onesided': False,
}
data_size = (2, 12, 7, 2)
self._test_linearity_of_istft(data_size, kwargs2)
def test_linearity_of_istft3(self):
# hamming_window, centered, normalized, not onesided
kwargs3 = {
'n_fft': 12,
'window': torch.hamming_window(12),
'center': True,
'pad_mode': 'constant',
'normalized': True,
'onesided': False,
}
data_size = (2, 12, 7, 2)
self._test_linearity_of_istft(data_size, kwargs3)
def test_linearity_of_istft4(self):
# hamming_window, not centered, not normalized, onesided
kwargs4 = {
'n_fft': 12,
'window': torch.hamming_window(12),
'center': False,
'pad_mode': 'constant',
'normalized': False,
'onesided': True,
}
data_size = (2, 7, 3, 2)
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase): class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
@parameterized.expand([(100,), (440,)]) @parameterized.expand([(100,), (440,)])
def test_pitch(self, frequency): def test_pitch(self, frequency):
......
...@@ -59,14 +59,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -59,14 +59,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
n_channels=n_channels, duration=5) n_channels=n_channels, duration=5)
self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate) self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)
def test_istft(self):
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
self.assert_batch_consistencies(F.istft, stft, n_fft=4, length=4)
def test_contrast(self): def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5 waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.) self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.)
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from torch import Tensor from torch import Tensor
__all__ = [ __all__ = [
"istft",
"spectrogram", "spectrogram",
"griffinlim", "griffinlim",
"amplitude_to_DB", "amplitude_to_DB",
...@@ -45,79 +44,6 @@ __all__ = [ ...@@ -45,79 +44,6 @@ __all__ = [
] ]
def istft(
stft_matrix: Tensor,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[Tensor] = None,
center: bool = True,
pad_mode: Optional[str] = None,
normalized: bool = False,
onesided: bool = True,
length: Optional[int] = None,
) -> Tensor:
r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft.
It has the same parameters (+ additional optional parameter of ``length``) and it should return the
least squares estimation of the original signal. The algorithm will check using the NOLA condition (
nonzero overlap).
Important consideration in the parameters ``window`` and ``center`` so that the envelop
created by the summation of all the windows is never zero at certain point in time. Specifically,
:math:`\sum_{t=-\infty}^{\infty} w^2[n-t\times hop\_length] \cancel{=} 0`.
Since stft discards elements at the end of the signal if they do not fit in a frame, the
istft may return a shorter signal than the original signal (can occur if ``center`` is False
since the signal isn't padded).
If ``center`` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding
can be trimmed off exactly because they can be calculated but right padding cannot be calculated
without additional information.
Example: Suppose the last window is:
[17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0]
The n_frame, hop_length, win_length are all the same which prevents the calculation of right padding.
These additional values could be zeros or a reflection of the signal so providing ``length``
could be useful. If ``length`` is ``None`` then padding will be aggressively removed
(some loss of signal).
[1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
Args:
stft_matrix (Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. It has a size of either (..., fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform
hop_length (int or None, optional): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``)
win_length (int or None, optional): The size of window frame and STFT filter. (Default: ``n_fft``)
window (Tensor or None, optional): The optional window function.
(Default: ``torch.ones(win_length)``)
center (bool, optional): Whether ``input`` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``)
pad_mode: This argument was ignored and to be removed.
normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
length (int or None, optional): The amount to trim the signal by (i.e. the
original signal length). (Default: whole signal)
Returns:
Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
warnings.warn(
'istft has been moved to PyTorch and will be removed from torchaudio, '
'please use torch.istft instead.')
if pad_mode is not None:
warnings.warn(
'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. '
'Please set `pad_mode` to None to suppress this warning.')
return torch.istft(
input=stft_matrix, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window,
center=center, normalized=normalized, onesided=onesided, length=length)
def spectrogram( def spectrogram(
waveform: Tensor, waveform: Tensor,
pad: int, pad: int,
...@@ -250,12 +176,12 @@ def griffinlim( ...@@ -250,12 +176,12 @@ def griffinlim(
tprev = rebuilt tprev = rebuilt
# Invert with our current estimate of the phases # Invert with our current estimate of the phases
inverse = istft(specgram * angles, inverse = torch.istft(specgram * angles,
n_fft=n_fft, n_fft=n_fft,
hop_length=hop_length, hop_length=hop_length,
win_length=win_length, win_length=win_length,
window=window, window=window,
length=length).float() length=length).float()
# Rebuild the spectrogram # Rebuild the spectrogram
rebuilt = torch.stft(inverse, n_fft, hop_length, win_length, window, rebuilt = torch.stft(inverse, n_fft, hop_length, win_length, window,
...@@ -268,12 +194,12 @@ def griffinlim( ...@@ -268,12 +194,12 @@ def griffinlim(
angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles)) angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles))
# Return the final phase estimates # Return the final phase estimates
waveform = istft(specgram * angles, waveform = torch.istft(specgram * angles,
n_fft=n_fft, n_fft=n_fft,
hop_length=hop_length, hop_length=hop_length,
win_length=win_length, win_length=win_length,
window=window, window=window,
length=length) length=length)
# unpack batch # unpack batch
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:]) waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment