You need to sign in or sign up before continuing.
Unverified Commit d3e146fd authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Drop pseudo complex support from phase_vocoder / TimeStretch (#1957)

Following the plan #1337, this commit drops the support for pseudo complex type from `F.phase_vocoder` and `T.TimeStretch`.
parent 5ec6ada6
...@@ -429,11 +429,8 @@ class Functional(TestBaseMixin): ...@@ -429,11 +429,8 @@ class Functional(TestBaseMixin):
def test_resample_waveform_upsample_accuracy(self, resampling_method, i): def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method) self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
@nested_params( @nested_params([0.5, 1.01, 1.3])
[0.5, 1.01, 1.3], def test_phase_vocoder_shape(self, rate):
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder""" """Verify the output shape of phase vocoder"""
hop_length = 256 hop_length = 256
num_freq = 1025 num_freq = 1025
...@@ -443,8 +440,6 @@ class Functional(TestBaseMixin): ...@@ -443,8 +440,6 @@ class Functional(TestBaseMixin):
torch.random.manual_seed(42) torch.random.manual_seed(42)
spec = torch.randn( spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device) batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)
phase_advance = torch.linspace( phase_advance = torch.linspace(
0, 0,
...@@ -456,7 +451,7 @@ class Functional(TestBaseMixin): ...@@ -456,7 +451,7 @@ class Functional(TestBaseMixin):
assert spec.dim() == spec_stretch.dim() assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))]) expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape output_shape = spec_stretch.shape
assert output_shape == expected_shape assert output_shape == expected_shape
@parameterized.expand( @parameterized.expand(
......
...@@ -126,11 +126,8 @@ class Functional(TestBaseMixin): ...@@ -126,11 +126,8 @@ class Functional(TestBaseMixin):
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class FunctionalComplex(TestBaseMixin): class FunctionalComplex(TestBaseMixin):
@nested_params( @nested_params([0.5, 1.01, 1.3])
[0.5, 1.01, 1.3], def test_phase_vocoder(self, rate):
[True, False],
)
def test_phase_vocoder(self, rate, test_pseudo_complex):
hop_length = 256 hop_length = 256
num_freq = 1025 num_freq = 1025
num_frames = 400 num_frames = 400
...@@ -147,15 +144,11 @@ class FunctionalComplex(TestBaseMixin): ...@@ -147,15 +144,11 @@ class FunctionalComplex(TestBaseMixin):
device=self.device, device=self.device,
dtype=torch.float64)[..., None] dtype=torch.float64)[..., None]
stretched = F.phase_vocoder( stretched = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
torch.view_as_real(spec) if test_pseudo_complex else spec,
rate=rate, phase_advance=phase_advance)
expected_stretched = librosa.phase_vocoder( expected_stretched = librosa.phase_vocoder(
spec.cpu().numpy(), spec.cpu().numpy(),
rate=rate, rate=rate,
hop_length=hop_length) hop_length=hop_length)
self.assertEqual( self.assertEqual(stretched, torch.from_numpy(expected_stretched))
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
torch.from_numpy(expected_stretched))
...@@ -3,7 +3,6 @@ import unittest ...@@ -3,7 +3,6 @@ import unittest
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
...@@ -31,14 +30,11 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -31,14 +30,11 @@ class Functional(TempDirMixin, TestBaseMixin):
output = output.shape output = output.shape
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False): def _assert_consistency_complex(self, func, tensor):
assert tensor.is_complex() assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype) tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch_script(func) ts_func = torch_script(func)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
torch.random.manual_seed(40) torch.random.manual_seed(40)
output = func(tensor) output = func(tensor)
...@@ -641,25 +637,22 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -641,25 +637,22 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
self._assert_consistency(func_beta, tensor) self._assert_consistency(func_beta, tensor)
@parameterized.expand([(True, ), (False, )]) def test_phase_vocoder(self):
def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor): def func(tensor):
is_complex = tensor.is_complex() n_freq = tensor.size(-2)
n_freq = tensor.size(-2 if is_complex else -3)
rate = 0.5 rate = 0.5
hop_length = 256 hop_length = 256
phase_advance = torch.linspace( phase_advance = torch.linspace(
0, 0,
3.14 * hop_length, 3.14 * hop_length,
n_freq, n_freq,
dtype=(torch.real(tensor) if is_complex else tensor).dtype, dtype=torch.real(tensor).dtype,
device=tensor.device, device=tensor.device,
)[..., None] )[..., None]
return F.phase_vocoder(tensor, rate, phase_advance) return F.phase_vocoder(tensor, rate, phase_advance)
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2)) tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency_complex(func, tensor, test_paseudo_complex) self._assert_consistency_complex(func, tensor)
class FunctionalFloat32Only(TestBaseMixin): class FunctionalFloat32Only(TestBaseMixin):
......
...@@ -226,11 +226,8 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -226,11 +226,8 @@ class AutogradTestMixin(TestBaseMixin):
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
self.assert_grad(transform, [spectrogram]) self.assert_grad(transform, [spectrogram])
@nested_params( @nested_params([0.7, 0.8, 0.9, 1.0, 1.3])
[0.7, 0.8, 0.9, 1.0, 1.3], def test_timestretch_non_zero(self, rate):
[False, True],
)
def test_timestretch_non_zero(self, rate, test_pseudo_complex):
"""Verify that ``T.TimeStretch`` does not fail if it's not close to 0 """Verify that ``T.TimeStretch`` does not fail if it's not close to 0
``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability ``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
...@@ -254,8 +251,6 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -254,8 +251,6 @@ class AutogradTestMixin(TestBaseMixin):
epsilon = 1e-2 epsilon = 1e-2
too_close = spectrogram.abs() < epsilon too_close = spectrogram.abs() < epsilon
spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs() spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs()
if test_pseudo_complex:
spectrogram = torch.view_as_real(spectrogram)
self.assert_grad(transform, [spectrogram]) self.assert_grad(transform, [spectrogram])
def test_psd(self): def test_psd(self):
......
...@@ -124,20 +124,16 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -124,20 +124,16 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5) self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5)
@parameterized.expand([(True, ), (False, )]) def test_batch_TimeStretch(self):
def test_batch_TimeStretch(self, test_pseudo_complex):
rate = 2 rate = 2
num_freq = 1025 num_freq = 1025
num_frames = 400
batch = 3 batch = 3
spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64) tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch)
if test_pseudo_complex: spec = common_utils.get_spectrogram(tensor, n_fft=num_freq)
spec = torch.view_as_real(spec)
transform = T.TimeStretch( transform = T.TimeStretch(
fixed_rate=rate, fixed_rate=rate,
n_freq=num_freq, n_freq=num_freq // 2 + 1,
hop_length=512 hop_length=512
) )
......
...@@ -24,15 +24,13 @@ class Transforms(TestBaseMixin): ...@@ -24,15 +24,13 @@ class Transforms(TestBaseMixin):
ts_output = ts_transform(tensor, *args) ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False, *args): def _assert_consistency_complex(self, transform, tensor, *args):
assert tensor.is_complex() assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype) tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.dtype) transform = transform.to(device=self.device, dtype=self.dtype)
ts_transform = torch_script(transform) ts_transform = torch_script(transform)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = transform(tensor, *args) output = transform(tensor, *args)
ts_output = ts_transform(tensor, *args) ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
...@@ -120,16 +118,21 @@ class Transforms(TestBaseMixin): ...@@ -120,16 +118,21 @@ class Transforms(TestBaseMixin):
waveform = common_utils.get_whitenoise(sample_rate=sample_rate) waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform) self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)
@parameterized.expand([(True, ), (False, )]) def test_TimeStretch(self):
def test_TimeStretch(self, test_pseudo_complex): n_fft = 1025
n_freq = 400 n_freq = n_fft // 2 + 1
hop_length = 512 hop_length = 512
fixed_rate = 1.3 fixed_rate = 1.3
tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2))) tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cfloat)
batch = 10
num_channels = 2
waveform = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch * num_channels)
tensor = common_utils.get_spectrogram(waveform, n_fft=n_fft)
tensor = tensor.reshape(batch, num_channels, n_freq, -1)
self._assert_consistency_complex( self._assert_consistency_complex(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor, tensor,
test_pseudo_complex
) )
def test_PitchShift(self): def test_PitchShift(self):
...@@ -152,7 +155,7 @@ class Transforms(TestBaseMixin): ...@@ -152,7 +155,7 @@ class Transforms(TestBaseMixin):
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(self.device) spectrogram = spectrogram.to(self.device)
mask = torch.rand(spectrogram.shape[-2:], device=self.device) mask = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.PSD(), spectrogram, False, mask) self._assert_consistency_complex(T.PSD(), spectrogram, mask)
class TransformsFloat32Only(TestBaseMixin): class TransformsFloat32Only(TestBaseMixin):
...@@ -188,5 +191,5 @@ class TransformsFloat64Only(TestBaseMixin): ...@@ -188,5 +191,5 @@ class TransformsFloat64Only(TestBaseMixin):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device) mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex( self._assert_consistency_complex(
T.MVDR(solution=solution, online=online), T.MVDR(solution=solution, online=online),
spectrogram, False, mask_s, mask_n spectrogram, mask_s, mask_n
) )
...@@ -714,8 +714,7 @@ def phase_vocoder( ...@@ -714,8 +714,7 @@ def phase_vocoder(
Args: Args:
complex_specgrams (Tensor): complex_specgrams (Tensor):
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)` A tensor of dimension `(..., freq, num_frame)` with complex dtype.
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
rate (float): Speed-up factor rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)` phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)`
...@@ -724,7 +723,7 @@ def phase_vocoder( ...@@ -724,7 +723,7 @@ def phase_vocoder(
Stretched spectrogram. The resulting tensor is of the same dtype as the input Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Example - With Tensor of complex dtype Example
>>> freq, hop_length = 1025, 512 >>> freq, hop_length = 1025, 512
>>> # (channel, freq, time) >>> # (channel, freq, time)
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat) >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
...@@ -734,41 +733,10 @@ def phase_vocoder( ...@@ -734,41 +733,10 @@ def phase_vocoder(
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance) >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3) >>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231]) torch.Size([2, 1025, 231])
Example - With Tensor of real dtype and extra dimension for complex field
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231, 2])
""" """
if rate == 1.0: if rate == 1.0:
return complex_specgrams return complex_specgrams
if not complex_specgrams.is_complex():
warnings.warn(
"The support for pseudo complex type in `torchaudio.functional.phase_vocoder` and "
"`torchaudio.transforms.TimeStretch` is now deprecated and will be removed "
"from 0.11 release."
"Please migrate to native complex type by converting the input tensor with "
"`torch.view_as_complex`. "
"Please refer to https://github.com/pytorch/audio/issues/1337 "
"for more details about torchaudio's plan to migrate to native complex type."
)
if complex_specgrams.size(-1) != 2:
raise ValueError(
"complex_specgrams must be either native complex tensors or "
"real valued tensors with shape (..., 2)")
is_complex = complex_specgrams.is_complex()
if not is_complex:
complex_specgrams = torch.view_as_complex(complex_specgrams)
# pack batch # pack batch
shape = complex_specgrams.size() shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:])) complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
...@@ -813,9 +781,6 @@ def phase_vocoder( ...@@ -813,9 +781,6 @@ def phase_vocoder(
# unpack batch # unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:]) complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
if not is_complex:
return torch.view_as_real(complex_specgrams_stretch)
return complex_specgrams_stretch return complex_specgrams_stretch
......
...@@ -972,8 +972,7 @@ class TimeStretch(torch.nn.Module): ...@@ -972,8 +972,7 @@ class TimeStretch(torch.nn.Module):
r""" r"""
Args: Args:
complex_specgrams (Tensor): complex_specgrams (Tensor):
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)` A tensor of dimension `(..., freq, num_frame)` with complex dtype.
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
overriding_rate (float or None, optional): speed up to apply to this batch. overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``) If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment