Unverified Commit 1500d4ef authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Complex STFT transform from spectrogram (#327)

* STFT transform and function from #285

* merge options in existing functionality.

* remove dimension 2 check. add test.

* using ...

* update spectrogram test.
parent 5211b843
......@@ -313,6 +313,18 @@ class Tester(unittest.TestCase):
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
def test_batch_spectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
# Single then transform then batch
expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)
# Batch then transform
computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
if __name__ == '__main__':
unittest.main()
......@@ -96,8 +96,7 @@ def istft(
Args:
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. it has a size of either (channel, fft_size, n_frame, 2) or (
fft_size, n_frame, 2)
column is a window. it has a size of either (..., fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``)
......@@ -218,14 +217,15 @@ def istft(
def spectrogram(
waveform, pad, window, n_fft, hop_length, win_length, power, normalized
):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
# type: (Tensor, int, Tensor, int, int, int, Optional[int], bool) -> Tensor
r"""
spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)
Create a spectrogram from a raw audio signal.
Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
pad (int): Two sided padding of signal
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
......@@ -233,27 +233,36 @@ def spectrogram(
win_length (int): Window size
power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead.
normalized (bool): Whether to normalize by magnitude after stft
Returns:
torch.Tensor: Dimension (channel, freq, time), where channel
torch.Tensor: Dimension (..., channel, freq, time), where channel
is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
assert waveform.dim() == 2
if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it
waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = _stft(
waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
)
# unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
if normalized:
spec_f /= window.pow(2).sum().sqrt()
if power is not None:
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor
return spec_f
......@@ -431,11 +440,11 @@ def complex_norm(complex_tensor, power=1.0):
r"""Compute the norm of complex tensor input.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`).
Returns:
torch.Tensor: Power of the normed input tensor. Shape of `(*, )`
torch.Tensor: Power of the normed input tensor. Shape of `(..., )`
"""
if power == 1.0:
return torch.norm(complex_tensor, 2, -1)
......@@ -448,10 +457,10 @@ def angle(complex_tensor):
r"""Compute the angle of complex tensor input.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
Return:
torch.Tensor: Angle of a complex tensor. Shape of `(*, )`
torch.Tensor: Angle of a complex tensor. Shape of `(..., )`
"""
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
......@@ -459,10 +468,10 @@ def angle(complex_tensor):
@torch.jit.script
def magphase(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tuple[Tensor, Tensor]
r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase.
r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`)
Returns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment