Unverified Commit 9835db75 authored by moto's avatar moto Committed by GitHub
Browse files

Use istft from torch (#523)

parent 2dd04029
...@@ -172,7 +172,7 @@ class TestIstft(unittest.TestCase): ...@@ -172,7 +172,7 @@ class TestIstft(unittest.TestCase):
def test_istft_requires_overlap_windows(self): def test_istft_requires_overlap_windows(self):
# the window is size 1 but it hops 20 so there is a gap which throw an error # the window is size 1 but it hops 20 so there is a gap which throw an error
stft = torch.zeros((3, 5, 2)) stft = torch.zeros((3, 5, 2))
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, n_fft=4, self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
hop_length=20, win_length=1, window=torch.ones(1)) hop_length=20, win_length=1, window=torch.ones(1))
def test_istft_requires_nola(self): def test_istft_requires_nola(self):
...@@ -192,11 +192,11 @@ class TestIstft(unittest.TestCase): ...@@ -192,11 +192,11 @@ class TestIstft(unittest.TestCase):
# A window of ones meets NOLA but a window of zeros does not. This should # A window of ones meets NOLA but a window of zeros does not. This should
# throw an error. # throw an error.
torchaudio.functional.istft(stft, **kwargs_ok) torchaudio.functional.istft(stft, **kwargs_ok)
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, **kwargs_not_ok) self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok)
def test_istft_requires_non_empty(self): def test_istft_requires_non_empty(self):
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2) self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2) self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)
def _test_istft_of_sine(self, amplitude, L, n): def _test_istft_of_sine(self, amplitude, L, n):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
import warnings
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -49,7 +50,7 @@ def istft( ...@@ -49,7 +50,7 @@ def istft(
win_length: Optional[int] = None, win_length: Optional[int] = None,
window: Optional[Tensor] = None, window: Optional[Tensor] = None,
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: Optional[str] = None,
normalized: bool = False, normalized: bool = False,
onesided: bool = True, onesided: bool = True,
length: Optional[int] = None, length: Optional[int] = None,
...@@ -94,8 +95,7 @@ def istft( ...@@ -94,8 +95,7 @@ def istft(
center (bool, optional): Whether ``input`` was padded on both sides so center (bool, optional): Whether ``input`` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``) (Default: ``True``)
pad_mode (str, optional): Controls the padding method used when ``center`` is True. (Default: pad_mode: This argument was ignored and to be removed.
``"reflect"``)
normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``) normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``) onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
length (int or None, optional): The amount to trim the signal by (i.e. the length (int or None, optional): The amount to trim the signal by (i.e. the
...@@ -104,105 +104,16 @@ def istft( ...@@ -104,105 +104,16 @@ def istft(
Returns: Returns:
Tensor: Least squares estimation of the original signal of size (..., signal_length) Tensor: Least squares estimation of the original signal of size (..., signal_length)
""" """
stft_matrix_dim = stft_matrix.dim() warnings.warn(
assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim) 'istft has been moved to PyTorch and will be removed from torchaudio, '
assert stft_matrix.numel() > 0 'please use torch.istft instead.')
if pad_mode is not None:
if stft_matrix_dim == 3: warnings.warn(
# add a channel dimension 'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. '
stft_matrix = stft_matrix.unsqueeze(0) 'Please set `pad_mode` to None to suppress this warning.')
return torch.istft(
# pack batch input=stft_matrix, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window,
shape = stft_matrix.size() center=center, normalized=normalized, onesided=onesided, length=length)
stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1])
dtype = stft_matrix.dtype
device = stft_matrix.device
fft_size = stft_matrix.size(1)
assert (onesided and n_fft // 2 + 1 == fft_size) or (
not onesided and n_fft == fft_size
), (
"one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
+ "Given values were onesided: %s, n_fft: %d, fft_size: %d"
% ("True" if onesided else False, n_fft, fft_size)
)
# use stft defaults for Optionals
if win_length is None:
win_length = n_fft
if hop_length is None:
hop_length = int(win_length // 4)
# There must be overlap
assert 0 < hop_length <= win_length
assert 0 < win_length <= n_fft
if window is None:
window = torch.ones(win_length, device=device, dtype=dtype)
assert window.dim() == 1 and window.size(0) == win_length
if win_length != n_fft:
# center window with pad left and right zeros
left = (n_fft - win_length) // 2
window = torch.nn.functional.pad(window, (left, n_fft - win_length - left))
assert window.size(0) == n_fft
# win_length and n_fft are synonymous from here on
stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frame, fft_size, 2)
stft_matrix = torch.irfft(
stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,)
) # size (channel, n_frame, n_fft)
assert stft_matrix.size(2) == n_fft
n_frame = stft_matrix.size(1)
ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frame, n_fft)
# each column of a channel is a frame which needs to be overlap added at the right place
ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frame)
# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
y = torch.nn.functional.fold(
ytmp, (1, (n_frame - 1) * hop_length + n_fft), (1, n_fft), stride=(1, hop_length)
).squeeze(2)
# do the same for the window function
window_sq = (
window.pow(2).view(n_fft, 1).repeat((1, n_frame)).unsqueeze(0)
) # size (1, n_fft, n_frame)
window_envelop = torch.nn.functional.fold(
window_sq, (1, (n_frame - 1) * hop_length + n_fft), (1, n_fft), stride=(1, hop_length)
).squeeze(2) # size (1, 1, expected_signal_len)
expected_signal_len = n_fft + hop_length * (n_frame - 1)
assert y.size(2) == expected_signal_len
assert window_envelop.size(2) == expected_signal_len
half_n_fft = n_fft // 2
# we need to trim the front padding away if center
start = half_n_fft if center else 0
end = -half_n_fft if length is None else start + length
y = y[:, :, start:end]
window_envelop = window_envelop[:, :, start:end]
# check NOLA non-zero overlap condition
window_envelop_lowest = window_envelop.abs().min()
assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % (
window_envelop_lowest
)
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
# unpack batch
y = y.reshape(shape[:-3] + y.shape[-1:])
if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0)
return y
def spectrogram( def spectrogram(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment