Unverified Commit d3e146fd authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Drop pseudo complex support from phase_vocoder / TimeStretch (#1957)

Following the plan #1337, this commit drops the support for pseudo complex type from `F.phase_vocoder` and `T.TimeStretch`.
parent 5ec6ada6
......@@ -429,11 +429,8 @@ class Functional(TestBaseMixin):
def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
@nested_params([0.5, 1.01, 1.3])
def test_phase_vocoder_shape(self, rate):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
......@@ -443,8 +440,6 @@ class Functional(TestBaseMixin):
torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)
phase_advance = torch.linspace(
0,
......@@ -456,7 +451,7 @@ class Functional(TestBaseMixin):
assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
output_shape = spec_stretch.shape
assert output_shape == expected_shape
@parameterized.expand(
......
......@@ -126,11 +126,8 @@ class Functional(TestBaseMixin):
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class FunctionalComplex(TestBaseMixin):
@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder(self, rate, test_pseudo_complex):
@nested_params([0.5, 1.01, 1.3])
def test_phase_vocoder(self, rate):
hop_length = 256
num_freq = 1025
num_frames = 400
......@@ -147,15 +144,11 @@ class FunctionalComplex(TestBaseMixin):
device=self.device,
dtype=torch.float64)[..., None]
stretched = F.phase_vocoder(
torch.view_as_real(spec) if test_pseudo_complex else spec,
rate=rate, phase_advance=phase_advance)
stretched = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
expected_stretched = librosa.phase_vocoder(
spec.cpu().numpy(),
rate=rate,
hop_length=hop_length)
self.assertEqual(
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
torch.from_numpy(expected_stretched))
self.assertEqual(stretched, torch.from_numpy(expected_stretched))
......@@ -3,7 +3,6 @@ import unittest
import torch
import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
......@@ -31,14 +30,11 @@ class Functional(TempDirMixin, TestBaseMixin):
output = output.shape
self.assertEqual(ts_output, output)
def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):
def _assert_consistency_complex(self, func, tensor):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch_script(func)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
torch.random.manual_seed(40)
output = func(tensor)
......@@ -641,25 +637,22 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, tensor)
self._assert_consistency(func_beta, tensor)
@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def test_phase_vocoder(self):
def func(tensor):
is_complex = tensor.is_complex()
n_freq = tensor.size(-2 if is_complex else -3)
n_freq = tensor.size(-2)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
dtype=torch.real(tensor).dtype,
device=tensor.device,
)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency_complex(func, tensor, test_paseudo_complex)
self._assert_consistency_complex(func, tensor)
class FunctionalFloat32Only(TestBaseMixin):
......
......@@ -226,11 +226,8 @@ class AutogradTestMixin(TestBaseMixin):
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
self.assert_grad(transform, [spectrogram])
@nested_params(
[0.7, 0.8, 0.9, 1.0, 1.3],
[False, True],
)
def test_timestretch_non_zero(self, rate, test_pseudo_complex):
@nested_params([0.7, 0.8, 0.9, 1.0, 1.3])
def test_timestretch_non_zero(self, rate):
"""Verify that ``T.TimeStretch`` does not fail if it's not close to 0
``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
......@@ -254,8 +251,6 @@ class AutogradTestMixin(TestBaseMixin):
epsilon = 1e-2
too_close = spectrogram.abs() < epsilon
spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs()
if test_pseudo_complex:
spectrogram = torch.view_as_real(spectrogram)
self.assert_grad(transform, [spectrogram])
def test_psd(self):
......
......@@ -124,20 +124,16 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5)
@parameterized.expand([(True, ), (False, )])
def test_batch_TimeStretch(self, test_pseudo_complex):
def test_batch_TimeStretch(self):
rate = 2
num_freq = 1025
num_frames = 400
batch = 3
spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64)
if test_pseudo_complex:
spec = torch.view_as_real(spec)
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch)
spec = common_utils.get_spectrogram(tensor, n_fft=num_freq)
transform = T.TimeStretch(
fixed_rate=rate,
n_freq=num_freq,
n_freq=num_freq // 2 + 1,
hop_length=512
)
......
......@@ -24,15 +24,13 @@ class Transforms(TestBaseMixin):
ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output)
def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False, *args):
def _assert_consistency_complex(self, transform, tensor, *args):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
ts_transform = torch_script(transform)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = transform(tensor, *args)
ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output)
......@@ -120,16 +118,21 @@ class Transforms(TestBaseMixin):
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)
@parameterized.expand([(True, ), (False, )])
def test_TimeStretch(self, test_pseudo_complex):
n_freq = 400
def test_TimeStretch(self):
n_fft = 1025
n_freq = n_fft // 2 + 1
hop_length = 512
fixed_rate = 1.3
tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2)))
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cfloat)
batch = 10
num_channels = 2
waveform = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch * num_channels)
tensor = common_utils.get_spectrogram(waveform, n_fft=n_fft)
tensor = tensor.reshape(batch, num_channels, n_freq, -1)
self._assert_consistency_complex(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
test_pseudo_complex
)
def test_PitchShift(self):
......@@ -152,7 +155,7 @@ class Transforms(TestBaseMixin):
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(self.device)
mask = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.PSD(), spectrogram, False, mask)
self._assert_consistency_complex(T.PSD(), spectrogram, mask)
class TransformsFloat32Only(TestBaseMixin):
......@@ -188,5 +191,5 @@ class TransformsFloat64Only(TestBaseMixin):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(
T.MVDR(solution=solution, online=online),
spectrogram, False, mask_s, mask_n
spectrogram, mask_s, mask_n
)
......@@ -714,8 +714,7 @@ def phase_vocoder(
Args:
complex_specgrams (Tensor):
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
A tensor of dimension `(..., freq, num_frame)` with complex dtype.
rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)`
......@@ -724,7 +723,7 @@ def phase_vocoder(
Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Example - With Tensor of complex dtype
Example
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time)
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
......@@ -734,41 +733,10 @@ def phase_vocoder(
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231])
Example - With Tensor of real dtype and extra dimension for complex field
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231, 2])
"""
if rate == 1.0:
return complex_specgrams
if not complex_specgrams.is_complex():
warnings.warn(
"The support for pseudo complex type in `torchaudio.functional.phase_vocoder` and "
"`torchaudio.transforms.TimeStretch` is now deprecated and will be removed "
"from 0.11 release."
"Please migrate to native complex type by converting the input tensor with "
"`torch.view_as_complex`. "
"Please refer to https://github.com/pytorch/audio/issues/1337 "
"for more details about torchaudio's plan to migrate to native complex type."
)
if complex_specgrams.size(-1) != 2:
raise ValueError(
"complex_specgrams must be either native complex tensors or "
"real valued tensors with shape (..., 2)")
is_complex = complex_specgrams.is_complex()
if not is_complex:
complex_specgrams = torch.view_as_complex(complex_specgrams)
# pack batch
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
......@@ -813,9 +781,6 @@ def phase_vocoder(
# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
if not is_complex:
return torch.view_as_real(complex_specgrams_stretch)
return complex_specgrams_stretch
......
......@@ -972,8 +972,7 @@ class TimeStretch(torch.nn.Module):
r"""
Args:
complex_specgrams (Tensor):
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
A tensor of dimension `(..., freq, num_frame)` with complex dtype.
overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment