Unverified Commit 5ec6ada6 authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Drop pseudo complex support from spectrogram (#1958)

Following the plan #1337, this commit drops the support for pseudo complex type from 
`F.spectrogram` and `T.Spectrogram`.

It also deprecates the use of `return_complex` argument.
parent f2eec77b
......@@ -47,7 +47,7 @@ class Functional(TempDirMixin, TestBaseMixin):
self.assertEqual(ts_output, output)
def test_spectrogram_complex(self):
def test_spectrogram(self):
def func(tensor):
n_fft = 400
ws = 400
......@@ -61,21 +61,7 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)
def test_spectrogram_real(self):
def func(tensor):
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.
normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize, return_complex=False)
tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)
def test_inverse_spectrogram_complex(self):
def test_inverse_spectrogram(self):
def func(tensor):
length = 400
n_fft = 400
......@@ -90,22 +76,6 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
self._assert_consistency_complex(func, tensor)
def test_inverse_spectrogram_real(self):
def func(tensor):
length = 400
n_fft = 400
hop = 200
ws = 400
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
normalize = False
return F.inverse_spectrogram(tensor, length, pad, window, n_fft, hop, ws, normalize)
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05)
tensor = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
tensor = torch.view_as_real(tensor)
self._assert_consistency(func, tensor)
@skipIfRocm
def test_griffinlim(self):
def func(tensor):
......
......@@ -77,15 +77,11 @@ class AutogradTestMixin(TestBaseMixin):
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
@parameterized.expand([(False, ), (True, )])
def test_inverse_spectrogram(self, return_complex):
def test_inverse_spectrogram(self):
# create a realistic input:
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
length = waveform.shape[-1]
spectrogram = get_spectrogram(waveform, n_fft=400)
if not return_complex:
spectrogram = torch.view_as_real(spectrogram)
# test
inv_transform = T.InverseSpectrogram(n_fft=400)
self.assert_grad(inv_transform, [spectrogram, length])
......
......@@ -50,12 +50,6 @@ class Transforms(TestBaseMixin):
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
self._assert_consistency_complex(T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram)
def test_InverseSpectrogram_pseudocomplex(self):
tensor = common_utils.get_whitenoise(sample_rate=8000)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = torch.view_as_real(spectrogram)
self._assert_consistency(T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram)
@skipIfRocm
def test_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
......
......@@ -53,7 +53,7 @@ def spectrogram(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: bool = True,
return_complex: Optional[bool] = None,
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
......@@ -77,25 +77,18 @@ def spectrogram(
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
return_complex (bool, optional):
Indicates whether the resulting complex-valued Tensor should be represented with
native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
mimicking complex value with an extra dimension for real and imaginary parts.
(See also ``torch.view_as_real``.)
This argument is only effective when ``power=None``. It is ignored for
cases where ``power`` is a number as in those cases, the returned tensor is
power spectrogram, which is a real-valued tensor.
Deprecated and not used.
Returns:
Tensor: Dimension `(..., freq, time)`, freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
if power is None and not return_complex:
if return_complex is not None:
warnings.warn(
"The use of pseudo complex type in spectrogram is now deprecated."
"Please migrate to native complex type by providing `return_complex=True`. "
"Please refer to https://github.com/pytorch/audio/issues/1337 "
"for more details about torchaudio's plan to migrate to native complex type."
"`return_complex` argument is now deprecated and is not effective."
"`torchaudio.functional.spectrogram(power=None)` always returns a tensor with "
"complex dtype. Please remove the argument in the function call."
)
if pad > 0:
......@@ -129,8 +122,6 @@ def spectrogram(
if power == 1.0:
return spec_f.abs()
return spec_f.abs().pow(power)
if not return_complex:
return torch.view_as_real(spec_f)
return spec_f
......@@ -172,16 +163,8 @@ def inverse_spectrogram(
Tensor: Dimension `(..., time)`. Least squares estimation of the original signal.
"""
if spectrogram.dtype == torch.float32 or spectrogram.dtype == torch.float64:
warnings.warn(
"The use of pseudo complex type in inverse_spectrogram is now deprecated. "
"Please migrate to native complex type by using a complex tensor as input. "
"If the input is generated via spectrogram() function or transform, please use "
"return_complex=True as an argument to that function. "
"Please refer to https://github.com/pytorch/audio/issues/1337 "
"for more details about torchaudio's plan to migrate to native complex type."
)
spectrogram = torch.view_as_complex(spectrogram)
if not spectrogram.is_complex():
raise ValueError("Expected `spectrogram` to be complex dtype.")
if normalized:
spectrogram = spectrogram * window.pow(2.).sum().sqrt()
......
......@@ -65,13 +65,7 @@ class Spectrogram(torch.nn.Module):
onesided (bool, optional): controls whether to return half of results to
avoid redundancy (Default: ``True``)
return_complex (bool, optional):
Indicates whether the resulting complex-valued Tensor should be represented with
native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
mimicking complex value with an extra dimension for real and imaginary parts.
(See also ``torch.view_as_real``.)
This argument is only effective when ``power=None``. It is ignored for
cases where ``power`` is a number as in those cases, the returned tensor is
power spectrogram, which is a real-valued tensor.
Deprecated and not used.
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
......@@ -93,7 +87,7 @@ class Spectrogram(torch.nn.Module):
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: bool = True) -> None:
return_complex: Optional[bool] = None) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
......@@ -108,7 +102,12 @@ class Spectrogram(torch.nn.Module):
self.center = center
self.pad_mode = pad_mode
self.onesided = onesided
self.return_complex = return_complex
if return_complex is not None:
warnings.warn(
"`return_complex` argument is now deprecated and is not effective."
"`torchaudio.transforms.Spectrogram(power=None)` always returns a tensor with "
"complex dtype. Please remove the argument in the function call."
)
def forward(self, waveform: Tensor) -> Tensor:
r"""
......@@ -132,7 +131,6 @@ class Spectrogram(torch.nn.Module):
self.center,
self.pad_mode,
self.onesided,
self.return_complex,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment