Unverified Commit a7797d5c authored by moto's avatar moto Committed by GitHub
Browse files

Fix nan gradient by using native complex abs op (#1013)

parent 6b07bcf8
import math
import unittest
import torch
import torchaudio
......@@ -8,7 +7,7 @@ from parameterized import parameterized
import pytest
from torchaudio_unittest import common_utils
from .functional_impl import Lfilter
from .functional_impl import Lfilter, Spectrogram
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
......@@ -21,6 +20,16 @@ class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
device = torch.device('cpu')
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self):
with pytest.warns(None) as w:
......
import torch
from torchaudio_unittest import common_utils
from .functional_impl import Lfilter
from .functional_impl import Lfilter, Spectrogram
@common_utils.skipIfNoCuda
......@@ -14,3 +14,15 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
"""Test defintion common to CPU and CUDA"""
import torch
import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest import common_utils
......@@ -29,3 +30,25 @@ class Lfilter(common_utils.TestBaseMixin):
assert output_signal.max() <= 1
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
assert output_signal.max() > 1
class Spectrogram(common_utils.TestBaseMixin):
@parameterized.expand([(0., ), (1., ), (2., ), (3., )])
def test_grad_at_zero(self, power):
"""The gradient of power spectrogram should not be nan but zero near x=0
https://github.com/pytorch/audio/issues/993
"""
x = torch.zeros(1, 22050, requires_grad=True)
spec = F.spectrogram(
x,
pad=0,
window=None,
n_fft=2048,
hop_length=None,
win_length=None,
power=power,
normalized=False,
)
spec.sum().backward()
assert not x.grad.isnan().sum()
......@@ -70,8 +70,7 @@ def spectrogram(
waveform = waveform.reshape(-1, shape[-1])
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.view_as_real(
torch.stft(
spec_f = torch.stft(
input=waveform,
n_fft=n_fft,
hop_length=hop_length,
......@@ -83,17 +82,17 @@ def spectrogram(
onesided=True,
return_complex=True,
)
)
# unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
if normalized:
spec_f /= window.pow(2.).sum().sqrt()
if power is not None:
spec_f = complex_norm(spec_f, power=power)
return spec_f
if power == 1.0:
return spec_f.abs()
return spec_f.abs().pow(power)
return torch.view_as_real(spec_f)
def griffinlim(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment