"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "8d858c380e59fb307e5e8774ec7fa1866384345c"
Unverified Commit 22e7e877 authored by Fabian-Robert Stöter's avatar Fabian-Robert Stöter Committed by GitHub
Browse files

Expose additional stft arguments to spectrogram (#892)

parent a4d643ea
...@@ -39,7 +39,10 @@ def spectrogram( ...@@ -39,7 +39,10 @@ def spectrogram(
hop_length: int, hop_length: int,
win_length: int, win_length: int,
power: Optional[float], power: Optional[float],
normalized: bool normalized: bool,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True
) -> Tensor: ) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex. The spectrogram can be either magnitude-only or complex.
...@@ -55,6 +58,13 @@ def spectrogram( ...@@ -55,6 +58,13 @@ def spectrogram(
(must be > 0) e.g., 1 for energy, 2 for power, etc. (must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead. If None, then the complex spectrum is returned instead.
normalized (bool): Whether to normalize by magnitude after stft normalized (bool): Whether to normalize by magnitude after stft
center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True``
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
Returns: Returns:
Tensor: Dimension (..., freq, time), freq is Tensor: Dimension (..., freq, time), freq is
...@@ -77,10 +87,10 @@ def spectrogram( ...@@ -77,10 +87,10 @@ def spectrogram(
hop_length=hop_length, hop_length=hop_length,
win_length=win_length, win_length=win_length,
window=window, window=window,
center=True, center=center,
pad_mode="reflect", pad_mode=pad_mode,
normalized=False, normalized=False,
onesided=True, onesided=onesided,
return_complex=True, return_complex=True,
) )
......
...@@ -47,6 +47,13 @@ class Spectrogram(torch.nn.Module): ...@@ -47,6 +47,13 @@ class Spectrogram(torch.nn.Module):
If None, then the complex spectrum is returned instead. (Default: ``2``) If None, then the complex spectrum is returned instead. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True``
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy Default: ``True``
""" """
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
...@@ -58,7 +65,10 @@ class Spectrogram(torch.nn.Module): ...@@ -58,7 +65,10 @@ class Spectrogram(torch.nn.Module):
window_fn: Callable[..., Tensor] = torch.hann_window, window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2., power: Optional[float] = 2.,
normalized: bool = False, normalized: bool = False,
wkwargs: Optional[dict] = None) -> None: wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True) -> None:
super(Spectrogram, self).__init__() super(Spectrogram, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1 # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
...@@ -70,6 +80,9 @@ class Spectrogram(torch.nn.Module): ...@@ -70,6 +80,9 @@ class Spectrogram(torch.nn.Module):
self.pad = pad self.pad = pad
self.power = power self.power = power
self.normalized = normalized self.normalized = normalized
self.center = center
self.pad_mode = pad_mode
self.onesided = onesided
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
...@@ -81,8 +94,19 @@ class Spectrogram(torch.nn.Module): ...@@ -81,8 +94,19 @@ class Spectrogram(torch.nn.Module):
``n_fft // 2 + 1`` where ``n_fft`` is the number of ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame). Fourier bins, and time is the number of window hops (n_frame).
""" """
return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, return F.spectrogram(
self.win_length, self.power, self.normalized) waveform,
self.pad,
self.window,
self.n_fft,
self.hop_length,
self.win_length,
self.power,
self.normalized,
self.center,
self.pad_mode,
self.onesided
)
class GriffinLim(torch.nn.Module): class GriffinLim(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment