Unverified Commit 5432a3f5 authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Default to native complex type when returning raw spect… (#1549)

* [BC-Breaking] Default to native complex type when returning raw spectrogram

Part of https://github.com/pytorch/audio/issues/1337 .

- This code changes the return type of spectrogram to be native complex dtype,
when (and only when) returning raw (complex-valued) spectrogram.
- Change `return_complex=False` to `return_complex=True` in spectrogram ops.
- `return_complex` is only effective when `power` is `None`. It is ignored for
cases where `power` is not `None`. Because the returned Tensor is power spectrogram,
which is real-valued Tensors.
parent f2a4aac0
...@@ -51,20 +51,34 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -51,20 +51,34 @@ class Functional(TempDirMixin, TestBaseMixin):
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def test_spectrogram(self): def test_spectrogram_complex(self):
def func(tensor): def func(tensor):
n_fft = 400 n_fft = 400
ws = 400 ws = 400
hop = 200 hop = 200
pad = 0 pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2. power = None
normalize = False normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize) return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)
tensor = common_utils.get_whitenoise() tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
def test_spectrogram_real(self):
def func(tensor):
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.
normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize, return_complex=False)
tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)
@skipIfRocm @skipIfRocm
def test_griffinlim(self): def test_griffinlim(self):
def func(tensor): def func(tensor):
......
...@@ -49,7 +49,7 @@ def spectrogram( ...@@ -49,7 +49,7 @@ def spectrogram(
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
onesided: bool = True, onesided: bool = True,
return_complex: bool = False, return_complex: bool = True,
) -> Tensor: ) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex. The spectrogram can be either magnitude-only or complex.
...@@ -76,8 +76,10 @@ def spectrogram( ...@@ -76,8 +76,10 @@ def spectrogram(
Indicates whether the resulting complex-valued Tensor should be represented with Indicates whether the resulting complex-valued Tensor should be represented with
native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
mimicking complex value with an extra dimension for real and imaginary parts. mimicking complex value with an extra dimension for real and imaginary parts.
This argument is only effective when ``power=None``. (See also ``torch.view_as_real``.)
See also ``torch.view_as_real``. This argument is only effective when ``power=None``. It is ignored for
cases where ``power`` is a number as in those cases, the returned tensor is
power spectrogram, which is a real-valued tensor.
Returns: Returns:
Tensor: Dimension (..., freq, time), freq is Tensor: Dimension (..., freq, time), freq is
......
...@@ -63,8 +63,10 @@ class Spectrogram(torch.nn.Module): ...@@ -63,8 +63,10 @@ class Spectrogram(torch.nn.Module):
Indicates whether the resulting complex-valued Tensor should be represented with Indicates whether the resulting complex-valued Tensor should be represented with
native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
mimicking complex value with an extra dimension for real and imaginary parts. mimicking complex value with an extra dimension for real and imaginary parts.
This argument is only effective when ``power=None``. (See also ``torch.view_as_real``.)
See also ``torch.view_as_real``. This argument is only effective when ``power=None``. It is ignored for
cases where ``power`` is a number as in those cases, the returned tensor is
power spectrogram, which is a real-valued tensor.
""" """
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
...@@ -80,7 +82,7 @@ class Spectrogram(torch.nn.Module): ...@@ -80,7 +82,7 @@ class Spectrogram(torch.nn.Module):
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
onesided: bool = True, onesided: bool = True,
return_complex: bool = False) -> None: return_complex: bool = True) -> None:
super(Spectrogram, self).__init__() super(Spectrogram, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1 # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment