Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
78c3480e
"megatron/vscode:/vscode.git/clone" did not exist on "270d6412bfe0b96af13ef861e47a3e8b45dfecee"
Unverified
Commit
78c3480e
authored
Apr 09, 2021
by
moto
Committed by
GitHub
Apr 09, 2021
Browse files
Adopt native complex dtype in griffnlim (#1368)
parent
35d68fdd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
32 deletions
+59
-32
test/torchaudio_unittest/transforms/autograd_test_impl.py
test/torchaudio_unittest/transforms/autograd_test_impl.py
+26
-6
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+33
-26
No files found.
test/torchaudio_unittest/transforms/autograd_test_impl.py
View file @
78c3480e
...
...
@@ -8,9 +8,23 @@ import torchaudio.transforms as T
from
torchaudio_unittest.common_utils
import
(
TestBaseMixin
,
get_whitenoise
,
get_spectrogram
,
nested_params
,
)
class
_DeterministicWrapper
(
torch
.
nn
.
Module
):
"""Helper transform wrapper to make the given transform deterministic"""
def
__init__
(
self
,
transform
,
seed
=
0
):
super
().
__init__
()
self
.
seed
=
seed
self
.
transform
=
transform
def
forward
(
self
,
input
:
torch
.
Tensor
):
torch
.
random
.
manual_seed
(
self
.
seed
)
return
self
.
transform
(
input
)
class
AutogradTestMixin
(
TestBaseMixin
):
def
assert_grad
(
self
,
...
...
@@ -65,14 +79,20 @@ class AutogradTestMixin(TestBaseMixin):
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
,
n_channels
=
2
)
self
.
assert_grad
(
transform
,
[
waveform
],
nondet_tol
=
1e-10
)
@
parameterized
.
expand
([(
0
,
),
(
0.99
,
)])
def
test_griffinlim
(
self
,
momentum
):
@
nested_params
(
[
0
,
0.99
],
[
False
,
True
],
)
def
test_griffinlim
(
self
,
momentum
,
rand_init
):
n_fft
=
400
n_frames
=
5
power
=
1
n_iter
=
3
spec
=
torch
.
rand
(
n_fft
//
2
+
1
,
n_frames
)
*
n_fft
transform
=
T
.
GriffinLim
(
n_fft
=
n_fft
,
n_iter
=
n_iter
,
momentum
=
momentum
,
rand_init
=
False
)
self
.
assert_grad
(
transform
,
[
spec
],
nondet_tol
=
1e-10
)
spec
=
get_spectrogram
(
get_whitenoise
(
sample_rate
=
8000
,
duration
=
0.05
,
n_channels
=
2
),
n_fft
=
n_fft
,
power
=
power
)
transform
=
_DeterministicWrapper
(
T
.
GriffinLim
(
n_fft
=
n_fft
,
n_iter
=
n_iter
,
momentum
=
momentum
,
rand_init
=
rand_init
,
power
=
power
))
self
.
assert_grad
(
transform
,
[
spec
])
@
parameterized
.
expand
([(
False
,
),
(
True
,
)])
def
test_mfcc
(
self
,
log_mels
):
...
...
torchaudio/functional/functional.py
View file @
78c3480e
...
...
@@ -125,6 +125,16 @@ def spectrogram(
return
spec_f
def
_get_complex_dtype
(
real_dtype
:
torch
.
dtype
):
if
real_dtype
==
torch
.
double
:
return
torch
.
cdouble
if
real_dtype
==
torch
.
float
:
return
torch
.
cfloat
if
real_dtype
==
torch
.
half
:
return
torch
.
complex32
raise
ValueError
(
f
'Unexpected dtype
{
real_dtype
}
'
)
def
griffinlim
(
specgram
:
Tensor
,
window
:
Tensor
,
...
...
@@ -180,23 +190,19 @@ def griffinlim(
specgram
=
specgram
.
pow
(
1
/
power
)
# randomly initialize the phase
batch
,
freq
,
frames
=
specgram
.
size
()
# initialize the phase
if
rand_init
:
angles
=
2
*
math
.
pi
*
torch
.
rand
(
batch
,
freq
,
frames
)
angles
=
torch
.
rand
(
specgram
.
size
(),
dtype
=
_get_complex_dtype
(
specgram
.
dtype
),
device
=
specgram
.
device
)
else
:
angles
=
torch
.
zeros
(
batch
,
freq
,
frames
)
angles
=
torch
.
stack
([
angles
.
cos
(),
angles
.
sin
()],
dim
=-
1
)
\
.
to
(
dtype
=
specgram
.
dtype
,
device
=
specgram
.
device
)
specgram
=
specgram
.
unsqueeze
(
-
1
).
expand_as
(
angles
)
angles
=
torch
.
full
(
specgram
.
size
(),
1
,
dtype
=
_get_complex_dtype
(
specgram
.
dtype
),
device
=
specgram
.
device
)
# And initialize the previous iterate to 0
rebuilt
=
torch
.
tensor
(
0.
)
tprev
=
torch
.
tensor
(
0.
,
dtype
=
specgram
.
dtype
,
device
=
specgram
.
device
)
for
_
in
range
(
n_iter
):
# Store the previous iterate
tprev
=
rebuilt
# Invert with our current estimate of the phases
inverse
=
torch
.
istft
(
specgram
*
angles
,
n_fft
=
n_fft
,
...
...
@@ -206,8 +212,7 @@ def griffinlim(
length
=
length
)
# Rebuild the spectrogram
rebuilt
=
torch
.
view_as_real
(
torch
.
stft
(
rebuilt
=
torch
.
stft
(
input
=
inverse
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
...
...
@@ -219,13 +224,15 @@ def griffinlim(
onesided
=
True
,
return_complex
=
True
,
)
)
# Update our phase estimates
angles
=
rebuilt
if
momentum
:
angles
=
angles
-
tprev
.
mul_
(
momentum
/
(
1
+
momentum
))
angles
=
angles
.
div
(
complex_norm
(
angles
).
add
(
1e-16
).
unsqueeze
(
-
1
).
expand_as
(
angles
))
angles
=
angles
.
div
(
angles
.
abs
().
add
(
1e-16
))
# Store the previous iterate
tprev
=
rebuilt
# Return the final phase estimates
waveform
=
torch
.
istft
(
specgram
*
angles
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment