Unverified Commit 0433b7aa authored by moto's avatar moto Committed by GitHub
Browse files

Make `F.phase_vocoder` and `T.TimeStretch` handle complex dtype (#1410)

1. `F.phase_vocoder` accepts Tensor with complex dtype.
    * The implementation path has been updated from #758 so that they share the same code path by internally converting the input Tensor to complex dtype and performing all the operation on top of it.
    * Adopted `torch.polar` for simpler Tensor generation from magnitude and angle.
2. Updated tests
    * librosa compatibility test for complex dtype and pseudo complex dtype
        * Extracted the output shape check test and moved it to functional so that it will be tested on all the combination of `{CPU | CUDA} x {complex64 | complex128}`
    * TorchScript compatibility test for `F.phase_vocoder` and `T.TimeStretch`.
    * batch consistency test for `T.TimeStretch`.
parent a6cdd6c7
...@@ -14,7 +14,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -14,7 +14,7 @@ from torchaudio_unittest.common_utils import (
skipIfNoSox, skipIfNoSox,
) )
from .functional_impl import Lfilter, Spectrogram from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
...@@ -41,6 +41,18 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): ...@@ -41,6 +41,18 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
device = torch.device('cpu') device = torch.device('cpu')
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
class TestCreateFBMatrix(common_utils.TorchaudioTestCase): class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self): def test_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import unittest import unittest
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
...@@ -31,3 +31,17 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase): ...@@ -31,3 +31,17 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized from parameterized import parameterized
import numpy as np
from scipy import signal from scipy import signal
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import nested_params
class Lfilter(common_utils.TestBaseMixin): class Lfilter(common_utils.TestBaseMixin):
...@@ -89,3 +91,39 @@ class Spectrogram(common_utils.TestBaseMixin): ...@@ -89,3 +91,39 @@ class Spectrogram(common_utils.TestBaseMixin):
) )
spec.sum().backward() spec.sum().backward()
assert not x.grad.isnan().sum() assert not x.grad.isnan().sum()
class FunctionalComplex(common_utils.TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None
@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2
torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape
import itertools
import unittest import unittest
from distutils.version import StrictVersion from distutils.version import StrictVersion
...@@ -15,6 +14,9 @@ if LIBROSA_AVAILABLE: ...@@ -15,6 +14,9 @@ if LIBROSA_AVAILABLE:
import librosa import librosa
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
nested_params,
)
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
...@@ -130,45 +132,36 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -130,45 +132,36 @@ class TestFunctional(common_utils.TorchaudioTestCase):
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestPhaseVocoder(common_utils.TorchaudioTestCase): class TestFunctionalComplex(common_utils.TorchaudioTestCase):
@parameterized.expand(list(itertools.product( @nested_params(
[(2, 1025, 400, 2)],
[0.5, 1.01, 1.3], [0.5, 1.01, 1.3],
[256] [True, False],
))) )
def test_phase_vocoder(self, shape, rate, hop_length): def test_phase_vocoder(self, rate, test_pseudo_complex):
hop_length = 256
num_freq = 1025
num_frames = 400
torch.random.manual_seed(42)
# Due to cummulative sum, numerical error in using torch.float32 will # Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not # result in bottom right values of the stretched sectrogram to not
# match with librosa. # match with librosa.
torch.random.manual_seed(42) spec = torch.randn(num_freq, num_frames, dtype=torch.complex128)
complex_specgrams = torch.randn(*shape)
complex_specgrams = complex_specgrams.type(torch.float64)
phase_advance = torch.linspace( phase_advance = torch.linspace(
0, 0,
np.pi * hop_length, np.pi * hop_length,
complex_specgrams.shape[-3], num_freq,
dtype=torch.float64)[..., None] dtype=torch.float64)[..., None]
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance) stretched = F.phase_vocoder(
torch.view_as_real(spec) if test_pseudo_complex else spec,
rate=rate, phase_advance=phase_advance)
# == Test shape expected_stretched = librosa.phase_vocoder(
expected_size = list(complex_specgrams.size()) spec.numpy(),
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))
assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
assert complex_specgrams_stretch.size() == torch.Size(expected_size)
# == Test values
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
mono_complex_specgram = complex_specgrams[index].numpy()
mono_complex_specgram = mono_complex_specgram[..., 0] + \
mono_complex_specgram[..., 1] * 1j
expected_complex_stretch = librosa.phase_vocoder(
mono_complex_specgram,
rate=rate, rate=rate,
hop_length=hop_length) hop_length=hop_length)
complex_stretch = complex_specgrams_stretch[index].numpy() self.assertEqual(
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1] torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
torch.from_numpy(expected_stretched))
self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional from .torchscript_consistency_impl import Functional, FunctionalComplex
class TestFunctionalFloat32(Functional, PytorchTestCase): class TestFunctionalFloat32(Functional, PytorchTestCase):
...@@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): ...@@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device('cpu')
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional from .torchscript_consistency_impl import Functional, FunctionalComplex
@skipIfNoCuda @skipIfNoCuda
...@@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase): ...@@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device('cuda')
@skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
@skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
...@@ -3,6 +3,7 @@ import unittest ...@@ -3,6 +3,7 @@ import unittest
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
...@@ -551,21 +552,6 @@ class Functional(common_utils.TestBaseMixin): ...@@ -551,21 +552,6 @@ class Functional(common_utils.TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100) tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
def test_phase_vocoder(self):
def func(tensor, device: torch.device = self.device):
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
tensor.shape[-3],
dtype=torch.float64,
).to(device)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)
tensor = torch.randn(2, 1025, 400, 2)
self._assert_consistency(func, tensor)
@common_utils.skipIfNoKaldi @common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self): def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'): if self.dtype != torch.float32 or self.device != torch.device('cpu'):
...@@ -577,3 +563,40 @@ class Functional(common_utils.TestBaseMixin): ...@@ -577,3 +563,40 @@ class Functional(common_utils.TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100) tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
class FunctionalComplex:
complex_dtype = None
real_dtype = None
device = None
def _assert_consistency(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch.jit.script(func)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)
@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor):
is_complex = tensor.is_complex()
n_freq = tensor.size(-2 if is_complex else -3)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
device=tensor.device,
)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency(func, tensor, test_paseudo_complex)
"""Test numerical consistency among single input and batched input.""" """Test numerical consistency among single input and batched input."""
import torch import torch
import torchaudio import torchaudio
from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
...@@ -130,40 +131,31 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -130,40 +131,31 @@ class TestTransforms(common_utils.TorchaudioTestCase):
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5) self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
def test_batch_TimeStretch(self): @parameterized.expand([(True, ), (False, )])
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') def test_batch_TimeStretch(self, test_pseudo_complex):
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
rate = 2 rate = 2
num_freq = 1025
num_frames = 400
complex_specgrams = torch.view_as_real( spec = torch.randn(num_freq, num_frames, dtype=torch.complex64)
torch.stft( pattern = [3, 1, 1, 1]
input=waveform, if test_pseudo_complex:
n_fft=2048, spec = torch.view_as_real(spec)
hop_length=512, pattern += [1]
win_length=2048,
window=torch.hann_window(2048),
center=True,
pad_mode='reflect',
normalized=True,
onesided=True,
return_complex=True,
)
)
# Single then transform then batch # Single then transform then batch
expected = torchaudio.transforms.TimeStretch( expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate, fixed_rate=rate,
n_freq=1025, n_freq=num_freq,
hop_length=512, hop_length=512,
)(complex_specgrams).repeat(3, 1, 1, 1, 1) )(spec).repeat(*pattern)
# Batch then transform # Batch then transform
computed = torchaudio.transforms.TimeStretch( computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate, fixed_rate=rate,
n_freq=1025, n_freq=num_freq,
hop_length=512, hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1)) )(spec.repeat(*pattern))
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5) self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms from .torchscript_consistency_impl import Transforms, TransformsComplex
class TestTransformsFloat32(Transforms, PytorchTestCase): class TestTransformsFloat32(Transforms, PytorchTestCase):
...@@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase): ...@@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device('cpu')
class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')
class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms from .torchscript_consistency_impl import Transforms, TransformsComplex
@skipIfNoCuda @skipIfNoCuda
...@@ -14,3 +14,17 @@ class TestTransformsFloat32(Transforms, PytorchTestCase): ...@@ -14,3 +14,17 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device('cuda')
@skipIfNoCuda
class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
@skipIfNoCuda
class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import torch import torch
import torchaudio.transforms as T import torchaudio.transforms as T
from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
...@@ -62,16 +63,6 @@ class Transforms(common_utils.TestBaseMixin): ...@@ -62,16 +63,6 @@ class Transforms(common_utils.TestBaseMixin):
tensor = torch.rand((1, 10)) tensor = torch.rand((1, 10))
self._assert_consistency(T.MuLawDecoding(), tensor) self._assert_consistency(T.MuLawDecoding(), tensor)
def test_TimeStretch(self):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
)
def test_Fade(self): def test_Fade(self):
waveform = common_utils.get_whitenoise() waveform = common_utils.get_whitenoise()
fade_in_len = 3000 fade_in_len = 3000
...@@ -103,3 +94,34 @@ class Transforms(common_utils.TestBaseMixin): ...@@ -103,3 +94,34 @@ class Transforms(common_utils.TestBaseMixin):
sample_rate = 44100 sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate) waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform) self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)
class TransformsComplex:
complex_dtype = None
real_dtype = None
device = None
def _assert_consistency(self, transform, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.real_dtype)
ts_transform = torch.jit.script(transform)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)
@parameterized.expand([(True, ), (False, )])
def test_TimeStretch(self, test_pseudo_complex):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2)))
self._assert_consistency(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
test_pseudo_complex
)
...@@ -565,14 +565,29 @@ def phase_vocoder( ...@@ -565,14 +565,29 @@ def phase_vocoder(
factor of ``rate``. factor of ``rate``.
Args: Args:
complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)` complex_specgrams (Tensor):
Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
rate (float): Speed-up factor rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1) phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Returns: Returns:
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)` Tensor:
Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Example Example - With Tensor of complex dtype
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time)
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231])
Example - With Tensor of real dtype and extra dimension for complex field
>>> freq, hop_length = 1025, 512 >>> freq, hop_length = 1025, 512
>>> # (channel, freq, time, complex=2) >>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2) >>> complex_specgrams = torch.randn(2, freq, 300, 2)
...@@ -583,32 +598,48 @@ def phase_vocoder( ...@@ -583,32 +598,48 @@ def phase_vocoder(
>>> x.shape # with 231 == ceil(300 / 1.3) >>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231, 2]) torch.Size([2, 1025, 231, 2])
""" """
if rate == 1.0:
return complex_specgrams
if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2:
raise ValueError(
"complex_specgrams must be either native complex tensors or "
"real valued tensors with shape (..., 2)")
is_complex = complex_specgrams.is_complex()
if not is_complex:
complex_specgrams = torch.view_as_complex(complex_specgrams)
# pack batch # pack batch
shape = complex_specgrams.size() shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
time_steps = torch.arange(0, # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
complex_specgrams.size(-2), # Note torch.real is a view so it does not incur any memory copy.
rate, real_dtype = torch.real(complex_specgrams).dtype
device=complex_specgrams.device, time_steps = torch.arange(
dtype=complex_specgrams.dtype) 0,
complex_specgrams.size(-1),
rate,
device=complex_specgrams.device,
dtype=real_dtype)
alphas = time_steps % 1.0 alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[..., :1, :]) phase_0 = complex_specgrams[..., :1].angle()
# Time Padding # Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
# (new_bins, freq, 2) # (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long()) complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long()) complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())
angle_0 = angle(complex_specgrams_0) angle_0 = complex_specgrams_0.angle()
angle_1 = angle(complex_specgrams_1) angle_1 = complex_specgrams_1.angle()
norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1) norm_0 = complex_specgrams_0.abs()
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1) norm_1 = complex_specgrams_1.abs()
phase = angle_1 - angle_0 - phase_advance phase = angle_1 - angle_0 - phase_advance
phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi)) phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))
...@@ -620,14 +651,13 @@ def phase_vocoder( ...@@ -620,14 +651,13 @@ def phase_vocoder(
mag = alphas * norm_1 + (1 - alphas) * norm_0 mag = alphas * norm_1 + (1 - alphas) * norm_0
real_stretch = mag * torch.cos(phase_acc) complex_specgrams_stretch = torch.polar(mag, phase_acc)
imag_stretch = mag * torch.sin(phase_acc)
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
# unpack batch # unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:]) complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
if not is_complex:
return torch.view_as_real(complex_specgrams_stretch)
return complex_specgrams_stretch return complex_specgrams_stretch
......
...@@ -729,26 +729,24 @@ class TimeStretch(torch.nn.Module): ...@@ -729,26 +729,24 @@ class TimeStretch(torch.nn.Module):
def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor: def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
r""" r"""
Args: Args:
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2). complex_specgrams (Tensor):
Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
overriding_rate (float or None, optional): speed up to apply to this batch. overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``) If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
Returns: Returns:
Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2). Tensor:
Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
""" """
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
if overriding_rate is None: if overriding_rate is None:
if self.fixed_rate is None:
raise ValueError(
"If no fixed_rate is specified, must pass a valid rate to the forward method.")
rate = self.fixed_rate rate = self.fixed_rate
if rate is None:
raise ValueError("If no fixed_rate is specified"
", must pass a valid rate to the forward method.")
else: else:
rate = overriding_rate rate = overriding_rate
if rate == 1.0:
return complex_specgrams
return F.phase_vocoder(complex_specgrams, rate, self.phase_advance) return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment