Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
2271a7ae
"src/vscode:/vscode.git/clone" did not exist on "43f1090a0f5879416d83e1b0991502a26fc27ec6"
Commit
2271a7ae
authored
Jul 24, 2019
by
Kiran Sanjeevan
Committed by
cpuhrsch
Jul 24, 2019
Browse files
torchaudio-contrib: Adding (some) functionals (#131)
parent
dc452aab
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
276 additions
and
2 deletions
+276
-2
build_tools/travis/test_script.sh
build_tools/travis/test_script.sh
+1
-1
requirements.txt
requirements.txt
+3
-0
test/common_utils.py
test/common_utils.py
+0
-1
test/test_functional.py
test/test_functional.py
+114
-0
torchaudio/functional.py
torchaudio/functional.py
+158
-0
No files found.
build_tools/travis/test_script.sh
View file @
2271a7ae
...
@@ -18,7 +18,7 @@ run_tests() {
...
@@ -18,7 +18,7 @@ run_tests() {
for
FILE
in
$TEST_FILES
;
do
for
FILE
in
$TEST_FILES
;
do
# run each file on a separate process. if one fails, just keep going and
# run each file on a separate process. if one fails, just keep going and
# return the final exit status.
# return the final exit status.
python
-m
unit
test
-v
$FILE
python
-m
py
test
-v
$FILE
STATUS
=
$?
STATUS
=
$?
EXIT_STATUS
=
"
$((
$EXIT_STATUS
+
STATUS
))
"
EXIT_STATUS
=
"
$((
$EXIT_STATUS
+
STATUS
))
"
done
done
...
...
requirements.txt
View file @
2271a7ae
...
@@ -10,3 +10,6 @@ flake8
...
@@ -10,3 +10,6 @@ flake8
# Used for comparison of outputs in tests
# Used for comparison of outputs in tests
librosa
librosa
scipy
scipy
# Unit tests with pytest
pytest
\ No newline at end of file
test/common_utils.py
View file @
2271a7ae
...
@@ -3,7 +3,6 @@ from shutil import copytree
...
@@ -3,7 +3,6 @@ from shutil import copytree
import
tempfile
import
tempfile
import
torch
import
torch
TEST_DIR_PATH
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
TEST_DIR_PATH
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
...
...
test/test_functional.py
View file @
2271a7ae
...
@@ -5,6 +5,16 @@ import torchaudio
...
@@ -5,6 +5,16 @@ import torchaudio
import
unittest
import
unittest
import
test.common_utils
import
test.common_utils
from
torchaudio.common_utils
import
IMPORT_LIBROSA
if
IMPORT_LIBROSA
:
import
numpy
as
np
import
librosa
import
pytest
import
torchaudio.functional
as
F
xfail
=
pytest
.
mark
.
xfail
class
TestFunctional
(
unittest
.
TestCase
):
class
TestFunctional
(
unittest
.
TestCase
):
data_sizes
=
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]
data_sizes
=
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]
...
@@ -183,5 +193,109 @@ class TestFunctional(unittest.TestCase):
...
@@ -183,5 +193,109 @@ class TestFunctional(unittest.TestCase):
self
.
_test_istft_of_sine
(
amplitude
=
99
,
L
=
10
,
n
=
7
)
self
.
_test_istft_of_sine
(
amplitude
=
99
,
L
=
10
,
n
=
7
)
def
_num_stft_bins
(
signal_len
,
fft_len
,
hop_length
,
pad
):
return
(
signal_len
+
2
*
pad
-
fft_len
+
hop_length
)
//
hop_length
@
pytest
.
mark
.
parametrize
(
'fft_length'
,
[
512
])
@
pytest
.
mark
.
parametrize
(
'hop_length'
,
[
256
])
@
pytest
.
mark
.
parametrize
(
'waveform'
,
[
(
torch
.
randn
(
1
,
100000
)),
(
torch
.
randn
(
1
,
2
,
100000
)),
pytest
.
param
(
torch
.
randn
(
1
,
100
),
marks
=
xfail
(
raises
=
RuntimeError
)),
])
@
pytest
.
mark
.
parametrize
(
'pad_mode'
,
[
# 'constant',
'reflect'
,
])
@
unittest
.
skipIf
(
not
IMPORT_LIBROSA
,
'Librosa is not available'
)
def
test_stft
(
waveform
,
fft_length
,
hop_length
,
pad_mode
):
"""
Test STFT for multi-channel signals.
Padding: Value in having padding outside of torch.stft?
"""
pad
=
fft_length
//
2
window
=
torch
.
hann_window
(
fft_length
)
complex_spec
=
F
.
stft
(
waveform
,
fft_length
=
fft_length
,
hop_length
=
hop_length
,
window
=
window
,
pad_mode
=
pad_mode
)
mag_spec
,
phase_spec
=
F
.
magphase
(
complex_spec
)
# == Test shape
expected_size
=
list
(
waveform
.
size
()[:
-
1
])
expected_size
+=
[
fft_length
//
2
+
1
,
_num_stft_bins
(
waveform
.
size
(
-
1
),
fft_length
,
hop_length
,
pad
),
2
]
assert
complex_spec
.
dim
()
==
waveform
.
dim
()
+
2
assert
complex_spec
.
size
()
==
torch
.
Size
(
expected_size
)
# == Test values
fft_config
=
dict
(
n_fft
=
fft_length
,
hop_length
=
hop_length
,
pad_mode
=
pad_mode
)
# note that librosa *automatically* pad with fft_length // 2.
expected_complex_spec
=
np
.
apply_along_axis
(
librosa
.
stft
,
-
1
,
waveform
.
numpy
(),
**
fft_config
)
expected_mag_spec
,
_
=
librosa
.
magphase
(
expected_complex_spec
)
# Convert torch to np.complex
complex_spec
=
complex_spec
.
numpy
()
complex_spec
=
complex_spec
[...,
0
]
+
1j
*
complex_spec
[...,
1
]
assert
np
.
allclose
(
complex_spec
,
expected_complex_spec
,
atol
=
1e-5
)
assert
np
.
allclose
(
mag_spec
.
numpy
(),
expected_mag_spec
,
atol
=
1e-5
)
@
pytest
.
mark
.
parametrize
(
'rate'
,
[
0.5
,
1.01
,
1.3
])
@
pytest
.
mark
.
parametrize
(
'complex_specgrams'
,
[
torch
.
randn
(
1
,
2
,
1025
,
400
,
2
),
torch
.
randn
(
1
,
1025
,
400
,
2
)
])
@
pytest
.
mark
.
parametrize
(
'hop_length'
,
[
256
])
@
unittest
.
skipIf
(
not
IMPORT_LIBROSA
,
'Librosa is not available'
)
def
test_phase_vocoder
(
complex_specgrams
,
rate
,
hop_length
):
# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
complex_specgrams
=
complex_specgrams
.
type
(
torch
.
float64
)
phase_advance
=
torch
.
linspace
(
0
,
np
.
pi
*
hop_length
,
complex_specgrams
.
shape
[
-
3
],
dtype
=
torch
.
float64
)[...,
None
]
complex_specgrams_stretch
=
F
.
phase_vocoder
(
complex_specgrams
,
rate
=
rate
,
phase_advance
=
phase_advance
)
# == Test shape
expected_size
=
list
(
complex_specgrams
.
size
())
expected_size
[
-
2
]
=
int
(
np
.
ceil
(
expected_size
[
-
2
]
/
rate
))
assert
complex_specgrams
.
dim
()
==
complex_specgrams_stretch
.
dim
()
assert
complex_specgrams_stretch
.
size
()
==
torch
.
Size
(
expected_size
)
# == Test values
index
=
[
0
]
*
(
complex_specgrams
.
dim
()
-
3
)
+
[
slice
(
None
)]
*
3
mono_complex_specgram
=
complex_specgrams
[
index
].
numpy
()
mono_complex_specgram
=
mono_complex_specgram
[...,
0
]
+
\
mono_complex_specgram
[...,
1
]
*
1j
expected_complex_stretch
=
librosa
.
phase_vocoder
(
mono_complex_specgram
,
rate
=
rate
,
hop_length
=
hop_length
)
complex_stretch
=
complex_specgrams_stretch
[
index
].
numpy
()
complex_stretch
=
complex_stretch
[...,
0
]
+
1j
*
complex_stretch
[...,
1
]
assert
np
.
allclose
(
complex_stretch
,
expected_complex_stretch
,
atol
=
1e-5
)
@
pytest
.
mark
.
parametrize
(
'complex_tensor'
,
[
torch
.
randn
(
1
,
2
,
1025
,
400
,
2
),
torch
.
randn
(
1025
,
400
,
2
)
])
@
pytest
.
mark
.
parametrize
(
'power'
,
[
1
,
2
,
0.7
])
def
test_complex_norm
(
complex_tensor
,
power
):
expected_norm_tensor
=
complex_tensor
.
pow
(
2
).
sum
(
-
1
).
pow
(
power
/
2
)
norm_tensor
=
F
.
complex_norm
(
complex_tensor
,
power
)
assert
torch
.
allclose
(
expected_norm_tensor
,
norm_tensor
,
atol
=
1e-5
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
torchaudio/functional.py
View file @
2271a7ae
...
@@ -450,3 +450,161 @@ def mu_law_expanding(x_mu, qc):
...
@@ -450,3 +450,161 @@ def mu_law_expanding(x_mu, qc):
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
((
x_mu
)
/
mu
)
*
2
-
1.
x
=
torch
.
sign
(
x
)
*
(
torch
.
exp
(
torch
.
abs
(
x
)
*
torch
.
log1p
(
mu
))
-
1.
)
/
mu
x
=
torch
.
sign
(
x
)
*
(
torch
.
exp
(
torch
.
abs
(
x
)
*
torch
.
log1p
(
mu
))
-
1.
)
/
mu
return
x
return
x
def
stft
(
waveforms
,
fft_length
,
hop_length
=
None
,
win_length
=
None
,
window
=
None
,
center
=
True
,
pad_mode
=
'reflect'
,
normalized
=
False
,
onesided
=
True
):
"""Compute a short time Fourier transform of the input waveform(s).
It wraps `torch.stft` after reshaping the input audio to allow for `waveforms` that `.dim()` >= 3.
It follows most of the `torch.stft` default values, but for `window`, which defaults to hann window.
Args:
waveforms (torch.Tensor): Audio signal of size `(*, channel, time)`
fft_length (int): FFT size [sample].
hop_length (int): Hop size [sample] between STFT frames.
(Defaults to `fft_length // 4`, 75%-overlapping windows by `torch.stft`).
win_length (int): Size of STFT window. (Defaults to `fft_length` by `torch.stft`).
window (torch.Tensor): window function. (Defaults to Hann Window of size `win_length` *unlike* `torch.stft`).
center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered
at time `t * hop_length`. (Defaults to `True` by `torch.stft`)
pad_mode (str): padding method (see `torch.nn.functional.pad`). (Defaults to `'reflect'` by `torch.stft`).
normalized (bool): Whether the results are normalized. (Defaults to `False` by `torch.stft`).
onesided (bool): Whether the half + 1 frequency bins are returned to removethe symmetric part of STFT
of real-valued signal. (Defaults to `True` by `torch.stft`).
Returns:
torch.Tensor: `(*, channel, num_freqs, time, complex=2)`
Example:
>>> waveforms = torch.randn(16, 2, 10000) # (batch, channel, time)
>>> x = stft(waveforms, 2048, 512)
>>> x.shape
torch.Size([16, 2, 1025, 20])
"""
leading_dims
=
waveforms
.
shape
[:
-
1
]
waveforms
=
waveforms
.
reshape
(
-
1
,
waveforms
.
size
(
-
1
))
if
window
is
None
:
if
win_length
is
None
:
window
=
torch
.
hann_window
(
fft_length
)
else
:
window
=
torch
.
hann_window
(
win_length
)
complex_specgrams
=
torch
.
stft
(
waveforms
,
n_fft
=
fft_length
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
pad_mode
=
pad_mode
,
normalized
=
normalized
,
onesided
=
onesided
)
complex_specgrams
=
complex_specgrams
.
reshape
(
leading_dims
+
complex_specgrams
.
shape
[
1
:])
return
complex_specgrams
def
complex_norm
(
complex_tensor
,
power
=
1.0
):
"""Compute the norm of complex tensor input
Args:
complex_tensor (Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm. Defaults to `1.0`.
Returns:
Tensor: power of the normed input tensor, shape of `(*, )`
"""
if
power
==
1.0
:
return
torch
.
norm
(
complex_tensor
,
2
,
-
1
)
return
torch
.
norm
(
complex_tensor
,
2
,
-
1
).
pow
(
power
)
def
angle
(
complex_tensor
):
"""
Return angle of a complex tensor with shape (*, 2).
"""
return
torch
.
atan2
(
complex_tensor
[...,
1
],
complex_tensor
[...,
0
])
def
magphase
(
complex_tensor
,
power
=
1.
):
"""
Separate a complex-valued spectrogram with shape (*,2)
into its magnitude and phase.
"""
mag
=
complex_norm
(
complex_tensor
,
power
)
phase
=
angle
(
complex_tensor
)
return
mag
,
phase
def
phase_vocoder
(
complex_specgrams
,
rate
,
phase_advance
):
"""
Phase vocoder. Given a STFT tensor, speed up in time
without modifying pitch by a factor of `rate`.
Args:
complex_specgrams (Tensor):
(*, channel, num_freqs, time, complex=2)
rate (float): Speed-up factor.
phase_advance (Tensor): Expected phase advance in
each bin. (num_freqs, 1).
Returns:
complex_specgrams_stretch (Tensor):
(*, channel, num_freqs, ceil(time/rate), complex=2).
Example:
>>> num_freqs, hop_length = 1025, 512
>>> # (batch, channel, num_freqs, time, complex=2)
>>> complex_specgrams = torch.randn(16, 1, num_freqs, 300, 2)
>>> rate = 1.3 # Slow down by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, num_freqs)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([16, 1, 1025, 231, 2])
"""
ndim
=
complex_specgrams
.
dim
()
time_slice
=
[
slice
(
None
)]
*
(
ndim
-
2
)
time_steps
=
torch
.
arange
(
0
,
complex_specgrams
.
size
(
-
2
),
rate
,
device
=
complex_specgrams
.
device
,
dtype
=
complex_specgrams
.
dtype
)
alphas
=
time_steps
%
1.
phase_0
=
angle
(
complex_specgrams
[
time_slice
+
[
slice
(
1
)]])
# Time Padding
complex_specgrams
=
torch
.
nn
.
functional
.
pad
(
complex_specgrams
,
[
0
,
0
,
0
,
2
])
# (new_bins, num_freqs, 2)
complex_specgrams_0
=
complex_specgrams
[
time_slice
+
[
time_steps
.
long
()]]
complex_specgrams_1
=
complex_specgrams
[
time_slice
+
[(
time_steps
+
1
).
long
()]]
angle_0
=
angle
(
complex_specgrams_0
)
angle_1
=
angle
(
complex_specgrams_1
)
norm_0
=
torch
.
norm
(
complex_specgrams_0
,
dim
=-
1
)
norm_1
=
torch
.
norm
(
complex_specgrams_1
,
dim
=-
1
)
phase
=
angle_1
-
angle_0
-
phase_advance
phase
=
phase
-
2
*
math
.
pi
*
torch
.
round
(
phase
/
(
2
*
math
.
pi
))
# Compute Phase Accum
phase
=
phase
+
phase_advance
phase
=
torch
.
cat
([
phase_0
,
phase
[
time_slice
+
[
slice
(
-
1
)]]],
dim
=-
1
)
phase_acc
=
torch
.
cumsum
(
phase
,
-
1
)
mag
=
alphas
*
norm_1
+
(
1
-
alphas
)
*
norm_0
real_stretch
=
mag
*
torch
.
cos
(
phase_acc
)
imag_stretch
=
mag
*
torch
.
sin
(
phase_acc
)
complex_specgrams_stretch
=
torch
.
stack
([
real_stretch
,
imag_stretch
],
dim
=-
1
)
return
complex_specgrams_stretch
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment