Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b8187b4c
Unverified
Commit
b8187b4c
authored
May 15, 2020
by
moto
Committed by
GitHub
May 15, 2020
Browse files
Adopt PyTorch's test util to batch consistency test (#643)
parent
44af0dea
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
47 deletions
+49
-47
test/test_batch_consistency.py
test/test_batch_consistency.py
+49
-47
No files found.
test/test_batch_consistency.py
View file @
b8187b4c
"""Test numerical consistency among single input and batched input."""
import
unittest
import
platform
import
torch
from
torch.testing._internal.common_utils
import
TestCase
import
torchaudio
import
torchaudio.functional
as
F
import
common_utils
def
_test_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
1
,
atol
=
1e-8
,
rtol
=
1e-5
,
seed
=
42
,
**
kwargs
):
class
TestFunctional
(
TestCase
):
"""Test functions defined in `functional` module"""
def
assert_batch_consistency
(
self
,
functional
,
tensor
,
*
args
,
batch_size
=
1
,
atol
=
1e-8
,
rtol
=
1e-5
,
seed
=
42
,
**
kwargs
):
# run then batch the result
torch
.
random
.
manual_seed
(
seed
)
expected
=
functional
(
tensor
.
clone
(),
*
args
,
**
kwargs
)
...
...
@@ -20,16 +23,15 @@ def _test_batch_consistency(functional, tensor, *args, batch_size=1, atol=1e-8,
pattern
=
[
batch_size
]
+
[
1
]
*
tensor
.
dim
()
computed
=
functional
(
tensor
.
repeat
(
pattern
),
*
args
,
**
kwargs
)
torch
.
testing
.
assert_allclose
(
computed
,
expected
,
rtol
=
rtol
,
atol
=
atol
)
self
.
assertEqual
(
computed
,
expected
,
rtol
=
rtol
,
atol
=
atol
)
def
_test_batch
(
functional
,
tensor
,
*
args
,
atol
=
1e-8
,
rtol
=
1e-5
,
seed
=
42
,
**
kwargs
):
_test_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
1
,
atol
=
atol
,
rtol
=
rtol
,
seed
=
seed
,
**
kwargs
)
_test_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
3
,
atol
=
atol
,
rtol
=
rtol
,
seed
=
seed
,
**
kwargs
)
def
assert_batch_consistencies
(
self
,
functional
,
tensor
,
*
args
,
atol
=
1e-8
,
rtol
=
1e-5
,
seed
=
42
,
**
kwargs
):
self
.
assert_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
1
,
atol
=
atol
,
rtol
=
rtol
,
seed
=
seed
,
**
kwargs
)
self
.
assert_batch_consistency
(
functional
,
tensor
,
*
args
,
batch_size
=
3
,
atol
=
atol
,
rtol
=
rtol
,
seed
=
seed
,
**
kwargs
)
class
TestFunctional
(
unittest
.
TestCase
):
"""Test functions defined in `functional` module"""
def
test_griffinlim
(
self
):
n_fft
=
400
ws
=
400
...
...
@@ -41,7 +43,7 @@ class TestFunctional(unittest.TestCase):
n_iter
=
32
length
=
1000
tensor
=
torch
.
rand
((
1
,
201
,
6
))
_test_batch
(
self
.
assert_batch_consistencies
(
F
.
griffinlim
,
tensor
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
,
n_iter
,
momentum
,
length
,
0
,
atol
=
5e-5
)
...
...
@@ -55,7 +57,7 @@ class TestFunctional(unittest.TestCase):
for
filename
in
filenames
:
filepath
=
common_utils
.
get_asset_path
(
filename
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
)
_test_batch
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
self
.
assert_batch_consistencies
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
def
test_istft
(
self
):
stft
=
torch
.
tensor
([
...
...
@@ -63,39 +65,39 @@ class TestFunctional(unittest.TestCase):
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]]
])
_test_batch
(
F
.
istft
,
stft
,
n_fft
=
4
,
length
=
4
)
self
.
assert_batch_consistencies
(
F
.
istft
,
stft
,
n_fft
=
4
,
length
=
4
)
def
test_contrast
(
self
):
waveform
=
torch
.
rand
(
2
,
100
)
-
0.5
_test_batch
(
F
.
contrast
,
waveform
,
enhancement_amount
=
80.
)
self
.
assert_batch_consistencies
(
F
.
contrast
,
waveform
,
enhancement_amount
=
80.
)
def
test_dcshift
(
self
):
waveform
=
torch
.
rand
(
2
,
100
)
-
0.5
_test_batch
(
F
.
dcshift
,
waveform
,
shift
=
0.5
,
limiter_gain
=
0.05
)
self
.
assert_batch_consistencies
(
F
.
dcshift
,
waveform
,
shift
=
0.5
,
limiter_gain
=
0.05
)
def
test_overdrive
(
self
):
waveform
=
torch
.
rand
(
2
,
100
)
-
0.5
_test_batch
(
F
.
overdrive
,
waveform
,
gain
=
45
,
colour
=
30
)
self
.
assert_batch_consistencies
(
F
.
overdrive
,
waveform
,
gain
=
45
,
colour
=
30
)
def
test_phaser
(
self
):
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
)
_test_batch
(
F
.
phaser
,
waveform
,
sample_rate
)
self
.
assert_batch_consistencies
(
F
.
phaser
,
waveform
,
sample_rate
)
def
test_sliding_window_cmn
(
self
):
waveform
=
torch
.
randn
(
2
,
1024
)
-
0.5
_test_batch
(
F
.
sliding_window_cmn
,
waveform
,
center
=
True
,
norm_vars
=
True
)
_test_batch
(
F
.
sliding_window_cmn
,
waveform
,
center
=
True
,
norm_vars
=
False
)
_test_batch
(
F
.
sliding_window_cmn
,
waveform
,
center
=
False
,
norm_vars
=
True
)
_test_batch
(
F
.
sliding_window_cmn
,
waveform
,
center
=
False
,
norm_vars
=
False
)
self
.
assert_batch_consistencies
(
F
.
sliding_window_cmn
,
waveform
,
center
=
True
,
norm_vars
=
True
)
self
.
assert_batch_consistencies
(
F
.
sliding_window_cmn
,
waveform
,
center
=
True
,
norm_vars
=
False
)
self
.
assert_batch_consistencies
(
F
.
sliding_window_cmn
,
waveform
,
center
=
False
,
norm_vars
=
True
)
self
.
assert_batch_consistencies
(
F
.
sliding_window_cmn
,
waveform
,
center
=
False
,
norm_vars
=
False
)
def
test_vad
(
self
):
filepath
=
common_utils
.
get_asset_path
(
"vad-go-mono-32000.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
)
_test_batch
(
F
.
vad
,
waveform
,
sample_rate
=
sample_rate
)
self
.
assert_batch_consistencies
(
F
.
vad
,
waveform
,
sample_rate
=
sample_rate
)
class
TestTransforms
(
unittest
.
TestCase
):
class
TestTransforms
(
TestCase
):
"""Test suite for classes defined in `transforms` module"""
def
test_batch_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
...
...
@@ -106,7 +108,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed
=
torchaudio
.
transforms
.
AmplitudeToDB
()(
spec
.
repeat
(
3
,
1
,
1
))
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_Resample
(
self
):
waveform
=
torch
.
randn
(
2
,
2786
)
...
...
@@ -117,7 +119,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed
=
torchaudio
.
transforms
.
Resample
()(
waveform
.
repeat
(
3
,
1
,
1
))
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_MelScale
(
self
):
specgram
=
torch
.
randn
(
2
,
31
,
2786
)
...
...
@@ -129,7 +131,7 @@ class TestTransforms(unittest.TestCase):
computed
=
torchaudio
.
transforms
.
MelScale
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, 201, 1394)
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_InverseMelScale
(
self
):
n_mels
=
32
...
...
@@ -146,7 +148,7 @@ class TestTransforms(unittest.TestCase):
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
torch
.
testing
.
assert_allclose
(
computed
,
expected
,
atol
=
1.0
,
rtol
=
1e-5
)
self
.
assertEqual
(
computed
,
expected
,
atol
=
1.0
,
rtol
=
1e-5
)
def
test_batch_compute_deltas
(
self
):
specgram
=
torch
.
randn
(
2
,
31
,
2786
)
...
...
@@ -158,7 +160,7 @@ class TestTransforms(unittest.TestCase):
computed
=
torchaudio
.
transforms
.
ComputeDeltas
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, 201, 1394)
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_mulaw
(
self
):
test_filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
...
...
@@ -173,7 +175,7 @@ class TestTransforms(unittest.TestCase):
computed
=
torchaudio
.
transforms
.
MuLawEncoding
()(
waveform_batched
)
# shape = (3, 2, 201, 1394)
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
# Single then transform then batch
waveform_decoded
=
torchaudio
.
transforms
.
MuLawDecoding
()(
waveform_encoded
)
...
...
@@ -183,7 +185,7 @@ class TestTransforms(unittest.TestCase):
computed
=
torchaudio
.
transforms
.
MuLawDecoding
()(
computed
)
# shape = (3, 2, 201, 1394)
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_spectrogram
(
self
):
test_filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
...
...
@@ -194,7 +196,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed
=
torchaudio
.
transforms
.
Spectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_melspectrogram
(
self
):
test_filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
...
...
@@ -205,7 +207,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed
=
torchaudio
.
transforms
.
MelSpectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_mfcc
(
self
):
test_filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
...
...
@@ -216,7 +218,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed
=
torchaudio
.
transforms
.
MFCC
()(
waveform
.
repeat
(
3
,
1
,
1
))
torch
.
testing
.
assert_allclose
(
computed
,
expected
,
atol
=
1e-4
,
rtol
=
1e-5
)
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-4
,
rtol
=
1e-5
)
def
test_batch_TimeStretch
(
self
):
test_filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
...
...
@@ -250,7 +252,7 @@ class TestTransforms(unittest.TestCase):
hop_length
=
512
,
)(
complex_specgrams
.
repeat
(
3
,
1
,
1
,
1
,
1
))
torch
.
testing
.
assert_allclose
(
computed
,
expected
,
atol
=
1e-5
,
rtol
=
1e-5
)
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_batch_Fade
(
self
):
test_filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
...
...
@@ -263,7 +265,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed
=
torchaudio
.
transforms
.
Fade
(
fade_in_len
,
fade_out_len
)(
waveform
.
repeat
(
3
,
1
,
1
))
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_Vol
(
self
):
test_filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
...
...
@@ -274,7 +276,7 @@ class TestTransforms(unittest.TestCase):
# Batch then transform
computed
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
.
repeat
(
3
,
1
,
1
))
torch
.
testing
.
assert_allclose
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment