Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
9835db75
"vscode:/vscode.git/clone" did not exist on "718f25ae6e50de0ed9efbbd895c0cb3bb503abd6"
Unverified
Commit
9835db75
authored
May 13, 2020
by
moto
Committed by
GitHub
May 13, 2020
Browse files
Use istft from torch (#523)
parent
2dd04029
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
106 deletions
+17
-106
test/test_functional.py
test/test_functional.py
+4
-4
torchaudio/functional.py
torchaudio/functional.py
+13
-102
No files found.
test/test_functional.py
View file @
9835db75
...
...
@@ -172,7 +172,7 @@ class TestIstft(unittest.TestCase):
def
test_istft_requires_overlap_windows
(
self
):
# the window is size 1 but it hops 20 so there is a gap which throw an error
stft
=
torch
.
zeros
((
3
,
5
,
2
))
self
.
assertRaises
(
Assertion
Error
,
torchaudio
.
functional
.
istft
,
stft
,
n_fft
=
4
,
self
.
assertRaises
(
Runtime
Error
,
torchaudio
.
functional
.
istft
,
stft
,
n_fft
=
4
,
hop_length
=
20
,
win_length
=
1
,
window
=
torch
.
ones
(
1
))
def
test_istft_requires_nola
(
self
):
...
...
@@ -192,11 +192,11 @@ class TestIstft(unittest.TestCase):
# A window of ones meets NOLA but a window of zeros does not. This should
# throw an error.
torchaudio
.
functional
.
istft
(
stft
,
**
kwargs_ok
)
self
.
assertRaises
(
Assertion
Error
,
torchaudio
.
functional
.
istft
,
stft
,
**
kwargs_not_ok
)
self
.
assertRaises
(
Runtime
Error
,
torchaudio
.
functional
.
istft
,
stft
,
**
kwargs_not_ok
)
def
test_istft_requires_non_empty
(
self
):
self
.
assertRaises
(
Assertion
Error
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
3
,
0
,
2
)),
2
)
self
.
assertRaises
(
Assertion
Error
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
0
,
3
,
2
)),
2
)
self
.
assertRaises
(
Runtime
Error
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
3
,
0
,
2
)),
2
)
self
.
assertRaises
(
Runtime
Error
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
0
,
3
,
2
)),
2
)
def
_test_istft_of_sine
(
self
,
amplitude
,
L
,
n
):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
...
...
torchaudio/functional.py
View file @
9835db75
...
...
@@ -2,6 +2,7 @@
import
math
from
typing
import
Optional
,
Tuple
import
warnings
import
torch
from
torch
import
Tensor
...
...
@@ -49,7 +50,7 @@ def istft(
win_length
:
Optional
[
int
]
=
None
,
window
:
Optional
[
Tensor
]
=
None
,
center
:
bool
=
True
,
pad_mode
:
str
=
"reflect"
,
pad_mode
:
Optional
[
str
]
=
None
,
normalized
:
bool
=
False
,
onesided
:
bool
=
True
,
length
:
Optional
[
int
]
=
None
,
...
...
@@ -94,8 +95,7 @@ def istft(
center (bool, optional): Whether ``input`` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``)
pad_mode (str, optional): Controls the padding method used when ``center`` is True. (Default:
``"reflect"``)
pad_mode: This argument was ignored and to be removed.
normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
length (int or None, optional): The amount to trim the signal by (i.e. the
...
...
@@ -104,105 +104,16 @@ def istft(
Returns:
Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
stft_matrix_dim
=
stft_matrix
.
dim
()
assert
3
<=
stft_matrix_dim
,
"Incorrect stft dimension: %d"
%
(
stft_matrix_dim
)
assert
stft_matrix
.
numel
()
>
0
if
stft_matrix_dim
==
3
:
# add a channel dimension
stft_matrix
=
stft_matrix
.
unsqueeze
(
0
)
# pack batch
shape
=
stft_matrix
.
size
()
stft_matrix
=
stft_matrix
.
reshape
(
-
1
,
shape
[
-
3
],
shape
[
-
2
],
shape
[
-
1
])
dtype
=
stft_matrix
.
dtype
device
=
stft_matrix
.
device
fft_size
=
stft_matrix
.
size
(
1
)
assert
(
onesided
and
n_fft
//
2
+
1
==
fft_size
)
or
(
not
onesided
and
n_fft
==
fft_size
),
(
"one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
+
"Given values were onesided: %s, n_fft: %d, fft_size: %d"
%
(
"True"
if
onesided
else
False
,
n_fft
,
fft_size
)
)
# use stft defaults for Optionals
if
win_length
is
None
:
win_length
=
n_fft
if
hop_length
is
None
:
hop_length
=
int
(
win_length
//
4
)
# There must be overlap
assert
0
<
hop_length
<=
win_length
assert
0
<
win_length
<=
n_fft
if
window
is
None
:
window
=
torch
.
ones
(
win_length
,
device
=
device
,
dtype
=
dtype
)
assert
window
.
dim
()
==
1
and
window
.
size
(
0
)
==
win_length
if
win_length
!=
n_fft
:
# center window with pad left and right zeros
left
=
(
n_fft
-
win_length
)
//
2
window
=
torch
.
nn
.
functional
.
pad
(
window
,
(
left
,
n_fft
-
win_length
-
left
))
assert
window
.
size
(
0
)
==
n_fft
# win_length and n_fft are synonymous from here on
stft_matrix
=
stft_matrix
.
transpose
(
1
,
2
)
# size (channel, n_frame, fft_size, 2)
stft_matrix
=
torch
.
irfft
(
stft_matrix
,
1
,
normalized
,
onesided
,
signal_sizes
=
(
n_fft
,)
)
# size (channel, n_frame, n_fft)
assert
stft_matrix
.
size
(
2
)
==
n_fft
n_frame
=
stft_matrix
.
size
(
1
)
ytmp
=
stft_matrix
*
window
.
view
(
1
,
1
,
n_fft
)
# size (channel, n_frame, n_fft)
# each column of a channel is a frame which needs to be overlap added at the right place
ytmp
=
ytmp
.
transpose
(
1
,
2
)
# size (channel, n_fft, n_frame)
# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
y
=
torch
.
nn
.
functional
.
fold
(
ytmp
,
(
1
,
(
n_frame
-
1
)
*
hop_length
+
n_fft
),
(
1
,
n_fft
),
stride
=
(
1
,
hop_length
)
).
squeeze
(
2
)
# do the same for the window function
window_sq
=
(
window
.
pow
(
2
).
view
(
n_fft
,
1
).
repeat
((
1
,
n_frame
)).
unsqueeze
(
0
)
)
# size (1, n_fft, n_frame)
window_envelop
=
torch
.
nn
.
functional
.
fold
(
window_sq
,
(
1
,
(
n_frame
-
1
)
*
hop_length
+
n_fft
),
(
1
,
n_fft
),
stride
=
(
1
,
hop_length
)
).
squeeze
(
2
)
# size (1, 1, expected_signal_len)
expected_signal_len
=
n_fft
+
hop_length
*
(
n_frame
-
1
)
assert
y
.
size
(
2
)
==
expected_signal_len
assert
window_envelop
.
size
(
2
)
==
expected_signal_len
half_n_fft
=
n_fft
//
2
# we need to trim the front padding away if center
start
=
half_n_fft
if
center
else
0
end
=
-
half_n_fft
if
length
is
None
else
start
+
length
y
=
y
[:,
:,
start
:
end
]
window_envelop
=
window_envelop
[:,
:,
start
:
end
]
# check NOLA non-zero overlap condition
window_envelop_lowest
=
window_envelop
.
abs
().
min
()
assert
window_envelop_lowest
>
1e-11
,
"window overlap add min: %f"
%
(
window_envelop_lowest
)
y
=
(
y
/
window_envelop
).
squeeze
(
1
)
# size (channel, expected_signal_len)
# unpack batch
y
=
y
.
reshape
(
shape
[:
-
3
]
+
y
.
shape
[
-
1
:])
if
stft_matrix_dim
==
3
:
# remove the channel dimension
y
=
y
.
squeeze
(
0
)
return
y
warnings
.
warn
(
'istft has been moved to PyTorch and will be removed from torchaudio, '
'please use torch.istft instead.'
)
if
pad_mode
is
not
None
:
warnings
.
warn
(
'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. '
'Please set `pad_mode` to None to suppress this warning.'
)
return
torch
.
istft
(
input
=
stft_matrix
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
normalized
=
normalized
,
onesided
=
onesided
,
length
=
length
)
def
spectrogram
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment