"src/vscode:/vscode.git/clone" did not exist on "43f1090a0f5879416d83e1b0991502a26fc27ec6"
Commit 2271a7ae authored by Kiran Sanjeevan's avatar Kiran Sanjeevan Committed by cpuhrsch
Browse files

torchaudio-contrib: Adding (some) functionals (#131)

parent dc452aab
...@@ -18,7 +18,7 @@ run_tests() { ...@@ -18,7 +18,7 @@ run_tests() {
for FILE in $TEST_FILES; do for FILE in $TEST_FILES; do
# run each file on a separate process. if one fails, just keep going and # run each file on a separate process. if one fails, just keep going and
# return the final exit status. # return the final exit status.
python -m unittest -v $FILE python -m pytest -v $FILE
STATUS=$? STATUS=$?
EXIT_STATUS="$(($EXIT_STATUS+STATUS))" EXIT_STATUS="$(($EXIT_STATUS+STATUS))"
done done
......
...@@ -10,3 +10,6 @@ flake8 ...@@ -10,3 +10,6 @@ flake8
# Used for comparison of outputs in tests # Used for comparison of outputs in tests
librosa librosa
scipy scipy
# Unit tests with pytest
pytest
\ No newline at end of file
...@@ -3,7 +3,6 @@ from shutil import copytree ...@@ -3,7 +3,6 @@ from shutil import copytree
import tempfile import tempfile
import torch import torch
TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
......
...@@ -5,6 +5,16 @@ import torchaudio ...@@ -5,6 +5,16 @@ import torchaudio
import unittest import unittest
import test.common_utils import test.common_utils
from torchaudio.common_utils import IMPORT_LIBROSA
if IMPORT_LIBROSA:
import numpy as np
import librosa
import pytest
import torchaudio.functional as F
xfail = pytest.mark.xfail
class TestFunctional(unittest.TestCase): class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)] data_sizes = [(2, 20), (3, 15), (4, 10)]
...@@ -183,5 +193,109 @@ class TestFunctional(unittest.TestCase): ...@@ -183,5 +193,109 @@ class TestFunctional(unittest.TestCase):
self._test_istft_of_sine(amplitude=99, L=10, n=7) self._test_istft_of_sine(amplitude=99, L=10, n=7)
def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
@pytest.mark.parametrize('fft_length', [512])
@pytest.mark.parametrize('hop_length', [256])
@pytest.mark.parametrize('waveform', [
(torch.randn(1, 100000)),
(torch.randn(1, 2, 100000)),
pytest.param(torch.randn(1, 100), marks=xfail(raises=RuntimeError)),
])
@pytest.mark.parametrize('pad_mode', [
# 'constant',
'reflect',
])
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_stft(waveform, fft_length, hop_length, pad_mode):
"""
Test STFT for multi-channel signals.
Padding: Value in having padding outside of torch.stft?
"""
pad = fft_length // 2
window = torch.hann_window(fft_length)
complex_spec = F.stft(waveform,
fft_length=fft_length,
hop_length=hop_length,
window=window,
pad_mode=pad_mode)
mag_spec, phase_spec = F.magphase(complex_spec)
# == Test shape
expected_size = list(waveform.size()[:-1])
expected_size += [fft_length // 2 + 1, _num_stft_bins(
waveform.size(-1), fft_length, hop_length, pad), 2]
assert complex_spec.dim() == waveform.dim() + 2
assert complex_spec.size() == torch.Size(expected_size)
# == Test values
fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode)
# note that librosa *automatically* pad with fft_length // 2.
expected_complex_spec = np.apply_along_axis(librosa.stft, -1,
waveform.numpy(), **fft_config)
expected_mag_spec, _ = librosa.magphase(expected_complex_spec)
# Convert torch to np.complex
complex_spec = complex_spec.numpy()
complex_spec = complex_spec[..., 0] + 1j * complex_spec[..., 1]
assert np.allclose(complex_spec, expected_complex_spec, atol=1e-5)
assert np.allclose(mag_spec.numpy(), expected_mag_spec, atol=1e-5)
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('complex_specgrams', [
torch.randn(1, 2, 1025, 400, 2),
torch.randn(1, 1025, 400, 2)
])
@pytest.mark.parametrize('hop_length', [256])
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_phase_vocoder(complex_specgrams, rate, hop_length):
# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
complex_specgrams = complex_specgrams.type(torch.float64)
phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-3], dtype=torch.float64)[..., None]
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))
assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)
# == Test values
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram,
rate=rate,
hop_length=hop_length)
complex_stretch = complex_specgrams_stretch[index].numpy()
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]
assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5)
@pytest.mark.parametrize('complex_tensor', [
torch.randn(1, 2, 1025, 400, 2),
torch.randn(1025, 400, 2)
])
@pytest.mark.parametrize('power', [1, 2, 0.7])
def test_complex_norm(complex_tensor, power):
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
norm_tensor = F.complex_norm(complex_tensor, power)
assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -450,3 +450,161 @@ def mu_law_expanding(x_mu, qc): ...@@ -450,3 +450,161 @@ def mu_law_expanding(x_mu, qc):
x = ((x_mu) / mu) * 2 - 1. x = ((x_mu) / mu) * 2 - 1.
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
return x return x
def stft(waveforms, fft_length, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True):
"""Compute a short time Fourier transform of the input waveform(s).
It wraps `torch.stft` after reshaping the input audio to allow for `waveforms` that `.dim()` >= 3.
It follows most of the `torch.stft` default values, but for `window`, which defaults to hann window.
Args:
waveforms (torch.Tensor): Audio signal of size `(*, channel, time)`
fft_length (int): FFT size [sample].
hop_length (int): Hop size [sample] between STFT frames.
(Defaults to `fft_length // 4`, 75%-overlapping windows by `torch.stft`).
win_length (int): Size of STFT window. (Defaults to `fft_length` by `torch.stft`).
window (torch.Tensor): window function. (Defaults to Hann Window of size `win_length` *unlike* `torch.stft`).
center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered
at time `t * hop_length`. (Defaults to `True` by `torch.stft`)
pad_mode (str): padding method (see `torch.nn.functional.pad`). (Defaults to `'reflect'` by `torch.stft`).
normalized (bool): Whether the results are normalized. (Defaults to `False` by `torch.stft`).
onesided (bool): Whether the half + 1 frequency bins are returned to removethe symmetric part of STFT
of real-valued signal. (Defaults to `True` by `torch.stft`).
Returns:
torch.Tensor: `(*, channel, num_freqs, time, complex=2)`
Example:
>>> waveforms = torch.randn(16, 2, 10000) # (batch, channel, time)
>>> x = stft(waveforms, 2048, 512)
>>> x.shape
torch.Size([16, 2, 1025, 20])
"""
leading_dims = waveforms.shape[:-1]
waveforms = waveforms.reshape(-1, waveforms.size(-1))
if window is None:
if win_length is None:
window = torch.hann_window(fft_length)
else:
window = torch.hann_window(win_length)
complex_specgrams = torch.stft(waveforms,
n_fft=fft_length,
hop_length=hop_length,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided)
complex_specgrams = complex_specgrams.reshape(
leading_dims +
complex_specgrams.shape[1:])
return complex_specgrams
def complex_norm(complex_tensor, power=1.0):
"""Compute the norm of complex tensor input
Args:
complex_tensor (Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm. Defaults to `1.0`.
Returns:
Tensor: power of the normed input tensor, shape of `(*, )`
"""
if power == 1.0:
return torch.norm(complex_tensor, 2, -1)
return torch.norm(complex_tensor, 2, -1).pow(power)
def angle(complex_tensor):
"""
Return angle of a complex tensor with shape (*, 2).
"""
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
def magphase(complex_tensor, power=1.):
"""
Separate a complex-valued spectrogram with shape (*,2)
into its magnitude and phase.
"""
mag = complex_norm(complex_tensor, power)
phase = angle(complex_tensor)
return mag, phase
def phase_vocoder(complex_specgrams, rate, phase_advance):
"""
Phase vocoder. Given a STFT tensor, speed up in time
without modifying pitch by a factor of `rate`.
Args:
complex_specgrams (Tensor):
(*, channel, num_freqs, time, complex=2)
rate (float): Speed-up factor.
phase_advance (Tensor): Expected phase advance in
each bin. (num_freqs, 1).
Returns:
complex_specgrams_stretch (Tensor):
(*, channel, num_freqs, ceil(time/rate), complex=2).
Example:
>>> num_freqs, hop_length = 1025, 512
>>> # (batch, channel, num_freqs, time, complex=2)
>>> complex_specgrams = torch.randn(16, 1, num_freqs, 300, 2)
>>> rate = 1.3 # Slow down by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, num_freqs)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([16, 1, 1025, 231, 2])
"""
ndim = complex_specgrams.dim()
time_slice = [slice(None)] * (ndim - 2)
time_steps = torch.arange(0,
complex_specgrams.size(-2),
rate,
device=complex_specgrams.device,
dtype=complex_specgrams.dtype)
alphas = time_steps % 1.
phase_0 = angle(complex_specgrams[time_slice + [slice(1)]])
# Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])
# (new_bins, num_freqs, 2)
complex_specgrams_0 = complex_specgrams[time_slice + [time_steps.long()]]
complex_specgrams_1 = complex_specgrams[time_slice + [(time_steps + 1).long()]]
angle_0 = angle(complex_specgrams_0)
angle_1 = angle(complex_specgrams_1)
norm_0 = torch.norm(complex_specgrams_0, dim=-1)
norm_1 = torch.norm(complex_specgrams_1, dim=-1)
phase = angle_1 - angle_0 - phase_advance
phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))
# Compute Phase Accum
phase = phase + phase_advance
phase = torch.cat([phase_0, phase[time_slice + [slice(-1)]]], dim=-1)
phase_acc = torch.cumsum(phase, -1)
mag = alphas * norm_1 + (1 - alphas) * norm_0
real_stretch = mag * torch.cos(phase_acc)
imag_stretch = mag * torch.sin(phase_acc)
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
return complex_specgrams_stretch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment