"docs/vscode:/vscode.git/clone" did not exist on "26f62dc5beab1ec9d36250c396726c4085ab36c2"
Unverified Commit 6a677ac8 authored by moto's avatar moto Committed by GitHub
Browse files

Add `return_complex` to F.spectrogram and T.Spectrogram (#1366)

* Update spectrogram to use complex
* Update autograd test
* Update TS test
* Update librosa test
parent 0433b7aa
......@@ -29,6 +29,10 @@ class AutogradTestMixin(TestBaseMixin):
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
@parameterized.expand([
({'pad': 0, 'normalized': False, 'power': None, 'return_complex': True}, ),
({'pad': 3, 'normalized': False, 'power': None, 'return_complex': True}, ),
({'pad': 0, 'normalized': True, 'power': None, 'return_complex': True}, ),
({'pad': 3, 'normalized': True, 'power': None, 'return_complex': True}, ),
({'pad': 0, 'normalized': False, 'power': None}, ),
({'pad': 3, 'normalized': False, 'power': None}, ),
({'pad': 0, 'normalized': True, 'power': None}, ),
......
......@@ -45,6 +45,20 @@ class TestTransforms(common_utils.TorchaudioTestCase):
out_torch = spect_transform(sound).squeeze().cpu()
self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)
def test_spectrogram_complex(self):
n_fft = 400
hop_length = 200
sample_rate = 16000
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
sound_librosa = sound.cpu().numpy().squeeze()
spect_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True)
out_librosa, _ = librosa.core.spectrum._spectrogram(
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=1)
out_torch = spect_transform(sound).squeeze()
self.assertEqual(out_torch.abs(), torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)
@parameterized.expand([
param(norm=norm, mel_scale=mel_scale, **p.kwargs)
for p in [
......
......@@ -25,6 +25,10 @@ class Transforms(common_utils.TestBaseMixin):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(), tensor)
def test_Spectrogram_return_complex(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(power=None, return_complex=True), tensor)
@skipIfRocm
def test_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
......
......@@ -48,7 +48,8 @@ def spectrogram(
normalized: bool,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True
onesided: bool = True,
return_complex: bool = False,
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
......@@ -71,12 +72,22 @@ def spectrogram(
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
return_complex (bool, optional):
``return_complex = True``, this function returns the resulting Tensor in
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
dimension for real and imaginary parts. (see ``torch.view_as_real``).
When ``power`` is provided, the value must be False, as the resulting
Tensor represents real-valued power.
Returns:
Tensor: Dimension (..., freq, time), freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
if power is not None and return_complex:
raise ValueError(
'When `power` is provided, the return value is real-valued. '
'Therefore, `return_complex` must be False.')
if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it
......@@ -109,7 +120,9 @@ def spectrogram(
if power == 1.0:
return spec_f.abs()
return spec_f.abs().pow(power)
if not return_complex:
return torch.view_as_real(spec_f)
return spec_f
def griffinlim(
......
......@@ -52,6 +52,12 @@ class Spectrogram(torch.nn.Module):
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy Default: ``True``
return_complex (bool, optional):
``return_complex = True``, this function returns the resulting Tensor in
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
dimension for real and imaginary parts. (see ``torch.view_as_real``).
When ``power`` is provided, the value must be False, as the resulting
Tensor represents real-valued power.
"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
......@@ -66,7 +72,8 @@ class Spectrogram(torch.nn.Module):
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True) -> None:
onesided: bool = True,
return_complex: bool = False) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
......@@ -81,6 +88,7 @@ class Spectrogram(torch.nn.Module):
self.center = center
self.pad_mode = pad_mode
self.onesided = onesided
self.return_complex = return_complex
def forward(self, waveform: Tensor) -> Tensor:
r"""
......@@ -103,7 +111,8 @@ class Spectrogram(torch.nn.Module):
self.normalized,
self.center,
self.pad_mode,
self.onesided
self.onesided,
self.return_complex,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment