Unverified Commit 5432a3f5 authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Default to native complex type when returning raw spect… (#1549)

* [BC-Breaking] Default to native complex type when returning raw spectrogram

Part of https://github.com/pytorch/audio/issues/1337 .

- This code changes the return type of spectrogram to be native complex dtype,
when (and only when) returning raw (complex-valued) spectrogram.
- Change `return_complex=False` to `return_complex=True` in spectrogram ops.
- `return_complex` is only effective when `power` is `None`. It is ignored for
cases where `power` is not `None`. Because the returned Tensor is power spectrogram,
which is real-valued Tensors.
parent f2a4aac0
......@@ -51,20 +51,34 @@ class Functional(TempDirMixin, TestBaseMixin):
self.assertEqual(ts_output, output)
def test_spectrogram(self):
def test_spectrogram_complex(self):
def func(tensor):
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.
power = None
normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)
tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)
def test_spectrogram_real(self):
def func(tensor):
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.
normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize, return_complex=False)
tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)
@skipIfRocm
def test_griffinlim(self):
def func(tensor):
......
......@@ -49,7 +49,7 @@ def spectrogram(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: bool = False,
return_complex: bool = True,
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
......@@ -76,8 +76,10 @@ def spectrogram(
Indicates whether the resulting complex-valued Tensor should be represented with
native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
mimicking complex value with an extra dimension for real and imaginary parts.
This argument is only effective when ``power=None``.
See also ``torch.view_as_real``.
(See also ``torch.view_as_real``.)
This argument is only effective when ``power=None``. It is ignored for
cases where ``power`` is a number as in those cases, the returned tensor is
power spectrogram, which is a real-valued tensor.
Returns:
Tensor: Dimension (..., freq, time), freq is
......
......@@ -63,8 +63,10 @@ class Spectrogram(torch.nn.Module):
Indicates whether the resulting complex-valued Tensor should be represented with
native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
mimicking complex value with an extra dimension for real and imaginary parts.
This argument is only effective when ``power=None``.
See also ``torch.view_as_real``.
(See also ``torch.view_as_real``.)
This argument is only effective when ``power=None``. It is ignored for
cases where ``power`` is a number as in those cases, the returned tensor is
power spectrogram, which is a real-valued tensor.
"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
......@@ -80,7 +82,7 @@ class Spectrogram(torch.nn.Module):
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: bool = False) -> None:
return_complex: bool = True) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment