Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
9835db75
Unverified
Commit
9835db75
authored
May 13, 2020
by
moto
Committed by
GitHub
May 13, 2020
Browse files
Use istft from torch (#523)
parent
2dd04029
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
106 deletions
+17
-106
test/test_functional.py
test/test_functional.py
+4
-4
torchaudio/functional.py
torchaudio/functional.py
+13
-102
No files found.
test/test_functional.py
View file @
9835db75
...
...
@@ -172,7 +172,7 @@ class TestIstft(unittest.TestCase):
def
test_istft_requires_overlap_windows
(
self
):
# the window is size 1 but it hops 20 so there is a gap which throw an error
stft
=
torch
.
zeros
((
3
,
5
,
2
))
self
.
assertRaises
(
Assertion
Error
,
torchaudio
.
functional
.
istft
,
stft
,
n_fft
=
4
,
self
.
assertRaises
(
Runtime
Error
,
torchaudio
.
functional
.
istft
,
stft
,
n_fft
=
4
,
hop_length
=
20
,
win_length
=
1
,
window
=
torch
.
ones
(
1
))
def
test_istft_requires_nola
(
self
):
...
...
@@ -192,11 +192,11 @@ class TestIstft(unittest.TestCase):
# A window of ones meets NOLA but a window of zeros does not. This should
# throw an error.
torchaudio
.
functional
.
istft
(
stft
,
**
kwargs_ok
)
self
.
assertRaises
(
Assertion
Error
,
torchaudio
.
functional
.
istft
,
stft
,
**
kwargs_not_ok
)
self
.
assertRaises
(
Runtime
Error
,
torchaudio
.
functional
.
istft
,
stft
,
**
kwargs_not_ok
)
def
test_istft_requires_non_empty
(
self
):
self
.
assertRaises
(
Assertion
Error
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
3
,
0
,
2
)),
2
)
self
.
assertRaises
(
Assertion
Error
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
0
,
3
,
2
)),
2
)
self
.
assertRaises
(
Runtime
Error
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
3
,
0
,
2
)),
2
)
self
.
assertRaises
(
Runtime
Error
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
0
,
3
,
2
)),
2
)
def
_test_istft_of_sine
(
self
,
amplitude
,
L
,
n
):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
...
...
torchaudio/functional.py
View file @
9835db75
...
...
@@ -2,6 +2,7 @@
import
math
from
typing
import
Optional
,
Tuple
import
warnings
import
torch
from
torch
import
Tensor
...
...
@@ -49,7 +50,7 @@ def istft(
win_length
:
Optional
[
int
]
=
None
,
window
:
Optional
[
Tensor
]
=
None
,
center
:
bool
=
True
,
pad_mode
:
str
=
"reflect"
,
pad_mode
:
Optional
[
str
]
=
None
,
normalized
:
bool
=
False
,
onesided
:
bool
=
True
,
length
:
Optional
[
int
]
=
None
,
...
...
@@ -94,8 +95,7 @@ def istft(
center (bool, optional): Whether ``input`` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``)
pad_mode (str, optional): Controls the padding method used when ``center`` is True. (Default:
``"reflect"``)
pad_mode: This argument was ignored and to be removed.
normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
length (int or None, optional): The amount to trim the signal by (i.e. the
...
...
@@ -104,105 +104,16 @@ def istft(
Returns:
Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
stft_matrix_dim
=
stft_matrix
.
dim
()
assert
3
<=
stft_matrix_dim
,
"Incorrect stft dimension: %d"
%
(
stft_matrix_dim
)
assert
stft_matrix
.
numel
()
>
0
if
stft_matrix_dim
==
3
:
# add a channel dimension
stft_matrix
=
stft_matrix
.
unsqueeze
(
0
)
# pack batch
shape
=
stft_matrix
.
size
()
stft_matrix
=
stft_matrix
.
reshape
(
-
1
,
shape
[
-
3
],
shape
[
-
2
],
shape
[
-
1
])
dtype
=
stft_matrix
.
dtype
device
=
stft_matrix
.
device
fft_size
=
stft_matrix
.
size
(
1
)
assert
(
onesided
and
n_fft
//
2
+
1
==
fft_size
)
or
(
not
onesided
and
n_fft
==
fft_size
),
(
"one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
+
"Given values were onesided: %s, n_fft: %d, fft_size: %d"
%
(
"True"
if
onesided
else
False
,
n_fft
,
fft_size
)
)
# use stft defaults for Optionals
if
win_length
is
None
:
win_length
=
n_fft
if
hop_length
is
None
:
hop_length
=
int
(
win_length
//
4
)
# There must be overlap
assert
0
<
hop_length
<=
win_length
assert
0
<
win_length
<=
n_fft
if
window
is
None
:
window
=
torch
.
ones
(
win_length
,
device
=
device
,
dtype
=
dtype
)
assert
window
.
dim
()
==
1
and
window
.
size
(
0
)
==
win_length
if
win_length
!=
n_fft
:
# center window with pad left and right zeros
left
=
(
n_fft
-
win_length
)
//
2
window
=
torch
.
nn
.
functional
.
pad
(
window
,
(
left
,
n_fft
-
win_length
-
left
))
assert
window
.
size
(
0
)
==
n_fft
# win_length and n_fft are synonymous from here on
stft_matrix
=
stft_matrix
.
transpose
(
1
,
2
)
# size (channel, n_frame, fft_size, 2)
stft_matrix
=
torch
.
irfft
(
stft_matrix
,
1
,
normalized
,
onesided
,
signal_sizes
=
(
n_fft
,)
)
# size (channel, n_frame, n_fft)
assert
stft_matrix
.
size
(
2
)
==
n_fft
n_frame
=
stft_matrix
.
size
(
1
)
ytmp
=
stft_matrix
*
window
.
view
(
1
,
1
,
n_fft
)
# size (channel, n_frame, n_fft)
# each column of a channel is a frame which needs to be overlap added at the right place
ytmp
=
ytmp
.
transpose
(
1
,
2
)
# size (channel, n_fft, n_frame)
# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
y
=
torch
.
nn
.
functional
.
fold
(
ytmp
,
(
1
,
(
n_frame
-
1
)
*
hop_length
+
n_fft
),
(
1
,
n_fft
),
stride
=
(
1
,
hop_length
)
).
squeeze
(
2
)
# do the same for the window function
window_sq
=
(
window
.
pow
(
2
).
view
(
n_fft
,
1
).
repeat
((
1
,
n_frame
)).
unsqueeze
(
0
)
)
# size (1, n_fft, n_frame)
window_envelop
=
torch
.
nn
.
functional
.
fold
(
window_sq
,
(
1
,
(
n_frame
-
1
)
*
hop_length
+
n_fft
),
(
1
,
n_fft
),
stride
=
(
1
,
hop_length
)
).
squeeze
(
2
)
# size (1, 1, expected_signal_len)
expected_signal_len
=
n_fft
+
hop_length
*
(
n_frame
-
1
)
assert
y
.
size
(
2
)
==
expected_signal_len
assert
window_envelop
.
size
(
2
)
==
expected_signal_len
half_n_fft
=
n_fft
//
2
# we need to trim the front padding away if center
start
=
half_n_fft
if
center
else
0
end
=
-
half_n_fft
if
length
is
None
else
start
+
length
y
=
y
[:,
:,
start
:
end
]
window_envelop
=
window_envelop
[:,
:,
start
:
end
]
# check NOLA non-zero overlap condition
window_envelop_lowest
=
window_envelop
.
abs
().
min
()
assert
window_envelop_lowest
>
1e-11
,
"window overlap add min: %f"
%
(
window_envelop_lowest
)
y
=
(
y
/
window_envelop
).
squeeze
(
1
)
# size (channel, expected_signal_len)
# unpack batch
y
=
y
.
reshape
(
shape
[:
-
3
]
+
y
.
shape
[
-
1
:])
if
stft_matrix_dim
==
3
:
# remove the channel dimension
y
=
y
.
squeeze
(
0
)
return
y
warnings
.
warn
(
'istft has been moved to PyTorch and will be removed from torchaudio, '
'please use torch.istft instead.'
)
if
pad_mode
is
not
None
:
warnings
.
warn
(
'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. '
'Please set `pad_mode` to None to suppress this warning.'
)
return
torch
.
istft
(
input
=
stft_matrix
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
normalized
=
normalized
,
onesided
=
onesided
,
length
=
length
)
def
spectrogram
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment