Commit 0fde7c57 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Add Union normalization parameter on spectrogram and inverse spectrogram (#2554)

Summary:
Add str to normalized parameter to enable frame_length based normalization to align with torch implementation of stft. Addresses issue https://github.com/pytorch/audio/issues/2104

Pull Request resolved: https://github.com/pytorch/audio/pull/2554

Reviewed By: carolineechen, mthrok

Differential Revision: D38247554

Pulled By: skim0514

fbshipit-source-id: c243c7a6b8fda2a1e565cef4600f7c5a06baf602
parent 338e3104
......@@ -245,6 +245,77 @@ class Functional(TestBaseMixin):
spec.sum().backward()
assert not x.grad.isnan().sum()
@parameterized.expand(
[
(1024,),
(2048,),
(4096,),
]
)
def test_spectrogram_normalization_hann_window(self, nfft):
"""This test assumes that currently, torch.stft and the existing math behind spectrogram are correct.
The test is checking that in relation to one another, the normalization factors correctly align based on
mathematical prediction. Using spec_false as a base, which has no normalization factors, we check to see that
turning normalized as ``True`` or ``"window"`` will have a normalization factor of the sum of squares of hann
window, which is calculated to be sqrt(3 * nfft / 8).
Next, when ``normalized`` is ``"frame_length"``, we are using the normalization in torch.stft, therefore we
assume that it is correctly normalized by a factor of sqrt(nfft). This test does not test the accuracy of
spectrogram, but is testing the relative factors of normalization and that they align upon the frame_length
and chosen normalize parameter.
https://github.com/pytorch/pytorch/issues/81428
"""
x = torch.rand(1, 22050)
spec_false = F.spectrogram(
x,
pad=0,
window=torch.hann_window(nfft, device=x.device, dtype=x.dtype),
n_fft=nfft,
hop_length=4,
win_length=nfft,
power=None,
normalized=False,
)
spec_true = F.spectrogram(
x,
pad=0,
window=torch.hann_window(nfft, device=x.device, dtype=x.dtype),
n_fft=nfft,
hop_length=4,
win_length=nfft,
power=None,
normalized=True,
)
spec_window = F.spectrogram(
x,
pad=0,
window=torch.hann_window(nfft, device=x.device, dtype=x.dtype),
n_fft=nfft,
hop_length=4,
win_length=nfft,
power=None,
normalized="window",
)
spec_frame = F.spectrogram(
x,
pad=0,
window=torch.hann_window(nfft, device=x.device, dtype=x.dtype),
n_fft=nfft,
hop_length=4,
win_length=nfft,
power=None,
normalized="frame_length",
)
norm_factor = math.sqrt(3 * nfft / 8)
frame_norm_factor = math.sqrt(nfft)
self.assertEqual(spec_true, spec_window)
self.assertEqual(spec_true, spec_false / norm_factor)
self.assertEqual(spec_frame, spec_false / frame_norm_factor)
def test_compute_deltas_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
......
......@@ -46,7 +46,15 @@ class Functional(TempDirMixin, TestBaseMixin):
self.assertEqual(ts_output, output)
def test_spectrogram(self):
@parameterized.expand(
[
(True,),
(False,),
("window",),
("frame_length",),
]
)
def test_spectrogram(self, normalize):
waveform = common_utils.get_whitenoise()
n_fft = 400
ws = 400
......@@ -54,12 +62,19 @@ class Functional(TempDirMixin, TestBaseMixin):
pad = 0
window = torch.hann_window(ws, device=waveform.device, dtype=waveform.dtype)
power = None
normalize = False
self._assert_consistency(
F.spectrogram, (waveform, pad, window, n_fft, hop, ws, power, normalize, True, "reflect", True, True)
)
def test_inverse_spectrogram(self):
@parameterized.expand(
[
(True,),
(False,),
("window",),
("frame_length",),
]
)
def test_inverse_spectrogram(self, normalize):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05)
specgram = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
length = 400
......@@ -68,7 +83,6 @@ class Functional(TempDirMixin, TestBaseMixin):
ws = 400
pad = 0
window = torch.hann_window(ws, device=specgram.device, dtype=torch.float64)
normalize = False
self._assert_consistency_complex(
F.inverse_spectrogram, (specgram, length, pad, window, n_fft, hop, ws, normalize, True, "reflect", True)
)
......
......@@ -54,7 +54,7 @@ def spectrogram(
hop_length: int,
win_length: int,
power: Optional[float],
normalized: bool,
normalized: Union[bool, str],
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
......@@ -77,7 +77,11 @@ def spectrogram(
power (float or None): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead.
normalized (bool): Whether to normalize by magnitude after stft
normalized (bool or str): Whether to normalize by magnitude after stft. If input is str, choices are
``"window"`` and ``"frame_length"``, if specific normalization type is desirable. ``True`` maps to
``"window"``. When normalized on ``"window"``, waveform is normalized upon the window's L2 energy. If
normalized on ``"frame_length"``, waveform is normalized by dividing by
:math:`(\text{frame\_length})^{0.5}`.
center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True``
......@@ -104,6 +108,8 @@ def spectrogram(
# TODO add "with torch.no_grad():" back when JIT supports it
waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
frame_length_norm, window_norm = _get_spec_norms(normalized)
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
......@@ -117,7 +123,7 @@ def spectrogram(
window=window,
center=center,
pad_mode=pad_mode,
normalized=False,
normalized=frame_length_norm,
onesided=onesided,
return_complex=True,
)
......@@ -125,7 +131,7 @@ def spectrogram(
# unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
if normalized:
if window_norm:
spec_f /= window.pow(2.0).sum().sqrt()
if power is not None:
if power == 1.0:
......@@ -142,7 +148,7 @@ def inverse_spectrogram(
n_fft: int,
hop_length: int,
win_length: int,
normalized: bool,
normalized: Union[bool, str],
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
......@@ -162,7 +168,9 @@ def inverse_spectrogram(
n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
normalized (bool): Whether the stft output was normalized by magnitude
normalized (bool or str): Whether the stft output was normalized by magnitude. If input is str, choices are
``"window"`` and ``"frame_length"``, dependent on normalization mode. ``True`` maps to
``"window"``.
center (bool, optional): whether the waveform was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True``
......@@ -176,10 +184,12 @@ def inverse_spectrogram(
Tensor: Dimension `(..., time)`. Least squares estimation of the original signal.
"""
frame_length_norm, window_norm = _get_spec_norms(normalized)
if not spectrogram.is_complex():
raise ValueError("Expected `spectrogram` to be complex dtype.")
if normalized:
if window_norm:
spectrogram = spectrogram * window.pow(2.0).sum().sqrt()
# pack batch
......@@ -194,7 +204,7 @@ def inverse_spectrogram(
win_length=win_length,
window=window,
center=center,
normalized=False,
normalized=frame_length_norm,
onesided=onesided,
length=length + 2 * pad if length is not None else None,
return_complex=False,
......@@ -210,6 +220,23 @@ def inverse_spectrogram(
return waveform
def _get_spec_norms(normalized: Union[str, bool]):
frame_length_norm, window_norm = False, False
if torch.jit.isinstance(normalized, str):
if normalized not in ["frame_length", "window"]:
raise ValueError("Invalid normalized parameter: {}".format(normalized))
if normalized == "frame_length":
frame_length_norm = True
elif normalized == "window":
window_norm = True
elif torch.jit.isinstance(normalized, bool):
if normalized:
window_norm = True
else:
raise TypeError("Input type not supported")
return frame_length_norm, window_norm
def _get_complex_dtype(real_dtype: torch.dtype):
if real_dtype == torch.double:
return torch.cdouble
......
......@@ -2,7 +2,7 @@
import math
import warnings
from typing import Callable, Optional
from typing import Callable, Optional, Union
import torch
from torch import Tensor
......@@ -37,7 +37,9 @@ class Spectrogram(torch.nn.Module):
power (float or None, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
normalized (bool or str, optional): Whether to normalize by magnitude after stft. If input is str, choices are
``"window"`` and ``"frame_length"``, if specific normalization type is desirable. ``True`` maps to
``"window"``. (Default: ``False``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
......@@ -65,7 +67,7 @@ class Spectrogram(torch.nn.Module):
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.0,
normalized: bool = False,
normalized: Union[bool, str] = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
......@@ -132,8 +134,9 @@ class InverseSpectrogram(torch.nn.Module):
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
normalized (bool, optional): Whether the spectrogram was normalized by magnitude after stft.
(Default: ``False``)
normalized (bool or str, optional): Whether the stft output was normalized by magnitude. If input is str,
choices are ``"window"`` and ``"frame_length"``, dependent on normalization mode. ``True`` maps to
``"window"``. (Default: ``False``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether the signal in spectrogram was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
......@@ -159,7 +162,7 @@ class InverseSpectrogram(torch.nn.Module):
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
normalized: bool = False,
normalized: Union[bool, str] = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment