Unverified Commit 1500d4ef authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Complex STFT transform from spectrogram (#327)

* STFT transform and function from #285

* merge options in existing functionality.

* remove dimension 2 check. add test.

* using ...

* update spectrogram test.
parent 5211b843
...@@ -313,6 +313,18 @@ class Tester(unittest.TestCase): ...@@ -313,6 +313,18 @@ class Tester(unittest.TestCase):
computed = transform(specgram) computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_batch_spectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -96,8 +96,7 @@ def istft( ...@@ -96,8 +96,7 @@ def istft(
Args: Args:
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. it has a size of either (channel, fft_size, n_frame, 2) or ( column is a window. it has a size of either (..., fft_size, n_frame, 2)
fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames. hop_length (Optional[int]): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``) (Default: ``win_length // 4``)
...@@ -218,14 +217,15 @@ def istft( ...@@ -218,14 +217,15 @@ def istft(
def spectrogram( def spectrogram(
waveform, pad, window, n_fft, hop_length, win_length, power, normalized waveform, pad, window, n_fft, hop_length, win_length, power, normalized
): ):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor # type: (Tensor, int, Tensor, int, int, int, Optional[int], bool) -> Tensor
r""" r"""
spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized) spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)
Create a spectrogram from a raw audio signal. Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time) waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
pad (int): Two sided padding of signal pad (int): Two sided padding of signal
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT n_fft (int): Size of FFT
...@@ -233,27 +233,36 @@ def spectrogram( ...@@ -233,27 +233,36 @@ def spectrogram(
win_length (int): Window size win_length (int): Window size
power (int): Exponent for the magnitude spectrogram, power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead.
normalized (bool): Whether to normalize by magnitude after stft normalized (bool): Whether to normalize by magnitude after stft
Returns: Returns:
torch.Tensor: Dimension (channel, freq, time), where channel torch.Tensor: Dimension (..., channel, freq, time), where channel
is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame). Fourier bins, and time is the number of window hops (n_frame).
""" """
assert waveform.dim() == 2
if pad > 0: if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it # TODO add "with torch.no_grad():" back when JIT supports it
waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
# default values are consistent with librosa.core.spectrum._spectrogram # default values are consistent with librosa.core.spectrum._spectrogram
spec_f = _stft( spec_f = _stft(
waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
) )
# unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
if normalized: if normalized:
spec_f /= window.pow(2).sum().sqrt() spec_f /= window.pow(2).sum().sqrt()
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor if power is not None:
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor
return spec_f return spec_f
...@@ -431,11 +440,11 @@ def complex_norm(complex_tensor, power=1.0): ...@@ -431,11 +440,11 @@ def complex_norm(complex_tensor, power=1.0):
r"""Compute the norm of complex tensor input. r"""Compute the norm of complex tensor input.
Args: Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`). power (float): Power of the norm. (Default: `1.0`).
Returns: Returns:
torch.Tensor: Power of the normed input tensor. Shape of `(*, )` torch.Tensor: Power of the normed input tensor. Shape of `(..., )`
""" """
if power == 1.0: if power == 1.0:
return torch.norm(complex_tensor, 2, -1) return torch.norm(complex_tensor, 2, -1)
...@@ -448,10 +457,10 @@ def angle(complex_tensor): ...@@ -448,10 +457,10 @@ def angle(complex_tensor):
r"""Compute the angle of complex tensor input. r"""Compute the angle of complex tensor input.
Args: Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
Return: Return:
torch.Tensor: Angle of a complex tensor. Shape of `(*, )` torch.Tensor: Angle of a complex tensor. Shape of `(..., )`
""" """
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
...@@ -459,10 +468,10 @@ def angle(complex_tensor): ...@@ -459,10 +468,10 @@ def angle(complex_tensor):
@torch.jit.script @torch.jit.script
def magphase(complex_tensor, power=1.0): def magphase(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tuple[Tensor, Tensor] # type: (Tensor, float) -> Tuple[Tensor, Tensor]
r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase. r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
Args: Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`) power (float): Power of the norm. (Default: `1.0`)
Returns: Returns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment