Unverified Commit 78c3480e authored by moto's avatar moto Committed by GitHub
Browse files

Adopt native complex dtype in griffnlim (#1368)

parent 35d68fdd
...@@ -8,9 +8,23 @@ import torchaudio.transforms as T ...@@ -8,9 +8,23 @@ import torchaudio.transforms as T
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
get_whitenoise, get_whitenoise,
get_spectrogram,
nested_params,
) )
class _DeterministicWrapper(torch.nn.Module):
"""Helper transform wrapper to make the given transform deterministic"""
def __init__(self, transform, seed=0):
super().__init__()
self.seed = seed
self.transform = transform
def forward(self, input: torch.Tensor):
torch.random.manual_seed(self.seed)
return self.transform(input)
class AutogradTestMixin(TestBaseMixin): class AutogradTestMixin(TestBaseMixin):
def assert_grad( def assert_grad(
self, self,
...@@ -65,14 +79,20 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -65,14 +79,20 @@ class AutogradTestMixin(TestBaseMixin):
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10) self.assert_grad(transform, [waveform], nondet_tol=1e-10)
@parameterized.expand([(0, ), (0.99, )]) @nested_params(
def test_griffinlim(self, momentum): [0, 0.99],
[False, True],
)
def test_griffinlim(self, momentum, rand_init):
n_fft = 400 n_fft = 400
n_frames = 5 power = 1
n_iter = 3 n_iter = 3
spec = torch.rand(n_fft // 2 + 1, n_frames) * n_fft spec = get_spectrogram(
transform = T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=False) get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2),
self.assert_grad(transform, [spec], nondet_tol=1e-10) n_fft=n_fft, power=power)
transform = _DeterministicWrapper(
T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=rand_init, power=power))
self.assert_grad(transform, [spec])
@parameterized.expand([(False, ), (True, )]) @parameterized.expand([(False, ), (True, )])
def test_mfcc(self, log_mels): def test_mfcc(self, log_mels):
......
...@@ -125,6 +125,16 @@ def spectrogram( ...@@ -125,6 +125,16 @@ def spectrogram(
return spec_f return spec_f
def _get_complex_dtype(real_dtype: torch.dtype):
if real_dtype == torch.double:
return torch.cdouble
if real_dtype == torch.float:
return torch.cfloat
if real_dtype == torch.half:
return torch.complex32
raise ValueError(f'Unexpected dtype {real_dtype}')
def griffinlim( def griffinlim(
specgram: Tensor, specgram: Tensor,
window: Tensor, window: Tensor,
...@@ -180,23 +190,19 @@ def griffinlim( ...@@ -180,23 +190,19 @@ def griffinlim(
specgram = specgram.pow(1 / power) specgram = specgram.pow(1 / power)
# randomly initialize the phase # initialize the phase
batch, freq, frames = specgram.size()
if rand_init: if rand_init:
angles = 2 * math.pi * torch.rand(batch, freq, frames) angles = torch.rand(
specgram.size(),
dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
else: else:
angles = torch.zeros(batch, freq, frames) angles = torch.full(
angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \ specgram.size(), 1,
.to(dtype=specgram.dtype, device=specgram.device) dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
specgram = specgram.unsqueeze(-1).expand_as(angles)
# And initialize the previous iterate to 0 # And initialize the previous iterate to 0
rebuilt = torch.tensor(0.) tprev = torch.tensor(0., dtype=specgram.dtype, device=specgram.device)
for _ in range(n_iter): for _ in range(n_iter):
# Store the previous iterate
tprev = rebuilt
# Invert with our current estimate of the phases # Invert with our current estimate of the phases
inverse = torch.istft(specgram * angles, inverse = torch.istft(specgram * angles,
n_fft=n_fft, n_fft=n_fft,
...@@ -206,26 +212,27 @@ def griffinlim( ...@@ -206,26 +212,27 @@ def griffinlim(
length=length) length=length)
# Rebuild the spectrogram # Rebuild the spectrogram
rebuilt = torch.view_as_real( rebuilt = torch.stft(
torch.stft( input=inverse,
input=inverse, n_fft=n_fft,
n_fft=n_fft, hop_length=hop_length,
hop_length=hop_length, win_length=win_length,
win_length=win_length, window=window,
window=window, center=True,
center=True, pad_mode='reflect',
pad_mode='reflect', normalized=False,
normalized=False, onesided=True,
onesided=True, return_complex=True,
return_complex=True,
)
) )
# Update our phase estimates # Update our phase estimates
angles = rebuilt angles = rebuilt
if momentum: if momentum:
angles = angles - tprev.mul_(momentum / (1 + momentum)) angles = angles - tprev.mul_(momentum / (1 + momentum))
angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles)) angles = angles.div(angles.abs().add(1e-16))
# Store the previous iterate
tprev = rebuilt
# Return the final phase estimates # Return the final phase estimates
waveform = torch.istft(specgram * angles, waveform = torch.istft(specgram * angles,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment