Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
413bd18e
"docs/tutorials/config.md" did not exist on "809d1a9b009cb4333f9414427a079854f7f50852"
Unverified
Commit
413bd18e
authored
Apr 02, 2020
by
moto
Committed by
GitHub
Apr 02, 2020
Browse files
Extract JIT tests from test_transforms to the dedicated test module (#496)
parent
eb5b5a02
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
97 deletions
+107
-97
test/test_torchscript_consistency.py
test/test_torchscript_consistency.py
+107
-0
test/test_transforms.py
test/test_transforms.py
+0
-97
No files found.
test/test_torchscript_consistency.py
View file @
413bd18e
...
@@ -5,6 +5,7 @@ import unittest
...
@@ -5,6 +5,7 @@ import unittest
import
torch
import
torch
import
torchaudio
import
torchaudio
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
import
torchaudio.transforms
import
common_utils
import
common_utils
...
@@ -149,3 +150,109 @@ class TestFunctional(unittest.TestCase):
...
@@ -149,3 +150,109 @@ class TestFunctional(unittest.TestCase):
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"RPDF"
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"RPDF"
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"GPDF"
)
_test_torchscript_functional_shape
(
F
.
dither
,
tensor
,
"GPDF"
)
RUN_CUDA
=
torch
.
cuda
.
is_available
()
print
(
"Run test with cuda:"
,
RUN_CUDA
)
def
_test_script_module
(
f
,
tensor
,
*
args
,
**
kwargs
):
py_method
=
f
(
*
args
,
**
kwargs
)
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
if
RUN_CUDA
:
tensor
=
tensor
.
to
(
"cuda"
)
py_method
=
py_method
.
cuda
()
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
class
TestTransforms
(
unittest
.
TestCase
):
def
test_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
Spectrogram
,
tensor
)
def
test_GriffinLim
(
self
):
tensor
=
torch
.
rand
((
1
,
201
,
6
))
_test_script_module
(
torchaudio
.
transforms
.
GriffinLim
,
tensor
,
length
=
1000
,
rand_init
=
False
)
def
test_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
_test_script_module
(
torchaudio
.
transforms
.
AmplitudeToDB
,
spec
)
def
test_MelScale
(
self
):
spec_f
=
torch
.
rand
((
1
,
6
,
201
))
_test_script_module
(
torchaudio
.
transforms
.
MelScale
,
spec_f
)
def
test_MelSpectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
MelSpectrogram
,
tensor
)
def
test_MFCC
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
MFCC
,
tensor
)
def
test_Resample
(
self
):
tensor
=
torch
.
rand
((
2
,
1000
))
sample_rate
=
100.
sample_rate_2
=
50.
_test_script_module
(
torchaudio
.
transforms
.
Resample
,
tensor
,
sample_rate
,
sample_rate_2
)
def
test_ComplexNorm
(
self
):
tensor
=
torch
.
rand
((
1
,
2
,
201
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
ComplexNorm
,
tensor
)
def
test_MuLawEncoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
torchaudio
.
transforms
.
MuLawEncoding
,
tensor
)
def
test_MuLawDecoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
torchaudio
.
transforms
.
MuLawDecoding
,
tensor
)
def
test_TimeStretch
(
self
):
n_freq
=
400
hop_length
=
512
fixed_rate
=
1.3
tensor
=
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
TimeStretch
,
tensor
,
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
)
def
test_Fade
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
fade_in_len
=
3000
fade_out_len
=
3000
_test_script_module
(
torchaudio
.
transforms
.
Fade
,
waveform
,
fade_in_len
,
fade_out_len
)
def
test_FrequencyMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
FrequencyMasking
,
tensor
,
freq_mask_param
=
60
,
iid_masks
=
False
)
def
test_TimeMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
TimeMasking
,
tensor
,
time_mask_param
=
30
,
iid_masks
=
False
)
def
test_Vol
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
_test_script_module
(
torchaudio
.
transforms
.
Vol
,
waveform
,
1.1
)
test/test_transforms.py
View file @
413bd18e
...
@@ -10,33 +10,6 @@ import torchaudio.functional as F
...
@@ -10,33 +10,6 @@ import torchaudio.functional as F
from
common_utils
import
AudioBackendScope
,
BACKENDS
,
create_temp_assets_dir
from
common_utils
import
AudioBackendScope
,
BACKENDS
,
create_temp_assets_dir
RUN_CUDA
=
torch
.
cuda
.
is_available
()
print
(
"Run test with cuda:"
,
RUN_CUDA
)
def
_test_script_module
(
f
,
tensor
,
*
args
,
**
kwargs
):
py_method
=
f
(
*
args
,
**
kwargs
)
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
if
RUN_CUDA
:
tensor
=
tensor
.
to
(
"cuda"
)
py_method
=
py_method
.
cuda
()
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
assert
torch
.
allclose
(
jit_out
,
py_out
)
class
Tester
(
unittest
.
TestCase
):
class
Tester
(
unittest
.
TestCase
):
# create a sinewave signal for testing
# create a sinewave signal for testing
...
@@ -57,14 +30,6 @@ class Tester(unittest.TestCase):
...
@@ -57,14 +30,6 @@ class Tester(unittest.TestCase):
waveform
=
waveform
.
to
(
torch
.
get_default_dtype
())
waveform
=
waveform
.
to
(
torch
.
get_default_dtype
())
return
waveform
/
factor
return
waveform
/
factor
def
test_scriptmodule_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
transforms
.
Spectrogram
,
tensor
)
def
test_scriptmodule_GriffinLim
(
self
):
tensor
=
torch
.
rand
((
1
,
201
,
6
))
_test_script_module
(
transforms
.
GriffinLim
,
tensor
,
length
=
1000
,
rand_init
=
False
)
def
test_mu_law_companding
(
self
):
def
test_mu_law_companding
(
self
):
quantization_channels
=
256
quantization_channels
=
256
...
@@ -79,10 +44,6 @@ class Tester(unittest.TestCase):
...
@@ -79,10 +44,6 @@ class Tester(unittest.TestCase):
waveform_exp
=
transforms
.
MuLawDecoding
(
quantization_channels
)(
waveform_mu
)
waveform_exp
=
transforms
.
MuLawDecoding
(
quantization_channels
)(
waveform_mu
)
self
.
assertTrue
(
waveform_exp
.
min
()
>=
-
1.
and
waveform_exp
.
max
()
<=
1.
)
self
.
assertTrue
(
waveform_exp
.
min
()
>=
-
1.
and
waveform_exp
.
max
()
<=
1.
)
def
test_scriptmodule_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
_test_script_module
(
transforms
.
AmplitudeToDB
,
spec
)
def
test_batch_AmplitudeToDB
(
self
):
def
test_batch_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
spec
=
torch
.
rand
((
6
,
201
))
...
@@ -106,10 +67,6 @@ class Tester(unittest.TestCase):
...
@@ -106,10 +67,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
mag_to_db_torch
,
power_to_db_torch
))
self
.
assertTrue
(
torch
.
allclose
(
mag_to_db_torch
,
power_to_db_torch
))
def
test_scriptmodule_MelScale
(
self
):
spec_f
=
torch
.
rand
((
1
,
6
,
201
))
_test_script_module
(
transforms
.
MelScale
,
spec_f
)
def
test_melscale_load_save
(
self
):
def
test_melscale_load_save
(
self
):
specgram
=
torch
.
ones
(
1
,
1000
,
100
)
specgram
=
torch
.
ones
(
1
,
1000
,
100
)
melscale_transform
=
transforms
.
MelScale
()
melscale_transform
=
transforms
.
MelScale
()
...
@@ -124,10 +81,6 @@ class Tester(unittest.TestCase):
...
@@ -124,10 +81,6 @@ class Tester(unittest.TestCase):
self
.
assertEqual
(
fb_copy
.
size
(),
(
1000
,
128
))
self
.
assertEqual
(
fb_copy
.
size
(),
(
1000
,
128
))
self
.
assertTrue
(
torch
.
allclose
(
fb
,
fb_copy
))
self
.
assertTrue
(
torch
.
allclose
(
fb
,
fb_copy
))
def
test_scriptmodule_MelSpectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
transforms
.
MelSpectrogram
,
tensor
)
def
test_melspectrogram_load_save
(
self
):
def
test_melspectrogram_load_save
(
self
):
waveform
=
self
.
waveform
.
float
()
waveform
=
self
.
waveform
.
float
()
mel_spectrogram_transform
=
transforms
.
MelSpectrogram
()
mel_spectrogram_transform
=
transforms
.
MelSpectrogram
()
...
@@ -186,10 +139,6 @@ class Tester(unittest.TestCase):
...
@@ -186,10 +139,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
def
test_scriptmodule_MFCC
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
transforms
.
MFCC
,
tensor
)
def
test_mfcc
(
self
):
def
test_mfcc
(
self
):
audio_orig
=
self
.
waveform
.
clone
()
audio_orig
=
self
.
waveform
.
clone
()
audio_scaled
=
self
.
scale
(
audio_orig
)
# (1, 16000)
audio_scaled
=
self
.
scale
(
audio_orig
)
# (1, 16000)
...
@@ -226,13 +175,6 @@ class Tester(unittest.TestCase):
...
@@ -226,13 +175,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
torch_mfcc_norm_none
.
allclose
(
norm_check
))
self
.
assertTrue
(
torch_mfcc_norm_none
.
allclose
(
norm_check
))
def
test_scriptmodule_Resample
(
self
):
tensor
=
torch
.
rand
((
2
,
1000
))
sample_rate
=
100.
sample_rate_2
=
50.
_test_script_module
(
transforms
.
Resample
,
tensor
,
sample_rate
,
sample_rate_2
)
def
test_batch_Resample
(
self
):
def
test_batch_Resample
(
self
):
waveform
=
torch
.
randn
(
2
,
2786
)
waveform
=
torch
.
randn
(
2
,
2786
)
...
@@ -245,10 +187,6 @@ class Tester(unittest.TestCase):
...
@@ -245,10 +187,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_scriptmodule_ComplexNorm
(
self
):
tensor
=
torch
.
rand
((
1
,
2
,
201
,
2
))
_test_script_module
(
transforms
.
ComplexNorm
,
tensor
)
def
test_resample_size
(
self
):
def
test_resample_size
(
self
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
input_path
)
waveform
,
sample_rate
=
torchaudio
.
load
(
input_path
)
...
@@ -349,14 +287,6 @@ class Tester(unittest.TestCase):
...
@@ -349,14 +287,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_scriptmodule_MuLawEncoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
transforms
.
MuLawEncoding
,
tensor
)
def
test_scriptmodule_MuLawDecoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
transforms
.
MuLawDecoding
,
tensor
)
def
test_batch_mulaw
(
self
):
def
test_batch_mulaw
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
# (2, 278756), 44100
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
# (2, 278756), 44100
...
@@ -424,13 +354,6 @@ class Tester(unittest.TestCase):
...
@@ -424,13 +354,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
,
atol
=
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
,
atol
=
1e-5
))
def
test_scriptmodule_TimeStretch
(
self
):
n_freq
=
400
hop_length
=
512
fixed_rate
=
1.3
tensor
=
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
))
_test_script_module
(
transforms
.
TimeStretch
,
tensor
,
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
)
def
test_batch_TimeStretch
(
self
):
def
test_batch_TimeStretch
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
...
@@ -475,26 +398,6 @@ class Tester(unittest.TestCase):
...
@@ -475,26 +398,6 @@ class Tester(unittest.TestCase):
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
self
.
assertTrue
(
torch
.
allclose
(
computed
,
expected
))
def
test_scriptmodule_Fade
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
fade_in_len
=
3000
fade_out_len
=
3000
_test_script_module
(
transforms
.
Fade
,
waveform
,
fade_in_len
,
fade_out_len
)
def
test_scriptmodule_FrequencyMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
transforms
.
FrequencyMasking
,
tensor
,
freq_mask_param
=
60
,
iid_masks
=
False
)
def
test_scriptmodule_TimeMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
transforms
.
TimeMasking
,
tensor
,
time_mask_param
=
30
,
iid_masks
=
False
)
def
test_scriptmodule_Vol
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
_test_script_module
(
transforms
.
Vol
,
waveform
,
1.1
)
def
test_batch_Vol
(
self
):
def
test_batch_Vol
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment