Unverified Commit a7797d5c authored by moto's avatar moto Committed by GitHub
Browse files

Fix nan gradient by using native complex abs op (#1013)

parent 6b07bcf8
import math import math
import unittest
import torch import torch
import torchaudio import torchaudio
...@@ -8,7 +7,7 @@ from parameterized import parameterized ...@@ -8,7 +7,7 @@ from parameterized import parameterized
import pytest import pytest
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .functional_impl import Lfilter from .functional_impl import Lfilter, Spectrogram
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
...@@ -21,6 +20,16 @@ class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): ...@@ -21,6 +20,16 @@ class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
device = torch.device('cpu') device = torch.device('cpu')
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestCreateFBMatrix(common_utils.TorchaudioTestCase): class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
def test_no_warning_high_n_freq(self): def test_no_warning_high_n_freq(self):
with pytest.warns(None) as w: with pytest.warns(None) as w:
......
import torch import torch
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .functional_impl import Lfilter from .functional_impl import Lfilter, Spectrogram
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
...@@ -14,3 +14,15 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): ...@@ -14,3 +14,15 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
"""Test defintion common to CPU and CUDA""" """Test defintion common to CPU and CUDA"""
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
...@@ -29,3 +30,25 @@ class Lfilter(common_utils.TestBaseMixin): ...@@ -29,3 +30,25 @@ class Lfilter(common_utils.TestBaseMixin):
assert output_signal.max() <= 1 assert output_signal.max() <= 1
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False) output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
assert output_signal.max() > 1 assert output_signal.max() > 1
class Spectrogram(common_utils.TestBaseMixin):
@parameterized.expand([(0., ), (1., ), (2., ), (3., )])
def test_grad_at_zero(self, power):
"""The gradient of power spectrogram should not be nan but zero near x=0
https://github.com/pytorch/audio/issues/993
"""
x = torch.zeros(1, 22050, requires_grad=True)
spec = F.spectrogram(
x,
pad=0,
window=None,
n_fft=2048,
hop_length=None,
win_length=None,
power=power,
normalized=False,
)
spec.sum().backward()
assert not x.grad.isnan().sum()
...@@ -70,30 +70,29 @@ def spectrogram( ...@@ -70,30 +70,29 @@ def spectrogram(
waveform = waveform.reshape(-1, shape[-1]) waveform = waveform.reshape(-1, shape[-1])
# default values are consistent with librosa.core.spectrum._spectrogram # default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.view_as_real( spec_f = torch.stft(
torch.stft( input=waveform,
input=waveform, n_fft=n_fft,
n_fft=n_fft, hop_length=hop_length,
hop_length=hop_length, win_length=win_length,
win_length=win_length, window=window,
window=window, center=True,
center=True, pad_mode="reflect",
pad_mode="reflect", normalized=False,
normalized=False, onesided=True,
onesided=True, return_complex=True,
return_complex=True,
)
) )
# unpack batch # unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:]) spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
if normalized: if normalized:
spec_f /= window.pow(2.).sum().sqrt() spec_f /= window.pow(2.).sum().sqrt()
if power is not None: if power is not None:
spec_f = complex_norm(spec_f, power=power) if power == 1.0:
return spec_f.abs()
return spec_f return spec_f.abs().pow(power)
return torch.view_as_real(spec_f)
def griffinlim( def griffinlim(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment