Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a7797d5c
Unverified
Commit
a7797d5c
authored
Jan 06, 2021
by
moto
Committed by
GitHub
Jan 06, 2021
Browse files
Fix nan gradient by using native complex abs op (#1013)
parent
6b07bcf8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
20 deletions
+63
-20
test/torchaudio_unittest/functional/functional_cpu_test.py
test/torchaudio_unittest/functional/functional_cpu_test.py
+11
-2
test/torchaudio_unittest/functional/functional_cuda_test.py
test/torchaudio_unittest/functional/functional_cuda_test.py
+13
-1
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+23
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+16
-17
No files found.
test/torchaudio_unittest/functional/functional_cpu_test.py
View file @
a7797d5c
import
math
import
unittest
import
torch
import
torchaudio
...
...
@@ -8,7 +7,7 @@ from parameterized import parameterized
import
pytest
from
torchaudio_unittest
import
common_utils
from
.functional_impl
import
Lfilter
from
.functional_impl
import
Lfilter
,
Spectrogram
class
TestLFilterFloat32
(
Lfilter
,
common_utils
.
PytorchTestCase
):
...
...
@@ -21,6 +20,16 @@ class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
device
=
torch
.
device
(
'cpu'
)
class
TestSpectrogramFloat32
(
Spectrogram
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestSpectrogramFloat64
(
Spectrogram
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
class
TestCreateFBMatrix
(
common_utils
.
TorchaudioTestCase
):
def
test_no_warning_high_n_freq
(
self
):
with
pytest
.
warns
(
None
)
as
w
:
...
...
test/torchaudio_unittest/functional/functional_cuda_test.py
View file @
a7797d5c
import
torch
from
torchaudio_unittest
import
common_utils
from
.functional_impl
import
Lfilter
from
.functional_impl
import
Lfilter
,
Spectrogram
@
common_utils
.
skipIfNoCuda
...
...
@@ -14,3 +14,15 @@ class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
class
TestLFilterFloat64
(
Lfilter
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
@
common_utils
.
skipIfNoCuda
class
TestSpectrogramFloat32
(
Spectrogram
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
common_utils
.
skipIfNoCuda
class
TestSpectrogramFloat64
(
Spectrogram
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/functional/functional_impl.py
View file @
a7797d5c
"""Test defintion common to CPU and CUDA"""
import
torch
import
torchaudio.functional
as
F
from
parameterized
import
parameterized
from
torchaudio_unittest
import
common_utils
...
...
@@ -29,3 +30,25 @@ class Lfilter(common_utils.TestBaseMixin):
assert
output_signal
.
max
()
<=
1
output_signal
=
F
.
lfilter
(
input_signal
,
a_coeffs
,
b_coeffs
,
clamp
=
False
)
assert
output_signal
.
max
()
>
1
class
Spectrogram
(
common_utils
.
TestBaseMixin
):
@
parameterized
.
expand
([(
0.
,
),
(
1.
,
),
(
2.
,
),
(
3.
,
)])
def
test_grad_at_zero
(
self
,
power
):
"""The gradient of power spectrogram should not be nan but zero near x=0
https://github.com/pytorch/audio/issues/993
"""
x
=
torch
.
zeros
(
1
,
22050
,
requires_grad
=
True
)
spec
=
F
.
spectrogram
(
x
,
pad
=
0
,
window
=
None
,
n_fft
=
2048
,
hop_length
=
None
,
win_length
=
None
,
power
=
power
,
normalized
=
False
,
)
spec
.
sum
().
backward
()
assert
not
x
.
grad
.
isnan
().
sum
()
torchaudio/functional/functional.py
View file @
a7797d5c
...
...
@@ -70,8 +70,7 @@ def spectrogram(
waveform
=
waveform
.
reshape
(
-
1
,
shape
[
-
1
])
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f
=
torch
.
view_as_real
(
torch
.
stft
(
spec_f
=
torch
.
stft
(
input
=
waveform
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
...
...
@@ -83,17 +82,17 @@ def spectrogram(
onesided
=
True
,
return_complex
=
True
,
)
)
# unpack batch
spec_f
=
spec_f
.
reshape
(
shape
[:
-
1
]
+
spec_f
.
shape
[
-
3
:])
spec_f
=
spec_f
.
reshape
(
shape
[:
-
1
]
+
spec_f
.
shape
[
-
2
:])
if
normalized
:
spec_f
/=
window
.
pow
(
2.
).
sum
().
sqrt
()
if
power
is
not
None
:
spec_f
=
complex_norm
(
spec_f
,
power
=
power
)
return
spec_f
if
power
==
1.0
:
return
spec_f
.
abs
()
return
spec_f
.
abs
().
pow
(
power
)
return
torch
.
view_as_real
(
spec_f
)
def
griffinlim
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment