Commit 0fde7c57 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Add Union normalization parameter on spectrogram and inverse spectrogram (#2554)

Summary:
Add str to normalized parameter to enable frame_length based normalization to align with torch implementation of stft. Addresses issue https://github.com/pytorch/audio/issues/2104

Pull Request resolved: https://github.com/pytorch/audio/pull/2554

Reviewed By: carolineechen, mthrok

Differential Revision: D38247554

Pulled By: skim0514

fbshipit-source-id: c243c7a6b8fda2a1e565cef4600f7c5a06baf602
parent 338e3104
...@@ -245,6 +245,77 @@ class Functional(TestBaseMixin): ...@@ -245,6 +245,77 @@ class Functional(TestBaseMixin):
spec.sum().backward() spec.sum().backward()
assert not x.grad.isnan().sum() assert not x.grad.isnan().sum()
@parameterized.expand(
[
(1024,),
(2048,),
(4096,),
]
)
def test_spectrogram_normalization_hann_window(self, nfft):
"""This test assumes that currently, torch.stft and the existing math behind spectrogram are correct.
The test is checking that in relation to one another, the normalization factors correctly align based on
mathematical prediction. Using spec_false as a base, which has no normalization factors, we check to see that
turning normalized as ``True`` or ``"window"`` will have a normalization factor of the sum of squares of hann
window, which is calculated to be sqrt(3 * nfft / 8).
Next, when ``normalized`` is ``"frame_length"``, we are using the normalization in torch.stft, therefore we
assume that it is correctly normalized by a factor of sqrt(nfft). This test does not test the accuracy of
spectrogram, but is testing the relative factors of normalization and that they align upon the frame_length
and chosen normalize parameter.
https://github.com/pytorch/pytorch/issues/81428
"""
x = torch.rand(1, 22050)
spec_false = F.spectrogram(
x,
pad=0,
window=torch.hann_window(nfft, device=x.device, dtype=x.dtype),
n_fft=nfft,
hop_length=4,
win_length=nfft,
power=None,
normalized=False,
)
spec_true = F.spectrogram(
x,
pad=0,
window=torch.hann_window(nfft, device=x.device, dtype=x.dtype),
n_fft=nfft,
hop_length=4,
win_length=nfft,
power=None,
normalized=True,
)
spec_window = F.spectrogram(
x,
pad=0,
window=torch.hann_window(nfft, device=x.device, dtype=x.dtype),
n_fft=nfft,
hop_length=4,
win_length=nfft,
power=None,
normalized="window",
)
spec_frame = F.spectrogram(
x,
pad=0,
window=torch.hann_window(nfft, device=x.device, dtype=x.dtype),
n_fft=nfft,
hop_length=4,
win_length=nfft,
power=None,
normalized="frame_length",
)
norm_factor = math.sqrt(3 * nfft / 8)
frame_norm_factor = math.sqrt(nfft)
self.assertEqual(spec_true, spec_window)
self.assertEqual(spec_true, spec_false / norm_factor)
self.assertEqual(spec_frame, spec_false / frame_norm_factor)
def test_compute_deltas_one_channel(self): def test_compute_deltas_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device) specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device) expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
......
...@@ -46,7 +46,15 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -46,7 +46,15 @@ class Functional(TempDirMixin, TestBaseMixin):
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def test_spectrogram(self): @parameterized.expand(
[
(True,),
(False,),
("window",),
("frame_length",),
]
)
def test_spectrogram(self, normalize):
waveform = common_utils.get_whitenoise() waveform = common_utils.get_whitenoise()
n_fft = 400 n_fft = 400
ws = 400 ws = 400
...@@ -54,12 +62,19 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -54,12 +62,19 @@ class Functional(TempDirMixin, TestBaseMixin):
pad = 0 pad = 0
window = torch.hann_window(ws, device=waveform.device, dtype=waveform.dtype) window = torch.hann_window(ws, device=waveform.device, dtype=waveform.dtype)
power = None power = None
normalize = False
self._assert_consistency( self._assert_consistency(
F.spectrogram, (waveform, pad, window, n_fft, hop, ws, power, normalize, True, "reflect", True, True) F.spectrogram, (waveform, pad, window, n_fft, hop, ws, power, normalize, True, "reflect", True, True)
) )
def test_inverse_spectrogram(self): @parameterized.expand(
[
(True,),
(False,),
("window",),
("frame_length",),
]
)
def test_inverse_spectrogram(self, normalize):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05)
specgram = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200) specgram = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
length = 400 length = 400
...@@ -68,7 +83,6 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -68,7 +83,6 @@ class Functional(TempDirMixin, TestBaseMixin):
ws = 400 ws = 400
pad = 0 pad = 0
window = torch.hann_window(ws, device=specgram.device, dtype=torch.float64) window = torch.hann_window(ws, device=specgram.device, dtype=torch.float64)
normalize = False
self._assert_consistency_complex( self._assert_consistency_complex(
F.inverse_spectrogram, (specgram, length, pad, window, n_fft, hop, ws, normalize, True, "reflect", True) F.inverse_spectrogram, (specgram, length, pad, window, n_fft, hop, ws, normalize, True, "reflect", True)
) )
......
...@@ -54,7 +54,7 @@ def spectrogram( ...@@ -54,7 +54,7 @@ def spectrogram(
hop_length: int, hop_length: int,
win_length: int, win_length: int,
power: Optional[float], power: Optional[float],
normalized: bool, normalized: Union[bool, str],
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
onesided: bool = True, onesided: bool = True,
...@@ -77,7 +77,11 @@ def spectrogram( ...@@ -77,7 +77,11 @@ def spectrogram(
power (float or None): Exponent for the magnitude spectrogram, power (float or None): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead. If None, then the complex spectrum is returned instead.
normalized (bool): Whether to normalize by magnitude after stft normalized (bool or str): Whether to normalize by magnitude after stft. If input is str, choices are
``"window"`` and ``"frame_length"``, if specific normalization type is desirable. ``True`` maps to
``"window"``. When normalized on ``"window"``, waveform is normalized upon the window's L2 energy. If
normalized on ``"frame_length"``, waveform is normalized by dividing by
:math:`(\text{frame\_length})^{0.5}`.
center (bool, optional): whether to pad :attr:`waveform` on both sides so center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True`` Default: ``True``
...@@ -104,6 +108,8 @@ def spectrogram( ...@@ -104,6 +108,8 @@ def spectrogram(
# TODO add "with torch.no_grad():" back when JIT supports it # TODO add "with torch.no_grad():" back when JIT supports it
waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
frame_length_norm, window_norm = _get_spec_norms(normalized)
# pack batch # pack batch
shape = waveform.size() shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1]) waveform = waveform.reshape(-1, shape[-1])
...@@ -117,7 +123,7 @@ def spectrogram( ...@@ -117,7 +123,7 @@ def spectrogram(
window=window, window=window,
center=center, center=center,
pad_mode=pad_mode, pad_mode=pad_mode,
normalized=False, normalized=frame_length_norm,
onesided=onesided, onesided=onesided,
return_complex=True, return_complex=True,
) )
...@@ -125,7 +131,7 @@ def spectrogram( ...@@ -125,7 +131,7 @@ def spectrogram(
# unpack batch # unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:]) spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
if normalized: if window_norm:
spec_f /= window.pow(2.0).sum().sqrt() spec_f /= window.pow(2.0).sum().sqrt()
if power is not None: if power is not None:
if power == 1.0: if power == 1.0:
...@@ -142,7 +148,7 @@ def inverse_spectrogram( ...@@ -142,7 +148,7 @@ def inverse_spectrogram(
n_fft: int, n_fft: int,
hop_length: int, hop_length: int,
win_length: int, win_length: int,
normalized: bool, normalized: Union[bool, str],
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
onesided: bool = True, onesided: bool = True,
...@@ -162,7 +168,9 @@ def inverse_spectrogram( ...@@ -162,7 +168,9 @@ def inverse_spectrogram(
n_fft (int): Size of FFT n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows hop_length (int): Length of hop between STFT windows
win_length (int): Window size win_length (int): Window size
normalized (bool): Whether the stft output was normalized by magnitude normalized (bool or str): Whether the stft output was normalized by magnitude. If input is str, choices are
``"window"`` and ``"frame_length"``, dependent on normalization mode. ``True`` maps to
``"window"``.
center (bool, optional): whether the waveform was padded on both sides so center (bool, optional): whether the waveform was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True`` Default: ``True``
...@@ -176,10 +184,12 @@ def inverse_spectrogram( ...@@ -176,10 +184,12 @@ def inverse_spectrogram(
Tensor: Dimension `(..., time)`. Least squares estimation of the original signal. Tensor: Dimension `(..., time)`. Least squares estimation of the original signal.
""" """
frame_length_norm, window_norm = _get_spec_norms(normalized)
if not spectrogram.is_complex(): if not spectrogram.is_complex():
raise ValueError("Expected `spectrogram` to be complex dtype.") raise ValueError("Expected `spectrogram` to be complex dtype.")
if normalized: if window_norm:
spectrogram = spectrogram * window.pow(2.0).sum().sqrt() spectrogram = spectrogram * window.pow(2.0).sum().sqrt()
# pack batch # pack batch
...@@ -194,7 +204,7 @@ def inverse_spectrogram( ...@@ -194,7 +204,7 @@ def inverse_spectrogram(
win_length=win_length, win_length=win_length,
window=window, window=window,
center=center, center=center,
normalized=False, normalized=frame_length_norm,
onesided=onesided, onesided=onesided,
length=length + 2 * pad if length is not None else None, length=length + 2 * pad if length is not None else None,
return_complex=False, return_complex=False,
...@@ -210,6 +220,23 @@ def inverse_spectrogram( ...@@ -210,6 +220,23 @@ def inverse_spectrogram(
return waveform return waveform
def _get_spec_norms(normalized: Union[str, bool]):
frame_length_norm, window_norm = False, False
if torch.jit.isinstance(normalized, str):
if normalized not in ["frame_length", "window"]:
raise ValueError("Invalid normalized parameter: {}".format(normalized))
if normalized == "frame_length":
frame_length_norm = True
elif normalized == "window":
window_norm = True
elif torch.jit.isinstance(normalized, bool):
if normalized:
window_norm = True
else:
raise TypeError("Input type not supported")
return frame_length_norm, window_norm
def _get_complex_dtype(real_dtype: torch.dtype): def _get_complex_dtype(real_dtype: torch.dtype):
if real_dtype == torch.double: if real_dtype == torch.double:
return torch.cdouble return torch.cdouble
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import math import math
import warnings import warnings
from typing import Callable, Optional from typing import Callable, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -37,7 +37,9 @@ class Spectrogram(torch.nn.Module): ...@@ -37,7 +37,9 @@ class Spectrogram(torch.nn.Module):
power (float or None, optional): Exponent for the magnitude spectrogram, power (float or None, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead. (Default: ``2``) If None, then the complex spectrum is returned instead. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) normalized (bool or str, optional): Whether to normalize by magnitude after stft. If input is str, choices are
``"window"`` and ``"frame_length"``, if specific normalization type is desirable. ``True`` maps to
``"window"``. (Default: ``False``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether to pad :attr:`waveform` on both sides so center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
...@@ -65,7 +67,7 @@ class Spectrogram(torch.nn.Module): ...@@ -65,7 +67,7 @@ class Spectrogram(torch.nn.Module):
pad: int = 0, pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window, window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.0, power: Optional[float] = 2.0,
normalized: bool = False, normalized: Union[bool, str] = False,
wkwargs: Optional[dict] = None, wkwargs: Optional[dict] = None,
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
...@@ -132,8 +134,9 @@ class InverseSpectrogram(torch.nn.Module): ...@@ -132,8 +134,9 @@ class InverseSpectrogram(torch.nn.Module):
pad (int, optional): Two sided padding of signal. (Default: ``0``) pad (int, optional): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[..., Tensor], optional): A function to create a window tensor window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
normalized (bool, optional): Whether the spectrogram was normalized by magnitude after stft. normalized (bool or str, optional): Whether the stft output was normalized by magnitude. If input is str,
(Default: ``False``) choices are ``"window"`` and ``"frame_length"``, dependent on normalization mode. ``True`` maps to
``"window"``. (Default: ``False``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether the signal in spectrogram was padded on both sides so center (bool, optional): whether the signal in spectrogram was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
...@@ -159,7 +162,7 @@ class InverseSpectrogram(torch.nn.Module): ...@@ -159,7 +162,7 @@ class InverseSpectrogram(torch.nn.Module):
hop_length: Optional[int] = None, hop_length: Optional[int] = None,
pad: int = 0, pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window, window_fn: Callable[..., Tensor] = torch.hann_window,
normalized: bool = False, normalized: Union[bool, str] = False,
wkwargs: Optional[dict] = None, wkwargs: Optional[dict] = None,
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment