Unverified Commit 6e0af713 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Add InverseSpectrogram to transforms and functional (#1652)



- Provide InverseSpectrogram module that corresponds to Spectrogram module
- Add length parameter to the forward method in transforms
Co-authored-by: default avatardgenzel <dgenzel@fb.com>
Co-authored-by: default avatarZhaoheng Ni <zni@fb.com>
parent 084455a3
...@@ -79,6 +79,37 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -79,6 +79,37 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor = common_utils.get_whitenoise() tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
def test_inverse_spectrogram_complex(self):
def func(tensor):
length = 400
n_fft = 400
hop = 200
ws = 400
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=torch.float64)
normalize = False
return F.inverse_spectrogram(tensor, length, pad, window, n_fft, hop, ws, normalize)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05)
tensor = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
self._assert_consistency_complex(func, tensor)
def test_inverse_spectrogram_real(self):
def func(tensor):
length = 400
n_fft = 400
hop = 200
ws = 400
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
normalize = False
return F.inverse_spectrogram(tensor, length, pad, window, n_fft, hop, ws, normalize)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05)
tensor = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
tensor = torch.view_as_real(tensor)
self._assert_consistency(func, tensor)
@skipIfRocm @skipIfRocm
def test_griffinlim(self): def test_griffinlim(self):
def func(tensor): def func(tensor):
......
...@@ -76,6 +76,19 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -76,6 +76,19 @@ class AutogradTestMixin(TestBaseMixin):
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10) self.assert_grad(transform, [waveform], nondet_tol=1e-10)
@parameterized.expand([(False, ), (True, )])
def test_inverse_spectrogram(self, return_complex):
# create a realistic input:
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
length = waveform.shape[-1]
spectrogram = get_spectrogram(waveform, n_fft=400)
if not return_complex:
spectrogram = torch.view_as_real(spectrogram)
# test
inv_transform = T.InverseSpectrogram(n_fft=400)
self.assert_grad(inv_transform, [spectrogram, length])
def test_melspectrogram(self): def test_melspectrogram(self):
# replication_pad1d_backward_cuda is not deteministic and # replication_pad1d_backward_cuda is not deteministic and
# gives very small (~2.7756e-17) difference. # gives very small (~2.7756e-17) difference.
......
...@@ -107,6 +107,17 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -107,6 +107,17 @@ class TestTransforms(common_utils.TorchaudioTestCase):
computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_batch_inverse_spectrogram(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
transform = torchaudio.transforms.Spectrogram(power=None)(waveform)
# Single then transform then batch
expected = torchaudio.transforms.InverseSpectrogram()(transform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.InverseSpectrogram()(transform.repeat(3, 1, 1, 1))
self.assertEqual(computed, expected)
def test_batch_melspectrogram(self): def test_batch_melspectrogram(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=2)
......
...@@ -50,6 +50,17 @@ class Transforms(TempDirMixin, TestBaseMixin): ...@@ -50,6 +50,17 @@ class Transforms(TempDirMixin, TestBaseMixin):
tensor = torch.rand((1, 1000)) tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(power=None, return_complex=True), tensor) self._assert_consistency(T.Spectrogram(power=None, return_complex=True), tensor)
def test_InverseSpectrogram(self):
tensor = common_utils.get_whitenoise(sample_rate=8000)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
self._assert_consistency_complex(T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram)
def test_InverseSpectrogram_pseudocomplex(self):
tensor = common_utils.get_whitenoise(sample_rate=8000)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = torch.view_as_real(spectrogram)
self._assert_consistency(T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram)
@skipIfRocm @skipIfRocm
def test_GriffinLim(self): def test_GriffinLim(self):
tensor = torch.rand((1, 201, 6)) tensor = torch.rand((1, 201, 6))
......
import torch import torch
import torchaudio.transforms as T import torchaudio.transforms as T
from parameterized import parameterized, param
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
get_whitenoise, get_whitenoise,
...@@ -82,3 +82,29 @@ class TransformsTestBase(TestBaseMixin): ...@@ -82,3 +82,29 @@ class TransformsTestBase(TestBaseMixin):
transform = T.Resample(16000, 44100, resampling_method, dtype=dtype) transform = T.Resample(16000, 44100, resampling_method, dtype=dtype)
assert transform.kernel.dtype == dtype if dtype is not None else torch.float32 assert transform.kernel.dtype == dtype if dtype is not None else torch.float32
@parameterized.expand([
param(n_fft=300, center=True, onesided=True),
param(n_fft=400, center=True, onesided=False),
param(n_fft=400, center=True, onesided=False),
param(n_fft=300, center=True, onesided=False),
param(n_fft=400, hop_length=10),
param(n_fft=800, win_length=400, hop_length=20),
param(n_fft=800, win_length=400, hop_length=20, normalized=True),
param(),
param(n_fft=400, pad=32),
# These tests do not work - cause runtime error
# See https://github.com/pytorch/pytorch/issues/62323
# param(n_fft=400, center=False, onesided=True),
# param(n_fft=400, center=False, onesided=False),
])
def test_roundtrip_spectrogram(self, **args):
"""Test the spectrogram + inverse spectrogram results in approximate identity."""
waveform = get_whitenoise(sample_rate=8000, duration=0.5, dtype=self.dtype)
s = T.Spectrogram(**args, power=None)
inv_s = T.InverseSpectrogram(**args)
transformed = s.forward(waveform)
restored = inv_s.forward(transformed, length=waveform.shape[-1])
self.assertEqual(waveform, restored, atol=1e-6, rtol=1e-6)
...@@ -10,6 +10,7 @@ from .functional import ( ...@@ -10,6 +10,7 @@ from .functional import (
linear_fbanks, linear_fbanks,
DB_to_amplitude, DB_to_amplitude,
detect_pitch_frequency, detect_pitch_frequency,
inverse_spectrogram,
griffinlim, griffinlim,
magphase, magphase,
mask_along_axis, mask_along_axis,
...@@ -70,6 +71,7 @@ __all__ = [ ...@@ -70,6 +71,7 @@ __all__ = [
'phase_vocoder', 'phase_vocoder',
'sliding_window_cmn', 'sliding_window_cmn',
'spectrogram', 'spectrogram',
'inverse_spectrogram',
'spectral_centroid', 'spectral_centroid',
'allpass_biquad', 'allpass_biquad',
'band_biquad', 'band_biquad',
......
...@@ -13,6 +13,7 @@ import torchaudio ...@@ -13,6 +13,7 @@ import torchaudio
__all__ = [ __all__ = [
"spectrogram", "spectrogram",
"inverse_spectrogram",
"griffinlim", "griffinlim",
"amplitude_to_DB", "amplitude_to_DB",
"DB_to_amplitude", "DB_to_amplitude",
...@@ -135,6 +136,86 @@ def spectrogram( ...@@ -135,6 +136,86 @@ def spectrogram(
return spec_f return spec_f
def inverse_spectrogram(
spectrogram: Tensor,
length: Optional[int],
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
normalized: bool,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
) -> Tensor:
r"""Create an inverse spectrogram or a batch of inverse spectrograms from the provided
complex-valued spectrogram.
Args:
spectrogram (Tensor): Complex tensor of audio of dimension (..., freq, time).
length (int, optional): The output length of the waveform.
pad (int): Two sided padding of signal. It is only effective when ``length`` is provided.
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
normalized (bool): Whether the stft output was normalized by magnitude
center (bool, optional): whether the waveform was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True``
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. This parameter is provided for compatibility with the
spectrogram function and is not used. Default: ``"reflect"``
onesided (bool, optional): controls whether spectrogram was done in onesided mode.
Default: ``True``
Returns:
Tensor: Dimension (..., time). Least squares estimation of the original signal.
"""
if spectrogram.dtype == torch.float32 or spectrogram.dtype == torch.float64:
warnings.warn(
"The use of pseudo complex type in inverse_spectrogram is now deprecated. "
"Please migrate to native complex type by using a complex tensor as input. "
"If the input is generated via spectrogram() function or transform, please use "
"return_complex=True as an argument to that function. "
"Please refer to https://github.com/pytorch/audio/issues/1337 "
"for more details about torchaudio's plan to migrate to native complex type."
)
spectrogram = torch.view_as_complex(spectrogram)
if normalized:
spectrogram = spectrogram * window.pow(2.).sum().sqrt()
# pack batch
shape = spectrogram.size()
spectrogram = spectrogram.reshape(-1, shape[-2], shape[-1])
# default values are consistent with librosa.core.spectrum._spectrogram
waveform = torch.istft(
input=spectrogram,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=center,
normalized=False,
onesided=onesided,
length=length + 2 * pad if length is not None else None,
return_complex=False,
)
if length is not None and pad > 0:
# remove padding from front and back
waveform = waveform[:, pad:-pad]
# unpack batch
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
return waveform
def _get_complex_dtype(real_dtype: torch.dtype): def _get_complex_dtype(real_dtype: torch.dtype):
if real_dtype == torch.double: if real_dtype == torch.double:
return torch.cdouble return torch.cdouble
......
...@@ -15,6 +15,7 @@ from .functional.functional import ( ...@@ -15,6 +15,7 @@ from .functional.functional import (
__all__ = [ __all__ = [
'Spectrogram', 'Spectrogram',
'InverseSpectrogram',
'GriffinLim', 'GriffinLim',
'AmplitudeToDB', 'AmplitudeToDB',
'MelScale', 'MelScale',
...@@ -133,6 +134,85 @@ class Spectrogram(torch.nn.Module): ...@@ -133,6 +134,85 @@ class Spectrogram(torch.nn.Module):
) )
class InverseSpectrogram(torch.nn.Module):
r"""Create an inverse spectrogram to recover an audio signal from a spectrogram.
Args:
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
normalized (bool, optional): Whether the spectrogram was normalized by magnitude after stft.
(Default: ``False``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether the signal in spectrogram was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``)
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. (Default: ``"reflect"``)
onesided (bool, optional): controls whether spectrogram was used to return half of results to
avoid redundancy (Default: ``True``)
Example
>>> batch, freq, time = 2, 257, 100
>>> length = 25344
>>> spectrogram = torch.randn(batch, freq, time, dtype=torch.cdouble)
>>> transform = transforms.InverseSpectrogram(n_fft=512)
>>> waveform = transform(spectrogram, length)
"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
def __init__(self,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
normalized: bool = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True) -> None:
super(InverseSpectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
# number of frequencies due to onesided=True in torch.stft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.pad = pad
self.normalized = normalized
self.center = center
self.pad_mode = pad_mode
self.onesided = onesided
def forward(self, spectrogram: Tensor, length: Optional[int] = None) -> Tensor:
r"""
Args:
spectrogram (Tensor): Complex tensor of audio of dimension (..., freq, time).
length (int, optional): The output length of the waveform.
Returns:
Tensor: Dimension (..., time), Least squares estimation of the original signal.
"""
return F.inverse_spectrogram(
spectrogram,
length,
self.pad,
self.window,
self.n_fft,
self.hop_length,
self.win_length,
self.normalized,
self.center,
self.pad_mode,
self.onesided,
)
class GriffinLim(torch.nn.Module): class GriffinLim(torch.nn.Module):
r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation. r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment