Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
413bd18e
Unverified
Commit
413bd18e
authored
Apr 02, 2020
by
moto
Committed by
GitHub
Apr 02, 2020
Browse files
Extract JIT tests from test_transforms to the dedicated test module (#496)
parent
eb5b5a02
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
97 deletions
+107
-97
test/test_torchscript_consistency.py
test/test_torchscript_consistency.py
+107
-0
test/test_transforms.py
test/test_transforms.py
+0
-97
No files found.
test/test_torchscript_consistency.py
View file @
413bd18e
...
@@ -5,6 +5,7 @@ import unittest
...
@@ -5,6 +5,7 @@ import unittest
import
torch
import
torch
import
torchaudio
import
torchaudio
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
import
torchaudio.transforms
import
common_utils
import
common_utils
...
@@ -149,3 +150,109 @@ class TestFunctional(unittest.TestCase):
...
@@ -149,3 +150,109 @@ class TestFunctional(unittest.TestCase):
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"RPDF"
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"RPDF"
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"GPDF"
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"GPDF"
)
RUN_CUDA
=
torch
.
cuda
.
is_available
()
print
(
"Run test with cuda:"
,
RUN_CUDA
)
def
_test_script_module
(
f
,
tensor
,
*
args
,
**
kwargs
):
py_method
=
f
(
*
args
,
**
kwargs
)
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
if
RUN_CUDA
:
tensor
=
tensor
.
to
(
"cuda"
)
py_method
=
py_method
.
cuda
()
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
class
TestTransforms
(
unittest
.
TestCase
):
def
test_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
Spectrogram
,
tensor
)
def
test_GriffinLim
(
self
):
tensor
=
torch
.
rand
((
1
,
201
,
6
))
_test_script_module
(
torchaudio
.
transforms
.
GriffinLim
,
tensor
,
length
=
1000
,
rand_init
=
False
)
def
test_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
_test_script_module
(
torchaudio
.
transforms
.
AmplitudeToDB
,
spec
)
def
test_MelScale
(
self
):
spec_f
=
torch
.
rand
((
1
,
6
,
201
))
_test_script_module
(
torchaudio
.
transforms
.
MelScale
,
spec_f
)
def
test_MelSpectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
MelSpectrogram
,
tensor
)
def
test_MFCC
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
MFCC
,
tensor
)
def
test_Resample
(
self
):
tensor
=
torch
.
rand
((
2
,
1000
))
sample_rate
=
100.
sample_rate_2
=
50.
_test_script_module
(
torchaudio
.
transforms
.
Resample
,
tensor
,
sample_rate
,
sample_rate_2
)
def
test_ComplexNorm
(
self
):
tensor
=
torch
.
rand
((
1
,
2
,
201
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
ComplexNorm
,
tensor
)
def
test_MuLawEncoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
torchaudio
.
transforms
.
MuLawEncoding
,
tensor
)
def
test_MuLawDecoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
torchaudio
.
transforms
.
MuLawDecoding
,
tensor
)
def
test_TimeStretch
(
self
):
n_freq
=
400
hop_length
=
512
fixed_rate
=
1.3
tensor
=
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
TimeStretch
,
tensor
,
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
)
def
test_Fade
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
fade_in_len
=
3000
fade_out_len
=
3000
_test_script_module
(
torchaudio
.
transforms
.
Fade
,
waveform
,
fade_in_len
,
fade_out_len
)
def
test_FrequencyMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
FrequencyMasking
,
tensor
,
freq_mask_param
=
60
,
iid_masks
=
False
)
def
test_TimeMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
TimeMasking
,
tensor
,
time_mask_param
=
30
,
iid_masks
=
False
)
def
test_Vol
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
_test_script_module
(
torchaudio
.
transforms
.
Vol
,
waveform
,
1.1
)
test/test_transforms.py
View file @
413bd18e
...
@@ -10,33 +10,6 @@ import torchaudio.functional as F
...
@@ -10,33 +10,6 @@ import torchaudio.functional as F
from
common_utils
import
AudioBackendScope
,
BACKENDS
,
create_temp_assets_dir
from
common_utils
import
AudioBackendScope
,
BACKENDS
,
create_temp_assets_dir
RUN_CUDA
=
torch
.
cuda
.
is_available
()
print
(
"Run test with cuda:"
,
RUN_CUDA
)
def
_test_script_module
(
f
,
tensor
,
*
args
,
**
kwargs
):
py_method
=
f
(
*
args
,
**
kwargs
)
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
if
RUN_CUDA
:
tensor
=
tensor
.
to
(
"cuda"
)
py_method
=
py_method
.
cuda
()
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
class
Tester
(
unittest
.
TestCase
):
class
Tester
(
unittest
.
TestCase
):
# create a sinewave signal for testing
# create a sinewave signal for testing
...
@@ -57,14 +30,6 @@ class Tester(unittest.TestCase):
...
@@ -57,14 +30,6 @@ class Tester(unittest.TestCase):
waveform
=
waveform
.
to
(
torch
.
get_default_dtype
())
waveform
=
waveform
.
to
(
torch
.
get_default_dtype
())
return
waveform
/
factor
return
waveform
/
factor
def
test_scriptmodule_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
transforms
.
Spectrogram
,
tensor
)
def
test_scriptmodule_GriffinLim
(
self
):
tensor
=
torch
.
rand
((
1
,
201
,
6
))
_test_script_module
(
transforms
.
GriffinLim
,
tensor
,
length
=
1000
,
rand_init
=
False
)
def
test_mu_law_companding
(
self
):
def
test_mu_law_companding
(
self
):
quantization_channels
=
256
quantization_channels
=
256
...
@@ -79,10 +44,6 @@ class Tester(unittest.TestCase):
...
@@ -79,10 +44,6 @@ class Tester(unittest.TestCase):
waveform_exp
=
transforms
.
MuLawDecoding
(
quantization_channels
)(
waveform_mu
)
waveform_exp
=
transforms
.
MuLawDecoding
(
quantization_channels
)(
waveform_mu
)
self
.
assertTrue
(
waveform_exp
.
min
()
>=
-
1.
and
waveform_exp
.
max
()
<=
1.
)
self
.
assertTrue
(
waveform_exp
.
min
()
>=
-
1.
and
waveform_exp
.
max
()
<=
1.
)
def
test_scriptmodule_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
_test_script_module
(
transforms
.
AmplitudeToDB
,
spec
)
def
test_batch_AmplitudeToDB
(
self
):
def
test_batch_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
spec
=
torch
.
rand
((
6
,
201
))
...
@@ -106,10 +67,6 @@ class Tester(unittest.TestCase):
...
@@ -106,10 +67,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
mag_to_db_torch
,
power_to_db_torch
))
self
.
assertTrue
(
torch
.
allclose
(
mag_to_db_torch
,
power_to_db_torch
))
def
test_scriptmodule_MelScale
(
self
):
spec_f
=
torch
.
rand
((
1
,
6
,
201
))
_test_script_module
(
transforms
.
MelScale
,
spec_f
)
def
test_melscale_load_save
(
self
):
def
test_melscale_load_save
(
self
):
specgram
=
torch
.
ones
(
1
,
1000
,
100
)
specgram
=
torch
.
ones
(
1
,
1000
,
100
)
melscale_transform
=
transforms
.
MelScale
()
melscale_transform
=
transforms
.
MelScale
()
...
@@ -124,10 +81,6 @@ class Tester(unittest.TestCase):
...
@@ -124,10 +81,6 @@ class Tester(unittest.TestCase):
self
.
assertEqual
(
fb_copy
.
size
(),
(
1000
,
128
))
self
.
assertEqual
(
fb_copy
.
size
(),
(
1000
,
128
))
self
.
assertTrue
(
torch
.
allclose
(
fb
,
fb_copy
))
self
.
assertTrue
(
torch
.
allclose
(
fb
,
fb_copy
))
def
test_scriptmodule_MelSpectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
transforms
.
MelSpectrogram
,
tensor
)
def
test_melspectrogram_load_save
(
self
):
def
test_melspectrogram_load_save
(
self
):
waveform
=
self
.
waveform
.
float
()
waveform
=
self
.
waveform
.
float
()
mel_spectrogram_transform
=
transforms
.
MelSpectrogram
()
mel_spectrogram_transform
=
transforms
.
MelSpectrogram
()
...
@@ -186,10 +139,6 @@ class Tester(unittest.TestCase):
...
@@ -186,10 +139,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
def
test_scriptmodule_MFCC
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
transforms
.
MFCC
,
tensor
)
def
test_mfcc
(
self
):
def
test_mfcc
(
self
):
audio_orig
=
self
.
waveform
.
clone
()
audio_orig
=
self
.
waveform
.
clone
()
audio_scaled
=
self
.
scale
(
audio_orig
)
# (1, 16000)
audio_scaled
=
self
.
scale
(
audio_orig
)
# (1, 16000)
...
@@ -226,13 +175,6 @@ class Tester(unittest.TestCase):
...
@@ -226,13 +175,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
torch_mfcc_norm_none
.
allclose
(
norm_check
))
self
.
assertTrue
(
torch_mfcc_norm_none
.
allclose
(
norm_check
))
def
test_scriptmodule_Resample
(
self
):
tensor
=
torch
.
rand
((
2
,
1000
))
sample_rate
=
100.
sample_rate_2
=
50.
_test_script_module
(
transforms
.
Resample
,
tensor
,
sample_rate
,
sample_rate_2
)
def
test_batch_Resample
(
self
):
def
test_batch_Resample
(
self
):
waveform
=
torch
.
randn
(
2
,
2786
)
waveform
=
torch
.
randn
(
2
,
2786
)
...
@@ -245,10 +187,6 @@ class Tester(unittest.TestCase):
...
@@ -245,10 +187,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_scriptmodule_ComplexNorm
(
self
):
tensor
=
torch
.
rand
((
1
,
2
,
201
,
2
))
_test_script_module
(
transforms
.
ComplexNorm
,
tensor
)
def
test_resample_size
(
self
):
def
test_resample_size
(
self
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
input_path
)
waveform
,
sample_rate
=
torchaudio
.
load
(
input_path
)
...
@@ -349,14 +287,6 @@ class Tester(unittest.TestCase):
...
@@ -349,14 +287,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_scriptmodule_MuLawEncoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
transforms
.
MuLawEncoding
,
tensor
)
def
test_scriptmodule_MuLawDecoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
transforms
.
MuLawDecoding
,
tensor
)
def
test_batch_mulaw
(
self
):
def
test_batch_mulaw
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
# (2, 278756), 44100
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
# (2, 278756), 44100
...
@@ -424,13 +354,6 @@ class Tester(unittest.TestCase):
...
@@ -424,13 +354,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
,
atol
=
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
,
atol
=
1e-5
))
def
test_scriptmodule_TimeStretch
(
self
):
n_freq
=
400
hop_length
=
512
fixed_rate
=
1.3
tensor
=
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
))
_test_script_module
(
transforms
.
TimeStretch
,
tensor
,
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
)
def
test_batch_TimeStretch
(
self
):
def
test_batch_TimeStretch
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
...
@@ -475,26 +398,6 @@ class Tester(unittest.TestCase):
...
@@ -475,26 +398,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_scriptmodule_Fade
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
fade_in_len
=
3000
fade_out_len
=
3000
_test_script_module
(
transforms
.
Fade
,
waveform
,
fade_in_len
,
fade_out_len
)
def
test_scriptmodule_FrequencyMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
transforms
.
FrequencyMasking
,
tensor
,
freq_mask_param
=
60
,
iid_masks
=
False
)
def
test_scriptmodule_TimeMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
transforms
.
TimeMasking
,
tensor
,
time_mask_param
=
30
,
iid_masks
=
False
)
def
test_scriptmodule_Vol
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
_test_script_module
(
transforms
.
Vol
,
waveform
,
1.1
)
def
test_batch_Vol
(
self
):
def
test_batch_Vol
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment