Unverified Commit 0433b7aa authored by moto's avatar moto Committed by GitHub
Browse files

Make `F.phase_vocoder` and `T.TimeStretch` handle complex dtype (#1410)

1. `F.phase_vocoder` accepts Tensor with complex dtype.
    * The implementation path has been updated from #758 so that they share the same code path by internally converting the input Tensor to complex dtype and performing all the operation on top of it.
    * Adopted `torch.polar` for simpler Tensor generation from magnitude and angle.
2. Updated tests
    * librosa compatibility test for complex dtype and pseudo complex dtype
        * Extracted the output shape check test and moved it to functional so that it will be tested on all the combination of `{CPU | CUDA} x {complex64 | complex128}`
    * TorchScript compatibility test for `F.phase_vocoder` and `T.TimeStretch`.
    * batch consistency test for `T.TimeStretch`.
parent a6cdd6c7
......@@ -14,7 +14,7 @@ from torchaudio_unittest.common_utils import (
skipIfNoSox,
)
from .functional_impl import Lfilter, Spectrogram
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
......@@ -41,6 +41,18 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
device = torch.device('cpu')
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
......
......@@ -2,7 +2,7 @@ import torch
import unittest
from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
@common_utils.skipIfNoCuda
......@@ -31,3 +31,17 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
......@@ -2,9 +2,11 @@
import torch
import torchaudio.functional as F
from parameterized import parameterized
import numpy as np
from scipy import signal
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import nested_params
class Lfilter(common_utils.TestBaseMixin):
......@@ -89,3 +91,39 @@ class Spectrogram(common_utils.TestBaseMixin):
)
spec.sum().backward()
assert not x.grad.isnan().sum()
class FunctionalComplex(common_utils.TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None
@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2
torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape
import itertools
import unittest
from distutils.version import StrictVersion
......@@ -15,6 +14,9 @@ if LIBROSA_AVAILABLE:
import librosa
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
nested_params,
)
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
......@@ -130,45 +132,36 @@ class TestFunctional(common_utils.TorchaudioTestCase):
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
[(2, 1025, 400, 2)],
class TestFunctionalComplex(common_utils.TorchaudioTestCase):
@nested_params(
[0.5, 1.01, 1.3],
[256]
)))
def test_phase_vocoder(self, shape, rate, hop_length):
[True, False],
)
def test_phase_vocoder(self, rate, test_pseudo_complex):
hop_length = 256
num_freq = 1025
num_frames = 400
torch.random.manual_seed(42)
# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
torch.random.manual_seed(42)
complex_specgrams = torch.randn(*shape)
complex_specgrams = complex_specgrams.type(torch.float64)
spec = torch.randn(num_freq, num_frames, dtype=torch.complex128)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
complex_specgrams.shape[-3],
num_freq,
dtype=torch.float64)[..., None]
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
# == Test shape
expected_size = list(complex_specgrams.size())
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))
stretched = F.phase_vocoder(
torch.view_as_real(spec) if test_pseudo_complex else spec,
rate=rate, phase_advance=phase_advance)
assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)
# == Test values
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
expected_complex_stretch = librosa.phase_vocoder(
mono_complex_specgram,
expected_stretched = librosa.phase_vocoder(
spec.numpy(),
rate=rate,
hop_length=hop_length)
complex_stretch = complex_specgrams_stretch[index].numpy()
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]
self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)
self.assertEqual(
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
torch.from_numpy(expected_stretched))
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex
class TestFunctionalFloat32(Functional, PytorchTestCase):
......@@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional
from .torchscript_consistency_impl import Functional, FunctionalComplex
@skipIfNoCuda
......@@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
@skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
......@@ -3,6 +3,7 @@ import unittest
import torch
import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
......@@ -551,21 +552,6 @@ class Functional(common_utils.TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)
def test_phase_vocoder(self):
def func(tensor, device: torch.device = self.device):
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
tensor.shape[-3],
dtype=torch.float64,
).to(device)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)
tensor = torch.randn(2, 1025, 400, 2)
self._assert_consistency(func, tensor)
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
......@@ -577,3 +563,40 @@ class Functional(common_utils.TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)
class FunctionalComplex:
complex_dtype = None
real_dtype = None
device = None
def _assert_consistency(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch.jit.script(func)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)
@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor):
is_complex = tensor.is_complex()
n_freq = tensor.size(-2 if is_complex else -3)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
device=tensor.device,
)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency(func, tensor, test_paseudo_complex)
"""Test numerical consistency among single input and batched input."""
import torch
import torchaudio
from parameterized import parameterized
from torchaudio_unittest import common_utils
......@@ -130,40 +131,31 @@ class TestTransforms(common_utils.TorchaudioTestCase):
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
@parameterized.expand([(True, ), (False, )])
def test_batch_TimeStretch(self, test_pseudo_complex):
rate = 2
num_freq = 1025
num_frames = 400
complex_specgrams = torch.view_as_real(
torch.stft(
input=waveform,
n_fft=2048,
hop_length=512,
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)
)
spec = torch.randn(num_freq, num_frames, dtype=torch.complex64)
pattern = [3, 1, 1, 1]
if test_pseudo_complex:
spec = torch.view_as_real(spec)
pattern += [1]
# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
n_freq=num_freq,
hop_length=512,
)(complex_specgrams).repeat(3, 1, 1, 1, 1)
)(spec).repeat(*pattern)
# Batch then transform
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
n_freq=num_freq,
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
)(spec.repeat(*pattern))
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms
from .torchscript_consistency_impl import Transforms, TransformsComplex
class TestTransformsFloat32(Transforms, PytorchTestCase):
......@@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')
class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms
from .torchscript_consistency_impl import Transforms, TransformsComplex
@skipIfNoCuda
......@@ -14,3 +14,17 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@skipIfNoCuda
class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
@skipIfNoCuda
class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
......@@ -2,6 +2,7 @@
import torch
import torchaudio.transforms as T
from parameterized import parameterized
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
......@@ -62,16 +63,6 @@ class Transforms(common_utils.TestBaseMixin):
tensor = torch.rand((1, 10))
self._assert_consistency(T.MuLawDecoding(), tensor)
def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)
def test_Fade(self):
waveform = common_utils.get_whitenoise()
fade_in_len = 3000
......@@ -103,3 +94,34 @@ class Transforms(common_utils.TestBaseMixin):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)
class TransformsComplex:
complex_dtype = None
real_dtype = None
device = None
def _assert_consistency(self, transform, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.real_dtype)
ts_transform = torch.jit.script(transform)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)
@parameterized.expand([(True, ), (False, )])
def test_TimeStretch(self, test_pseudo_complex):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2)))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
test_pseudo_complex
)
......@@ -565,14 +565,29 @@ def phase_vocoder(
factor of ``rate``.
Args:
complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)`
complex_specgrams (Tensor):
Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Returns:
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)`
Tensor:
Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Example
Example - With Tensor of complex dtype
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time)
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231])
Example - With Tensor of real dtype and extra dimension for complex field
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
......@@ -583,32 +598,48 @@ def phase_vocoder(
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231, 2])
"""
if rate == 1.0:
return complex_specgrams
if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2:
raise ValueError(
"complex_specgrams must be either native complex tensors or "
"real valued tensors with shape (..., 2)")
is_complex = complex_specgrams.is_complex()
if not is_complex:
complex_specgrams = torch.view_as_complex(complex_specgrams)
# pack batch
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
time_steps = torch.arange(0,
complex_specgrams.size(-2),
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
# Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
# Note torch.real is a view so it does not incur any memory copy.
real_dtype = torch.real(complex_specgrams).dtype
time_steps = torch.arange(
0,
complex_specgrams.size(-1),
rate,
device=complex_specgrams.device,
dtype=complex_specgrams.dtype)
dtype=real_dtype)
alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[..., :1, :])
phase_0 = complex_specgrams[..., :1].angle()
# Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
# (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long())
complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long())
complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())
angle_0 = angle(complex_specgrams_0)
angle_1 = angle(complex_specgrams_1)
angle_0 = complex_specgrams_0.angle()
angle_1 = complex_specgrams_1.angle()
norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1)
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1)
norm_0 = complex_specgrams_0.abs()
norm_1 = complex_specgrams_1.abs()
phase = angle_1 - angle_0 - phase_advance
phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))
......@@ -620,14 +651,13 @@ def phase_vocoder(
mag = alphas * norm_1 + (1 - alphas) * norm_0
real_stretch = mag * torch.cos(phase_acc)
imag_stretch = mag * torch.sin(phase_acc)
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
complex_specgrams_stretch = torch.polar(mag, phase_acc)
# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
if not is_complex:
return torch.view_as_real(complex_specgrams_stretch)
return complex_specgrams_stretch
......
......@@ -729,26 +729,24 @@ class TimeStretch(torch.nn.Module):
def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
complex_specgrams (Tensor):
Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
Returns:
Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
Tensor:
Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
if overriding_rate is None:
if self.fixed_rate is None:
raise ValueError(
"If no fixed_rate is specified, must pass a valid rate to the forward method.")
rate = self.fixed_rate
if rate is None:
raise ValueError("If no fixed_rate is specified"
", must pass a valid rate to the forward method.")
else:
rate = overriding_rate
if rate == 1.0:
return complex_specgrams
return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment