Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b95d60c2
Unverified
Commit
b95d60c2
authored
Mar 30, 2020
by
moto
Committed by
GitHub
Mar 30, 2020
Browse files
Extract JIT tests from test_functional to the dedicated test module (#480)
parent
d63d851e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
151 additions
and
150 deletions
+151
-150
test/test_functional.py
test/test_functional.py
+0
-150
test/test_torchscript_consistency.py
test/test_torchscript_consistency.py
+151
-0
No files found.
test/test_functional.py
View file @
b95d60c2
...
@@ -17,21 +17,6 @@ if IMPORT_LIBROSA:
...
@@ -17,21 +17,6 @@ if IMPORT_LIBROSA:
import
librosa
import
librosa
def
_test_torchscript_functional_shape
(
py_method
,
*
args
,
**
kwargs
):
jit_method
=
torch
.
jit
.
script
(
py_method
)
jit_out
=
jit_method
(
*
args
,
**
kwargs
)
py_out
=
py_method
(
*
args
,
**
kwargs
)
assert
jit_out
.
shape
==
py_out
.
shape
return
jit_out
,
py_out
def
_test_torchscript_functional
(
py_method
,
*
args
,
**
kwargs
):
jit_out
,
py_out
=
_test_torchscript_functional_shape
(
py_method
,
*
args
,
**
kwargs
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
class
TestFunctional
(
unittest
.
TestCase
):
class
TestFunctional
(
unittest
.
TestCase
):
data_sizes
=
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]
data_sizes
=
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]
number_of_trials
=
100
number_of_trials
=
100
...
@@ -43,38 +28,6 @@ class TestFunctional(unittest.TestCase):
...
@@ -43,38 +28,6 @@ class TestFunctional(unittest.TestCase):
'steam-train-whistle-daniel_simon.wav'
)
'steam-train-whistle-daniel_simon.wav'
)
waveform_train
,
sr_train
=
torchaudio
.
load
(
test_filepath
)
waveform_train
,
sr_train
=
torchaudio
.
load
(
test_filepath
)
def
test_torchscript_spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
n_fft
=
400
ws
=
400
hop
=
200
pad
=
0
window
=
torch
.
hann_window
(
ws
)
power
=
2
normalize
=
False
_test_torchscript_functional
(
F
.
spectrogram
,
tensor
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
)
def
test_torchscript_griffinlim
(
self
):
tensor
=
torch
.
rand
((
1
,
201
,
6
))
n_fft
=
400
ws
=
400
hop
=
200
window
=
torch
.
hann_window
(
ws
)
power
=
2
normalize
=
False
momentum
=
0.99
n_iter
=
32
length
=
1000
init
=
0
_test_torchscript_functional
(
F
.
griffinlim
,
tensor
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
,
n_iter
,
momentum
,
length
,
0
)
@
unittest
.
skipIf
(
not
IMPORT_LIBROSA
,
'Librosa not available'
)
@
unittest
.
skipIf
(
not
IMPORT_LIBROSA
,
'Librosa not available'
)
def
test_griffinlim
(
self
):
def
test_griffinlim
(
self
):
...
@@ -138,26 +91,10 @@ class TestFunctional(unittest.TestCase):
...
@@ -138,26 +91,10 @@ class TestFunctional(unittest.TestCase):
[
0.5
,
1.0
,
1.0
,
0.5
]]])
[
0.5
,
1.0
,
1.0
,
0.5
]]])
self
.
_test_compute_deltas
(
specgram
,
expected
)
self
.
_test_compute_deltas
(
specgram
,
expected
)
def
test_compute_deltas_randn
(
self
):
channel
=
13
n_mfcc
=
channel
*
3
time
=
1021
win_length
=
2
*
7
+
1
specgram
=
torch
.
randn
(
channel
,
n_mfcc
,
time
)
computed
=
F
.
compute_deltas
(
specgram
,
win_length
=
win_length
)
self
.
assertTrue
(
computed
.
shape
==
specgram
.
shape
,
(
computed
.
shape
,
specgram
.
shape
))
_test_torchscript_functional
(
F
.
compute_deltas
,
specgram
,
win_length
=
win_length
)
def
test_batch_pitch
(
self
):
def
test_batch_pitch
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
self
.
_test_batch
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
self
.
_test_batch
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
def
test_jit_pitch
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
_test_torchscript_functional
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
def
_compare_estimate
(
self
,
sound
,
estimate
,
atol
=
1e-6
,
rtol
=
1e-8
):
def
_compare_estimate
(
self
,
sound
,
estimate
,
atol
=
1e-6
,
rtol
=
1e-8
):
# trim sound for case when constructed signal is shorter than original
# trim sound for case when constructed signal is shorter than original
sound
=
sound
[...,
:
estimate
.
size
(
-
1
)]
sound
=
sound
[...,
:
estimate
.
size
(
-
1
)]
...
@@ -568,33 +505,6 @@ class TestFunctional(unittest.TestCase):
...
@@ -568,33 +505,6 @@ class TestFunctional(unittest.TestCase):
torch
.
random
.
manual_seed
(
42
)
torch
.
random
.
manual_seed
(
42
)
computed
=
functional
(
tensors
.
clone
(),
*
args
,
**
kwargs
)
computed
=
functional
(
tensors
.
clone
(),
*
args
,
**
kwargs
)
def
test_torchscript_create_fb_matrix
(
self
):
n_stft
=
100
f_min
=
0.0
f_max
=
20.0
n_mels
=
10
sample_rate
=
16000
_test_torchscript_functional
(
F
.
create_fb_matrix
,
n_stft
,
f_min
,
f_max
,
n_mels
,
sample_rate
)
def
test_torchscript_amplitude_to_DB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
multiplier
=
10.0
amin
=
1e-10
db_multiplier
=
0.0
top_db
=
80.0
_test_torchscript_functional
(
F
.
amplitude_to_DB
,
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
)
def
test_torchscript_DB_to_amplitude
(
self
):
x
=
torch
.
rand
((
1
,
100
))
ref
=
1.
power
=
1.
_test_torchscript_functional
(
F
.
DB_to_amplitude
,
x
,
ref
,
power
)
def
test_DB_to_amplitude
(
self
):
def
test_DB_to_amplitude
(
self
):
# Make some noise
# Make some noise
x
=
torch
.
rand
(
1000
)
x
=
torch
.
rand
(
1000
)
...
@@ -661,66 +571,6 @@ class TestFunctional(unittest.TestCase):
...
@@ -661,66 +571,6 @@ class TestFunctional(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
ta_out
,
lr_out
,
atol
=
5e-5
))
self
.
assertTrue
(
torch
.
allclose
(
ta_out
,
lr_out
,
atol
=
5e-5
))
def
test_torchscript_create_dct
(
self
):
n_mfcc
=
40
n_mels
=
128
norm
=
"ortho"
_test_torchscript_functional
(
F
.
create_dct
,
n_mfcc
,
n_mels
,
norm
)
def
test_torchscript_mu_law_encoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
qc
=
256
_test_torchscript_functional
(
F
.
mu_law_encoding
,
tensor
,
qc
)
def
test_torchscript_mu_law_decoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
qc
=
256
_test_torchscript_functional
(
F
.
mu_law_decoding
,
tensor
,
qc
)
def
test_torchscript_complex_norm
(
self
):
complex_tensor
=
torch
.
randn
(
1
,
2
,
1025
,
400
,
2
)
power
=
2
_test_torchscript_functional
(
F
.
complex_norm
,
complex_tensor
,
power
)
def
test_mask_along_axis
(
self
):
specgram
=
torch
.
randn
(
2
,
1025
,
400
)
mask_param
=
100
mask_value
=
30.
axis
=
2
_test_torchscript_functional
(
F
.
mask_along_axis
,
specgram
,
mask_param
,
mask_value
,
axis
)
def
test_mask_along_axis_iid
(
self
):
specgrams
=
torch
.
randn
(
4
,
2
,
1025
,
400
)
mask_param
=
100
mask_value
=
30.
axis
=
2
_test_torchscript_functional
(
F
.
mask_along_axis_iid
,
specgrams
,
mask_param
,
mask_value
,
axis
)
def
test_torchscript_gain
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
gainDB
=
2.0
_test_torchscript_functional
(
F
.
gain
,
tensor
,
gainDB
)
def
test_torchscript_dither
(
self
):
tensor
=
torch
.
rand
((
2
,
1000
))
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"RPDF"
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"GPDF"
)
def
_num_stft_bins
(
signal_len
,
fft_len
,
hop_length
,
pad
):
def
_num_stft_bins
(
signal_len
,
fft_len
,
hop_length
,
pad
):
return
(
signal_len
+
2
*
pad
-
fft_len
+
hop_length
)
//
hop_length
return
(
signal_len
+
2
*
pad
-
fft_len
+
hop_length
)
//
hop_length
...
...
test/test_torchscript_consistency.py
0 → 100644
View file @
b95d60c2
"""Test suites for jit-ability and its numerical compatibility"""
import
os
import
unittest
import
torch
import
torchaudio
import
torchaudio.functional
as
F
import
common_utils
def
_test_torchscript_functional_shape
(
py_method
,
*
args
,
**
kwargs
):
jit_method
=
torch
.
jit
.
script
(
py_method
)
jit_out
=
jit_method
(
*
args
,
**
kwargs
)
py_out
=
py_method
(
*
args
,
**
kwargs
)
assert
jit_out
.
shape
==
py_out
.
shape
return
jit_out
,
py_out
def
_test_torchscript_functional
(
py_method
,
*
args
,
**
kwargs
):
jit_out
,
py_out
=
_test_torchscript_functional_shape
(
py_method
,
*
args
,
**
kwargs
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
class
TestFunctional
(
unittest
.
TestCase
):
"""Test functions in `functional` module."""
def
test_spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
n_fft
=
400
ws
=
400
hop
=
200
pad
=
0
window
=
torch
.
hann_window
(
ws
)
power
=
2
normalize
=
False
_test_torchscript_functional
(
F
.
spectrogram
,
tensor
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
)
def
test_griffinlim
(
self
):
tensor
=
torch
.
rand
((
1
,
201
,
6
))
n_fft
=
400
ws
=
400
hop
=
200
window
=
torch
.
hann_window
(
ws
)
power
=
2
normalize
=
False
momentum
=
0.99
n_iter
=
32
length
=
1000
_test_torchscript_functional
(
F
.
griffinlim
,
tensor
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
,
n_iter
,
momentum
,
length
,
0
)
def
test_compute_deltas
(
self
):
channel
=
13
n_mfcc
=
channel
*
3
time
=
1021
win_length
=
2
*
7
+
1
specgram
=
torch
.
randn
(
channel
,
n_mfcc
,
time
)
_test_torchscript_functional
(
F
.
compute_deltas
,
specgram
,
win_length
=
win_length
)
def
test_detect_pitch_frequency
(
self
):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.mp3'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
)
_test_torchscript_functional
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
def
test_create_fb_matrix
(
self
):
n_stft
=
100
f_min
=
0.0
f_max
=
20.0
n_mels
=
10
sample_rate
=
16000
_test_torchscript_functional
(
F
.
create_fb_matrix
,
n_stft
,
f_min
,
f_max
,
n_mels
,
sample_rate
)
def
test_amplitude_to_DB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
multiplier
=
10.0
amin
=
1e-10
db_multiplier
=
0.0
top_db
=
80.0
_test_torchscript_functional
(
F
.
amplitude_to_DB
,
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
)
def
test_DB_to_amplitude
(
self
):
x
=
torch
.
rand
((
1
,
100
))
ref
=
1.
power
=
1.
_test_torchscript_functional
(
F
.
DB_to_amplitude
,
x
,
ref
,
power
)
def
test_create_dct
(
self
):
n_mfcc
=
40
n_mels
=
128
norm
=
"ortho"
_test_torchscript_functional
(
F
.
create_dct
,
n_mfcc
,
n_mels
,
norm
)
def
test_mu_law_encoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
qc
=
256
_test_torchscript_functional
(
F
.
mu_law_encoding
,
tensor
,
qc
)
def
test_mu_law_decoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
qc
=
256
_test_torchscript_functional
(
F
.
mu_law_decoding
,
tensor
,
qc
)
def
test_complex_norm
(
self
):
complex_tensor
=
torch
.
randn
(
1
,
2
,
1025
,
400
,
2
)
power
=
2
_test_torchscript_functional
(
F
.
complex_norm
,
complex_tensor
,
power
)
def
test_mask_along_axis
(
self
):
specgram
=
torch
.
randn
(
2
,
1025
,
400
)
mask_param
=
100
mask_value
=
30.
axis
=
2
_test_torchscript_functional
(
F
.
mask_along_axis
,
specgram
,
mask_param
,
mask_value
,
axis
)
def
test_mask_along_axis_iid
(
self
):
specgrams
=
torch
.
randn
(
4
,
2
,
1025
,
400
)
mask_param
=
100
mask_value
=
30.
axis
=
2
_test_torchscript_functional
(
F
.
mask_along_axis_iid
,
specgrams
,
mask_param
,
mask_value
,
axis
)
def
test_gain
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
gainDB
=
2.0
_test_torchscript_functional
(
F
.
gain
,
tensor
,
gainDB
)
def
test_dither
(
self
):
tensor
=
torch
.
rand
((
2
,
1000
))
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"RPDF"
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"GPDF"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment