Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a9c4d0a8
Unverified
Commit
a9c4d0a8
authored
Apr 07, 2020
by
moto
Committed by
GitHub
Apr 07, 2020
Browse files
Refactor torchscript test helper function (#521)
parent
657f0a02
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
39 deletions
+36
-39
test/test_torchscript_consistency.py
test/test_torchscript_consistency.py
+36
-39
No files found.
test/test_torchscript_consistency.py
View file @
a9c4d0a8
...
...
@@ -10,19 +10,16 @@ import torchaudio.transforms
import
common_utils
def
_
test_torchscrip
t_functional_
shape
(
py_method
,
*
args
,
**
kwargs
):
def
_
asser
t_functional_
consistency
(
py_method
,
*
args
,
shape_only
=
False
,
**
kwargs
):
jit_method
=
torch
.
jit
.
script
(
py_method
)
jit_out
=
jit_method
(
*
args
,
**
kwargs
)
py_out
=
py_method
(
*
args
,
**
kwargs
)
assert
jit_out
.
shape
==
py_out
.
shape
return
jit_out
,
py_out
def
_test_torchscript_functional
(
py_method
,
*
args
,
**
kwargs
):
jit_out
,
py_out
=
_test_torchscript_functional_shape
(
py_method
,
*
args
,
**
kwargs
)
torch
.
testing
.
assert_allclose
(
jit_out
,
py_out
)
if
shape_only
:
assert
jit_out
.
shape
==
py_out
.
shape
,
(
jit_out
.
shape
,
py_out
.
shape
)
else
:
torch
.
testing
.
assert_allclose
(
jit_out
,
py_out
)
def
_test_lfilter
(
waveform
):
...
...
@@ -58,7 +55,7 @@ def _test_lfilter(waveform):
],
device
=
waveform
.
device
,
)
_
test_torchscript_functional
(
F
.
lfilter
,
waveform
,
a_coeffs
,
b_coeffs
)
_
assert_functional_consistency
(
F
.
lfilter
,
waveform
,
a_coeffs
,
b_coeffs
)
class
TestFunctional
(
unittest
.
TestCase
):
...
...
@@ -73,7 +70,7 @@ class TestFunctional(unittest.TestCase):
power
=
2
normalize
=
False
_
test_torchscript_functional
(
_
assert_functional_consistency
(
F
.
spectrogram
,
tensor
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
)
...
...
@@ -89,7 +86,7 @@ class TestFunctional(unittest.TestCase):
n_iter
=
32
length
=
1000
_
test_torchscript_functional
(
_
assert_functional_consistency
(
F
.
griffinlim
,
tensor
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
,
n_iter
,
momentum
,
length
,
0
)
...
...
@@ -100,13 +97,13 @@ class TestFunctional(unittest.TestCase):
win_length
=
2
*
7
+
1
specgram
=
torch
.
randn
(
channel
,
n_mfcc
,
time
)
_
test_torchscript_functional
(
F
.
compute_deltas
,
specgram
,
win_length
=
win_length
)
_
assert_functional_consistency
(
F
.
compute_deltas
,
specgram
,
win_length
=
win_length
)
def
test_detect_pitch_frequency
(
self
):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.mp3'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
)
_
test_torchscript_functional
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
_
assert_functional_consistency
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
def
test_create_fb_matrix
(
self
):
n_stft
=
100
...
...
@@ -115,7 +112,7 @@ class TestFunctional(unittest.TestCase):
n_mels
=
10
sample_rate
=
16000
_
test_torchscript_functional
(
F
.
create_fb_matrix
,
n_stft
,
f_min
,
f_max
,
n_mels
,
sample_rate
)
_
assert_functional_consistency
(
F
.
create_fb_matrix
,
n_stft
,
f_min
,
f_max
,
n_mels
,
sample_rate
)
def
test_amplitude_to_DB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
...
...
@@ -124,39 +121,39 @@ class TestFunctional(unittest.TestCase):
db_multiplier
=
0.0
top_db
=
80.0
_
test_torchscript_functional
(
F
.
amplitude_to_DB
,
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
)
_
assert_functional_consistency
(
F
.
amplitude_to_DB
,
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
)
def
test_DB_to_amplitude
(
self
):
x
=
torch
.
rand
((
1
,
100
))
ref
=
1.
power
=
1.
_
test_torchscript_functional
(
F
.
DB_to_amplitude
,
x
,
ref
,
power
)
_
assert_functional_consistency
(
F
.
DB_to_amplitude
,
x
,
ref
,
power
)
def
test_create_dct
(
self
):
n_mfcc
=
40
n_mels
=
128
norm
=
"ortho"
_
test_torchscript_functional
(
F
.
create_dct
,
n_mfcc
,
n_mels
,
norm
)
_
assert_functional_consistency
(
F
.
create_dct
,
n_mfcc
,
n_mels
,
norm
)
def
test_mu_law_encoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
qc
=
256
_
test_torchscript_functional
(
F
.
mu_law_encoding
,
tensor
,
qc
)
_
assert_functional_consistency
(
F
.
mu_law_encoding
,
tensor
,
qc
)
def
test_mu_law_decoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
qc
=
256
_
test_torchscript_functional
(
F
.
mu_law_decoding
,
tensor
,
qc
)
_
assert_functional_consistency
(
F
.
mu_law_decoding
,
tensor
,
qc
)
def
test_complex_norm
(
self
):
complex_tensor
=
torch
.
randn
(
1
,
2
,
1025
,
400
,
2
)
power
=
2
_
test_torchscript_functional
(
F
.
complex_norm
,
complex_tensor
,
power
)
_
assert_functional_consistency
(
F
.
complex_norm
,
complex_tensor
,
power
)
def
test_mask_along_axis
(
self
):
specgram
=
torch
.
randn
(
2
,
1025
,
400
)
...
...
@@ -164,7 +161,7 @@ class TestFunctional(unittest.TestCase):
mask_value
=
30.
axis
=
2
_
test_torchscript_functional
(
F
.
mask_along_axis
,
specgram
,
mask_param
,
mask_value
,
axis
)
_
assert_functional_consistency
(
F
.
mask_along_axis
,
specgram
,
mask_param
,
mask_value
,
axis
)
def
test_mask_along_axis_iid
(
self
):
specgrams
=
torch
.
randn
(
4
,
2
,
1025
,
400
)
...
...
@@ -172,20 +169,20 @@ class TestFunctional(unittest.TestCase):
mask_value
=
30.
axis
=
2
_
test_torchscript_functional
(
F
.
mask_along_axis_iid
,
specgrams
,
mask_param
,
mask_value
,
axis
)
_
assert_functional_consistency
(
F
.
mask_along_axis_iid
,
specgrams
,
mask_param
,
mask_value
,
axis
)
def
test_gain
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
gainDB
=
2.0
_
test_torchscript_functional
(
F
.
gain
,
tensor
,
gainDB
)
_
assert_functional_consistency
(
F
.
gain
,
tensor
,
gainDB
)
def
test_dither
(
self
):
tensor
=
torch
.
rand
((
2
,
1000
))
_
test_torchscrip
t_functional_
shape
(
F
.
dither
,
tensor
)
_
test_torchscrip
t_functional_
shape
(
F
.
dither
,
tensor
,
"RPDF"
)
_
test_torchscrip
t_functional_
shape
(
F
.
dither
,
tensor
,
"GPDF"
)
_
asser
t_functional_
consistency
(
F
.
dither
,
tensor
,
shape_only
=
True
)
_
asser
t_functional_
consistency
(
F
.
dither
,
tensor
,
"RPDF"
,
shape_only
=
True
)
_
asser
t_functional_
consistency
(
F
.
dither
,
tensor
,
"GPDF"
,
shape_only
=
True
)
def
test_lfilter
(
self
):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'whitenoise.wav'
)
...
...
@@ -203,14 +200,14 @@ class TestFunctional(unittest.TestCase):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'whitenoise.wav'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
F
.
lowpass_biquad
,
waveform
,
sample_rate
,
cutoff_freq
)
_
assert_functional_consistency
(
F
.
lowpass_biquad
,
waveform
,
sample_rate
,
cutoff_freq
)
def
test_highpass
(
self
):
cutoff_freq
=
2000
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'whitenoise.wav'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
F
.
highpass_biquad
,
waveform
,
sample_rate
,
cutoff_freq
)
_
assert_functional_consistency
(
F
.
highpass_biquad
,
waveform
,
sample_rate
,
cutoff_freq
)
def
test_allpass
(
self
):
central_freq
=
1000
...
...
@@ -218,7 +215,7 @@ class TestFunctional(unittest.TestCase):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'whitenoise.wav'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
F
.
allpass_biquad
,
waveform
,
sample_rate
,
central_freq
,
q
)
_
assert_functional_consistency
(
F
.
allpass_biquad
,
waveform
,
sample_rate
,
central_freq
,
q
)
def
test_bandpass_with_csg
(
self
):
central_freq
=
1000
...
...
@@ -227,7 +224,7 @@ class TestFunctional(unittest.TestCase):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
_
assert_functional_consistency
(
F
.
bandpass_biquad
,
waveform
,
sample_rate
,
central_freq
,
q
,
const_skirt_gain
)
def
test_bandpass_withou_csg
(
self
):
...
...
@@ -237,7 +234,7 @@ class TestFunctional(unittest.TestCase):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
_
assert_functional_consistency
(
F
.
bandpass_biquad
,
waveform
,
sample_rate
,
central_freq
,
q
,
const_skirt_gain
)
def
test_bandreject
(
self
):
...
...
@@ -246,7 +243,7 @@ class TestFunctional(unittest.TestCase):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
_
assert_functional_consistency
(
F
.
bandreject_biquad
,
waveform
,
sample_rate
,
central_freq
,
q
)
def
test_band_with_noise
(
self
):
...
...
@@ -256,7 +253,7 @@ class TestFunctional(unittest.TestCase):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
F
.
band_biquad
,
waveform
,
sample_rate
,
central_freq
,
q
,
noise
)
_
assert_functional_consistency
(
F
.
band_biquad
,
waveform
,
sample_rate
,
central_freq
,
q
,
noise
)
def
test_band_without_noise
(
self
):
central_freq
=
1000
...
...
@@ -265,7 +262,7 @@ class TestFunctional(unittest.TestCase):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
F
.
band_biquad
,
waveform
,
sample_rate
,
central_freq
,
q
,
noise
)
_
assert_functional_consistency
(
F
.
band_biquad
,
waveform
,
sample_rate
,
central_freq
,
q
,
noise
)
def
test_treble
(
self
):
gain
=
40
...
...
@@ -274,17 +271,17 @@ class TestFunctional(unittest.TestCase):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
F
.
treble_biquad
,
waveform
,
sample_rate
,
gain
,
central_freq
,
q
)
_
assert_functional_consistency
(
F
.
treble_biquad
,
waveform
,
sample_rate
,
gain
,
central_freq
,
q
)
def
test_deemph
(
self
):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
F
.
deemph_biquad
,
waveform
,
sample_rate
)
_
assert_functional_consistency
(
F
.
deemph_biquad
,
waveform
,
sample_rate
)
def
test_riaa
(
self
):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
F
.
riaa_biquad
,
waveform
,
sample_rate
)
_
assert_functional_consistency
(
F
.
riaa_biquad
,
waveform
,
sample_rate
)
def
test_equalizer
(
self
):
center_freq
=
300
...
...
@@ -293,7 +290,7 @@ class TestFunctional(unittest.TestCase):
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
_
assert_functional_consistency
(
F
.
equalizer_biquad
,
waveform
,
sample_rate
,
center_freq
,
gain
,
q
)
def
test_perf_biquad_filtering
(
self
):
...
...
@@ -301,7 +298,7 @@ class TestFunctional(unittest.TestCase):
b
=
torch
.
tensor
([
0.4
,
0.2
,
0.9
])
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
"assets"
,
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
_
test_torchscript_functional
(
F
.
lfilter
,
waveform
,
a
,
b
)
_
assert_functional_consistency
(
F
.
lfilter
,
waveform
,
a
,
b
)
RUN_CUDA
=
torch
.
cuda
.
is_available
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment