Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
56e2835e
Commit
56e2835e
authored
Jul 10, 2019
by
jamarshon
Committed by
cpuhrsch
Jul 10, 2019
Browse files
ISTFT (#135)
parent
cf6ce7dc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
352 additions
and
0 deletions
+352
-0
test/common_utils.py
test/common_utils.py
+32
-0
test/test_functional.py
test/test_functional.py
+187
-0
torchaudio/functional.py
torchaudio/functional.py
+133
-0
No files found.
test/common_utils.py
View file @
56e2835e
import
os
import
os
from
shutil
import
copytree
from
shutil
import
copytree
import
tempfile
import
tempfile
import
torch
TEST_DIR_PATH
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
TEST_DIR_PATH
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
...
@@ -16,3 +17,34 @@ def create_temp_assets_dir():
...
@@ -16,3 +17,34 @@ def create_temp_assets_dir():
copytree
(
os
.
path
.
join
(
TEST_DIR_PATH
,
"assets"
),
copytree
(
os
.
path
.
join
(
TEST_DIR_PATH
,
"assets"
),
os
.
path
.
join
(
tmp_dir
.
name
,
"assets"
))
os
.
path
.
join
(
tmp_dir
.
name
,
"assets"
))
return
tmp_dir
.
name
,
tmp_dir
return
tmp_dir
.
name
,
tmp_dir
def
random_float_tensor
(
seed
,
size
,
a
=
22695477
,
c
=
1
,
m
=
2
**
32
):
""" Generates random tensors given a seed and size
https://en.wikipedia.org/wiki/Linear_congruential_generator
X_{n + 1} = (a * X_n + c) % m
Using Borland C/C++ values
The tensor will have values between [0,1)
Inputs:
seed (int): an int
size (Tuple[int]): the size of the output tensor
a (int): the multiplier constant to the generator
c (int): the additive constant to the generator
m (int): the modulus constant to the generator
"""
num_elements
=
1
for
s
in
size
:
num_elements
*=
s
arr
=
[(
a
*
seed
+
c
)
%
m
]
for
i
in
range
(
num_elements
-
1
):
arr
.
append
((
a
*
arr
[
i
]
+
c
)
%
m
)
return
torch
.
tensor
(
arr
).
float
().
view
(
size
)
/
m
def
random_int_tensor
(
seed
,
size
,
low
=
0
,
high
=
2
**
32
,
a
=
22695477
,
c
=
1
,
m
=
2
**
32
):
""" Same as random_float_tensor but integers between [low, high)
"""
return
torch
.
floor
(
random_float_tensor
(
seed
,
size
,
a
,
c
,
m
)
*
(
high
-
low
))
+
low
test/test_functional.py
0 → 100644
View file @
56e2835e
import
math
import
torch
import
torchaudio
import
unittest
import
test.common_utils
class
TestFunctional
(
unittest
.
TestCase
):
data_sizes
=
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]
number_of_trials
=
100
def
_compare_estimate
(
self
,
sound
,
estimate
,
atol
=
1e-6
,
rtol
=
1e-8
):
# trim sound for case when constructed signal is shorter than original
sound
=
sound
[...,
:
estimate
.
size
(
-
1
)]
self
.
assertTrue
(
sound
.
shape
==
estimate
.
shape
,
(
sound
.
shape
,
estimate
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
sound
,
estimate
,
atol
=
atol
,
rtol
=
rtol
))
def
_test_istft_is_inverse_of_stft
(
self
,
kwargs
):
# generates a random sound signal for each tril and then does the stft/istft
# operation to check whether we can reconstruct signal
for
data_size
in
self
.
data_sizes
:
for
i
in
range
(
self
.
number_of_trials
):
sound
=
test
.
common_utils
.
random_float_tensor
(
i
,
data_size
)
stft
=
torch
.
stft
(
sound
,
**
kwargs
)
estimate
=
torchaudio
.
functional
.
istft
(
stft
,
length
=
sound
.
size
(
1
),
**
kwargs
)
self
.
_compare_estimate
(
sound
,
estimate
)
def
test_istft_is_inverse_of_stft1
(
self
):
# hann_window, centered, normalized, onesided
kwargs1
=
{
'n_fft'
:
12
,
'hop_length'
:
4
,
'win_length'
:
12
,
'window'
:
torch
.
hann_window
(
12
),
'center'
:
True
,
'pad_mode'
:
'reflect'
,
'normalized'
:
True
,
'onesided'
:
True
,
}
self
.
_test_istft_is_inverse_of_stft
(
kwargs1
)
def
test_istft_is_inverse_of_stft2
(
self
):
# hann_window, centered, not normalized, not onesided
kwargs2
=
{
'n_fft'
:
12
,
'hop_length'
:
2
,
'win_length'
:
8
,
'window'
:
torch
.
hann_window
(
8
),
'center'
:
True
,
'pad_mode'
:
'reflect'
,
'normalized'
:
False
,
'onesided'
:
False
,
}
self
.
_test_istft_is_inverse_of_stft
(
kwargs2
)
def
test_istft_is_inverse_of_stft3
(
self
):
# hamming_window, centered, normalized, not onesided
kwargs3
=
{
'n_fft'
:
15
,
'hop_length'
:
3
,
'win_length'
:
11
,
'window'
:
torch
.
hamming_window
(
11
),
'center'
:
True
,
'pad_mode'
:
'constant'
,
'normalized'
:
True
,
'onesided'
:
False
,
}
self
.
_test_istft_is_inverse_of_stft
(
kwargs3
)
def
test_istft_is_inverse_of_stft4
(
self
):
# hamming_window, not centered, not normalized, onesided
# window same size as n_fft
kwargs4
=
{
'n_fft'
:
5
,
'hop_length'
:
2
,
'win_length'
:
5
,
'window'
:
torch
.
hamming_window
(
5
),
'center'
:
False
,
'pad_mode'
:
'constant'
,
'normalized'
:
False
,
'onesided'
:
True
,
}
self
.
_test_istft_is_inverse_of_stft
(
kwargs4
)
def
test_istft_is_inverse_of_stft5
(
self
):
# hamming_window, not centered, not normalized, not onesided
# window same size as n_fft
kwargs5
=
{
'n_fft'
:
3
,
'hop_length'
:
2
,
'win_length'
:
3
,
'window'
:
torch
.
hamming_window
(
3
),
'center'
:
False
,
'pad_mode'
:
'reflect'
,
'normalized'
:
False
,
'onesided'
:
False
,
}
self
.
_test_istft_is_inverse_of_stft
(
kwargs5
)
def
test_istft_of_ones
(
self
):
# stft = torch.stft(torch.ones(4), 4)
stft
=
torch
.
tensor
([
[[
4.
,
0.
],
[
4.
,
0.
],
[
4.
,
0.
],
[
4.
,
0.
],
[
4.
,
0.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]]
])
estimate
=
torchaudio
.
functional
.
istft
(
stft
,
n_fft
=
4
,
length
=
4
)
self
.
_compare_estimate
(
torch
.
ones
(
4
),
estimate
)
def
test_istft_of_zeros
(
self
):
# stft = torch.stft(torch.zeros(4), 4)
stft
=
torch
.
zeros
((
3
,
5
,
2
))
estimate
=
torchaudio
.
functional
.
istft
(
stft
,
n_fft
=
4
,
length
=
4
)
self
.
_compare_estimate
(
torch
.
zeros
(
4
),
estimate
)
def
test_istft_requires_overlap_windows
(
self
):
# the window is size 1 but it hops 20 so there is a gap which throw an error
stft
=
torch
.
zeros
((
3
,
5
,
2
))
self
.
assertRaises
(
AssertionError
,
torchaudio
.
functional
.
istft
,
stft
,
n_fft
=
4
,
hop_length
=
20
,
win_length
=
1
,
window
=
torch
.
ones
(
1
))
def
test_istft_requires_nola
(
self
):
stft
=
torch
.
zeros
((
3
,
5
,
2
))
kwargs_ok
=
{
'n_fft'
:
4
,
'win_length'
:
4
,
'window'
:
torch
.
ones
(
4
),
}
kwargs_not_ok
=
{
'n_fft'
:
4
,
'win_length'
:
4
,
'window'
:
torch
.
zeros
(
4
),
}
# A window of ones meets NOLA but a window of zeros does not. This should
# throw an error.
torchaudio
.
functional
.
istft
(
stft
,
**
kwargs_ok
)
self
.
assertRaises
(
AssertionError
,
torchaudio
.
functional
.
istft
,
stft
,
**
kwargs_not_ok
)
def
test_istft_requires_non_empty
(
self
):
self
.
assertRaises
(
AssertionError
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
3
,
0
,
2
)),
2
)
self
.
assertRaises
(
AssertionError
,
torchaudio
.
functional
.
istft
,
torch
.
zeros
((
0
,
3
,
2
)),
2
)
def
_test_istft_of_sine
(
self
,
amplitude
,
L
,
n
):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
x
=
torch
.
arange
(
2
*
L
+
1
,
dtype
=
torch
.
get_default_dtype
())
sound
=
amplitude
*
torch
.
sin
(
2
*
math
.
pi
/
L
*
x
*
n
)
# stft = torch.stft(sound, L, hop_length=L, win_length=L,
# window=torch.ones(L), center=False, normalized=False)
stft
=
torch
.
zeros
((
L
//
2
+
1
,
2
,
2
))
stft_largest_val
=
(
amplitude
*
L
)
/
2.0
if
n
<
stft
.
size
(
0
):
stft
[
n
,
:,
1
]
=
-
stft_largest_val
if
0
<=
L
-
n
<
stft
.
size
(
0
):
# symmetric about L // 2
stft
[
L
-
n
,
:,
1
]
=
stft_largest_val
estimate
=
torchaudio
.
functional
.
istft
(
stft
,
L
,
hop_length
=
L
,
win_length
=
L
,
window
=
torch
.
ones
(
L
),
center
=
False
,
normalized
=
False
)
# There is a larger error due to the scaling of amplitude
self
.
_compare_estimate
(
sound
,
estimate
,
atol
=
1e-3
)
def
test_istft_of_sine
(
self
):
self
.
_test_istft_of_sine
(
amplitude
=
123
,
L
=
5
,
n
=
1
)
self
.
_test_istft_of_sine
(
amplitude
=
150
,
L
=
5
,
n
=
2
)
self
.
_test_istft_of_sine
(
amplitude
=
111
,
L
=
5
,
n
=
3
)
self
.
_test_istft_of_sine
(
amplitude
=
160
,
L
=
7
,
n
=
4
)
self
.
_test_istft_of_sine
(
amplitude
=
145
,
L
=
8
,
n
=
5
)
self
.
_test_istft_of_sine
(
amplitude
=
80
,
L
=
9
,
n
=
6
)
self
.
_test_istft_of_sine
(
amplitude
=
99
,
L
=
10
,
n
=
7
)
if
__name__
==
'__main__'
:
unittest
.
main
()
torchaudio/functional.py
View file @
56e2835e
...
@@ -7,6 +7,7 @@ __all__ = [
...
@@ -7,6 +7,7 @@ __all__ = [
'pad_trim'
,
'pad_trim'
,
'downmix_mono'
,
'downmix_mono'
,
'LC2CL'
,
'LC2CL'
,
'istft'
,
'spectrogram'
,
'spectrogram'
,
'create_fb_matrix'
,
'create_fb_matrix'
,
'spectrogram_to_DB'
,
'spectrogram_to_DB'
,
...
@@ -105,6 +106,138 @@ def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normal
...
@@ -105,6 +106,138 @@ def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normal
return
torch
.
stft
(
input
,
n_fft
,
hop_length
,
win_length
,
window
,
center
,
pad_mode
,
normalized
,
onesided
)
return
torch
.
stft
(
input
,
n_fft
,
hop_length
,
win_length
,
window
,
center
,
pad_mode
,
normalized
,
onesided
)
def
istft
(
stft_matrix
,
# type: Tensor
n_fft
,
# type: int
hop_length
=
None
,
# type: Optional[int]
win_length
=
None
,
# type: Optional[int]
window
=
None
,
# type: Optional[Tensor]
center
=
True
,
# type: bool
pad_mode
=
'reflect'
,
# type: str
normalized
=
False
,
# type: bool
onesided
=
True
,
# type: bool
length
=
None
# type: Optional[int]
):
# type: (...) -> Tensor
r
""" Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft.
It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
least squares estimation of the original signal. The algorithm will check using the NOLA condition (
nonzero overlap).
Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop
created by the summation of all the windows is never zero at certain point in time. Specifically,
:math:`\sum_{t=-\ infty}^{\ infty} w^2[n-t\times hop\_length] \neq 0`.
Since stft discards elements at the end of the signal if they do not fit in a frame, the
istft may return a shorter signal than the original signal (can occur if :attr:`center` is False
since the signal isn't padded).
If :attr:`center` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding
can be trimmed off exactly because they can be calculated but right padding cannot be calculated
without additional information.
Example: Suppose the last window is:
[17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0]
The n_frames, hop_length, win_length are all the same which prevents the calculation of right padding.
These additional values could be zeros or a reflection of the signal so providing :attr:`length`
could be useful. If :attr:`length` is None then padding will be aggressively removed (some loss of signal).
[1] D. W. Griffin and J. S. Lim, “Signal estimation from modified short-time Fourier transform,”
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Inputs:
stft_matrix (Tensor): output of stft where each row of a batch is a frequency and each column is
a window. it has a shape of either (batch, fft_size, n_frames, 2) or (fft_size, n_frames, 2)
n_fft (int): size of Fourier transform
hop_length (Optional[int]): the distance between neighboring sliding window frames. (Default: win_length // 4)
win_length (Optional[int]): the size of window frame and STFT filter. (Default: n_fft)
window (Optional[Tensor]): the optional window function. (Default: torch.ones(win_length))
center (bool): whether :attr:`input` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`
pad_mode (str): controls the padding method used when :attr:`center` is ``True``
normalized (bool): whether the STFT was normalized
onesided (bool): whether the STFT is onesided
length (Optional[int]): the amount to trim the signal by (i.e. the
original signal length). (Default: whole signal)
Outputs:
Tensor: least squares estimation of the original signal of size (batch, signal_length) or (signal_length)
"""
stft_matrix_dim
=
stft_matrix
.
dim
()
assert
3
<=
stft_matrix_dim
<=
4
,
(
'Incorrect stft dimension: %d'
%
(
stft_matrix_dim
))
if
stft_matrix_dim
==
3
:
# add a batch dimension
stft_matrix
=
stft_matrix
.
unsqueeze
(
0
)
device
=
stft_matrix
.
device
fft_size
=
stft_matrix
.
size
(
1
)
assert
(
onesided
and
n_fft
//
2
+
1
==
fft_size
)
or
(
not
onesided
and
n_fft
==
fft_size
),
(
'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. '
+
'Given values were onesided: %s, n_fft: %d, fft_size: %d'
%
(
'True'
if
onesided
else
False
,
n_fft
,
fft_size
))
# use stft defaults for Optionals
if
win_length
is
None
:
win_length
=
n_fft
if
hop_length
is
None
:
hop_length
=
int
(
win_length
//
4
)
# There must be overlap
assert
0
<
hop_length
<=
win_length
assert
0
<
win_length
<=
n_fft
if
window
is
None
:
window
=
torch
.
ones
(
win_length
)
assert
window
.
dim
()
==
1
and
window
.
size
(
0
)
==
win_length
if
win_length
!=
n_fft
:
# center window with pad left and right zeros
left
=
(
n_fft
-
win_length
)
//
2
window
=
torch
.
nn
.
functional
.
pad
(
window
,
(
left
,
n_fft
-
win_length
-
left
))
assert
window
.
size
(
0
)
==
n_fft
# win_length and n_fft are synonymous from here on
stft_matrix
=
stft_matrix
.
transpose
(
1
,
2
)
# size (batch, n_frames, fft_size, 2)
stft_matrix
=
torch
.
irfft
(
stft_matrix
,
1
,
normalized
,
onesided
,
signal_sizes
=
(
n_fft
,))
# size (batch, n_frames, n_fft)
assert
stft_matrix
.
size
(
2
)
==
n_fft
n_frames
=
stft_matrix
.
size
(
1
)
ytmp
=
stft_matrix
*
window
.
view
(
1
,
1
,
n_fft
)
# size (batch, n_frames, n_fft)
# each column of a batch is a frame which needs to be overlap added at the right place
ytmp
=
ytmp
.
transpose
(
1
,
2
)
# size (batch, n_fft, n_frames)
eye
=
torch
.
eye
(
n_fft
,
requires_grad
=
False
,
device
=
device
).
unsqueeze
(
1
)
# size (n_fft, 1, n_fft)
# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
y
=
torch
.
nn
.
functional
.
conv_transpose1d
(
ytmp
,
eye
,
stride
=
hop_length
,
padding
=
0
)
# size (batch, 1, expected_signal_len)
# do the same for the window function
window_sq
=
window
.
pow
(
2
).
view
(
n_fft
,
1
).
repeat
((
1
,
n_frames
)).
unsqueeze
(
0
)
# size (1, n_fft, n_frames)
window_envelop
=
torch
.
nn
.
functional
.
conv_transpose1d
(
window_sq
,
eye
,
stride
=
hop_length
,
padding
=
0
)
# size (1, 1, expected_signal_len)
expected_signal_len
=
n_fft
+
hop_length
*
(
n_frames
-
1
)
assert
y
.
size
(
2
)
==
expected_signal_len
assert
window_envelop
.
size
(
2
)
==
expected_signal_len
half_n_fft
=
n_fft
//
2
# we need to trim the front padding away if center
start
=
half_n_fft
if
center
else
0
end
=
-
half_n_fft
if
length
is
None
else
start
+
length
y
=
y
[:,
:,
start
:
end
]
window_envelop
=
window_envelop
[:,
:,
start
:
end
]
# check NOLA non-zero overlap condition
window_envelop_lowest
=
window_envelop
.
abs
().
min
()
assert
window_envelop_lowest
>
1e-11
,
(
'window overlap add min: %f'
%
(
window_envelop_lowest
))
y
=
(
y
/
window_envelop
).
squeeze
(
1
)
# size (batch, expected_signal_len)
if
stft_matrix_dim
==
3
:
# remove the batch dimension
y
=
y
.
squeeze
(
0
)
return
y
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
spectrogram
(
sig
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
):
def
spectrogram
(
sig
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment