Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a91cae7c
Unverified
Commit
a91cae7c
authored
Apr 02, 2020
by
moto
Committed by
GitHub
Apr 02, 2020
Browse files
Extract batch test from test_functional and move to the dedicated module (#491)
parent
413bd18e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
104 additions
and
92 deletions
+104
-92
test/test_batch_consistency.py
test/test_batch_consistency.py
+104
-0
test/test_functional.py
test/test_functional.py
+0
-92
No files found.
test/test_batch_consistency.py
0 → 100644
View file @
a91cae7c
"""Test numerical consistency among single input and batched input."""
import
os
import
unittest
import
torch
import
torchaudio
import
torchaudio.functional
as
F
import
common_utils
def
_test_batch_shape
(
functional
,
tensor
,
*
args
,
**
kwargs
):
kwargs_compare
=
{}
if
'atol'
in
kwargs
:
atol
=
kwargs
[
'atol'
]
del
kwargs
[
'atol'
]
kwargs_compare
[
'atol'
]
=
atol
if
'rtol'
in
kwargs
:
rtol
=
kwargs
[
'rtol'
]
del
kwargs
[
'rtol'
]
kwargs_compare
[
'rtol'
]
=
rtol
# Single then transform then batch
torch
.
random
.
manual_seed
(
42
)
expected
=
functional
(
tensor
.
clone
(),
*
args
,
**
kwargs
)
expected
=
expected
.
unsqueeze
(
0
).
unsqueeze
(
0
)
# 1-Batch then transform
tensors
=
tensor
.
unsqueeze
(
0
).
unsqueeze
(
0
)
torch
.
random
.
manual_seed
(
42
)
computed
=
functional
(
tensors
.
clone
(),
*
args
,
**
kwargs
)
assert
expected
.
shape
==
computed
.
shape
,
(
expected
.
shape
,
computed
.
shape
)
assert
torch
.
allclose
(
expected
,
computed
,
**
kwargs_compare
)
return
tensors
,
expected
def
_test_batch
(
functional
,
tensor
,
*
args
,
**
kwargs
):
tensors
,
expected
=
_test_batch_shape
(
functional
,
tensor
,
*
args
,
**
kwargs
)
kwargs_compare
=
{}
if
'atol'
in
kwargs
:
atol
=
kwargs
[
'atol'
]
del
kwargs
[
'atol'
]
kwargs_compare
[
'atol'
]
=
atol
if
'rtol'
in
kwargs
:
rtol
=
kwargs
[
'rtol'
]
del
kwargs
[
'rtol'
]
kwargs_compare
[
'rtol'
]
=
rtol
# 3-Batch then transform
ind
=
[
3
]
+
[
1
]
*
(
int
(
tensors
.
dim
())
-
1
)
tensors
=
tensor
.
repeat
(
*
ind
)
ind
=
[
3
]
+
[
1
]
*
(
int
(
expected
.
dim
())
-
1
)
expected
=
expected
.
repeat
(
*
ind
)
torch
.
random
.
manual_seed
(
42
)
computed
=
functional
(
tensors
.
clone
(),
*
args
,
**
kwargs
)
class
TestFunctional
(
unittest
.
TestCase
):
"""Test functions defined in `functional` module"""
def
test_griffinlim
(
self
):
n_fft
=
400
ws
=
400
hop
=
200
window
=
torch
.
hann_window
(
ws
)
power
=
2
normalize
=
False
momentum
=
0.99
n_iter
=
32
length
=
1000
tensor
=
torch
.
rand
((
1
,
201
,
6
))
_test_batch
(
F
.
griffinlim
,
tensor
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
,
n_iter
,
momentum
,
length
,
0
,
atol
=
5e-5
)
def
test_detect_pitch_frequency
(
self
):
filenames
=
[
'steam-train-whistle-daniel_simon.wav'
,
# 2ch 44100Hz
# Files from https://www.mediacollege.com/audio/tone/download/
'100Hz_44100Hz_16bit_05sec.wav'
,
# 1ch
'440Hz_44100Hz_16bit_05sec.wav'
,
# 1ch
]
for
filename
in
filenames
:
filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
filename
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
)
_test_batch
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
def
test_istft
(
self
):
stft
=
torch
.
tensor
([
[[
4.
,
0.
],
[
4.
,
0.
],
[
4.
,
0.
],
[
4.
,
0.
],
[
4.
,
0.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]]
])
_test_batch
(
F
.
istft
,
stft
,
n_fft
=
4
,
length
=
4
)
test/test_functional.py
View file @
a91cae7c
...
@@ -23,25 +23,6 @@ class TestFunctional(unittest.TestCase):
...
@@ -23,25 +23,6 @@ class TestFunctional(unittest.TestCase):
'steam-train-whistle-daniel_simon.wav'
)
'steam-train-whistle-daniel_simon.wav'
)
waveform_train
,
sr_train
=
torchaudio
.
load
(
test_filepath
)
waveform_train
,
sr_train
=
torchaudio
.
load
(
test_filepath
)
def
test_batch_griffinlim
(
self
):
torch
.
random
.
manual_seed
(
42
)
tensor
=
torch
.
rand
((
1
,
201
,
6
))
n_fft
=
400
ws
=
400
hop
=
200
window
=
torch
.
hann_window
(
ws
)
power
=
2
normalize
=
False
momentum
=
0.99
n_iter
=
32
length
=
1000
self
.
_test_batch
(
F
.
griffinlim
,
tensor
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
,
n_iter
,
momentum
,
length
,
0
,
atol
=
5e-5
)
def
_test_compute_deltas
(
self
,
specgram
,
expected
,
win_length
=
3
,
atol
=
1e-6
,
rtol
=
1e-8
):
def
_test_compute_deltas
(
self
,
specgram
,
expected
,
win_length
=
3
,
atol
=
1e-6
,
rtol
=
1e-8
):
computed
=
F
.
compute_deltas
(
specgram
,
win_length
=
win_length
)
computed
=
F
.
compute_deltas
(
specgram
,
win_length
=
win_length
)
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
self
.
assertTrue
(
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
))
...
@@ -58,10 +39,6 @@ class TestFunctional(unittest.TestCase):
...
@@ -58,10 +39,6 @@ class TestFunctional(unittest.TestCase):
[
0.5
,
1.0
,
1.0
,
0.5
]]])
[
0.5
,
1.0
,
1.0
,
0.5
]]])
self
.
_test_compute_deltas
(
specgram
,
expected
)
self
.
_test_compute_deltas
(
specgram
,
expected
)
def
test_batch_pitch
(
self
):
waveform
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
)
self
.
_test_batch
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
def
_compare_estimate
(
self
,
sound
,
estimate
,
atol
=
1e-6
,
rtol
=
1e-8
):
def
_compare_estimate
(
self
,
sound
,
estimate
,
atol
=
1e-6
,
rtol
=
1e-8
):
# trim sound for case when constructed signal is shorter than original
# trim sound for case when constructed signal is shorter than original
sound
=
sound
[...,
:
estimate
.
size
(
-
1
)]
sound
=
sound
[...,
:
estimate
.
size
(
-
1
)]
...
@@ -298,16 +275,6 @@ class TestFunctional(unittest.TestCase):
...
@@ -298,16 +275,6 @@ class TestFunctional(unittest.TestCase):
data_size
=
(
2
,
7
,
3
,
2
)
data_size
=
(
2
,
7
,
3
,
2
)
self
.
_test_linearity_of_istft
(
data_size
,
kwargs4
,
atol
=
1e-5
,
rtol
=
1e-8
)
self
.
_test_linearity_of_istft
(
data_size
,
kwargs4
,
atol
=
1e-5
,
rtol
=
1e-8
)
def
test_batch_istft
(
self
):
stft
=
torch
.
tensor
([
[[
4.
,
0.
],
[
4.
,
0.
],
[
4.
,
0.
],
[
4.
,
0.
],
[
4.
,
0.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]],
[[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
],
[
0.
,
0.
]]
])
self
.
_test_batch
(
F
.
istft
,
stft
,
n_fft
=
4
,
length
=
4
)
@
unittest
.
skipIf
(
"sox"
not
in
BACKENDS
,
"sox not available"
)
@
unittest
.
skipIf
(
"sox"
not
in
BACKENDS
,
"sox not available"
)
@
AudioBackendScope
(
"sox"
)
@
AudioBackendScope
(
"sox"
)
def
test_gain
(
self
):
def
test_gain
(
self
):
...
@@ -383,65 +350,6 @@ class TestFunctional(unittest.TestCase):
...
@@ -383,65 +350,6 @@ class TestFunctional(unittest.TestCase):
s
=
((
freq
-
freq_ref
).
abs
()
>
threshold
).
sum
()
s
=
((
freq
-
freq_ref
).
abs
()
>
threshold
).
sum
()
self
.
assertFalse
(
s
)
self
.
assertFalse
(
s
)
# Convert to stereo and batch for testing purposes
self
.
_test_batch
(
F
.
detect_pitch_frequency
,
waveform
,
sample_rate
)
def
_test_batch_shape
(
self
,
functional
,
tensor
,
*
args
,
**
kwargs
):
kwargs_compare
=
{}
if
'atol'
in
kwargs
:
atol
=
kwargs
[
'atol'
]
del
kwargs
[
'atol'
]
kwargs_compare
[
'atol'
]
=
atol
if
'rtol'
in
kwargs
:
rtol
=
kwargs
[
'rtol'
]
del
kwargs
[
'rtol'
]
kwargs_compare
[
'rtol'
]
=
rtol
# Single then transform then batch
torch
.
random
.
manual_seed
(
42
)
expected
=
functional
(
tensor
.
clone
(),
*
args
,
**
kwargs
)
expected
=
expected
.
unsqueeze
(
0
).
unsqueeze
(
0
)
# 1-Batch then transform
tensors
=
tensor
.
unsqueeze
(
0
).
unsqueeze
(
0
)
torch
.
random
.
manual_seed
(
42
)
computed
=
functional
(
tensors
.
clone
(),
*
args
,
**
kwargs
)
self
.
_compare_estimate
(
computed
,
expected
,
**
kwargs_compare
)
return
tensors
,
expected
def
_test_batch
(
self
,
functional
,
tensor
,
*
args
,
**
kwargs
):
tensors
,
expected
=
self
.
_test_batch_shape
(
functional
,
tensor
,
*
args
,
**
kwargs
)
kwargs_compare
=
{}
if
'atol'
in
kwargs
:
atol
=
kwargs
[
'atol'
]
del
kwargs
[
'atol'
]
kwargs_compare
[
'atol'
]
=
atol
if
'rtol'
in
kwargs
:
rtol
=
kwargs
[
'rtol'
]
del
kwargs
[
'rtol'
]
kwargs_compare
[
'rtol'
]
=
rtol
# 3-Batch then transform
ind
=
[
3
]
+
[
1
]
*
(
int
(
tensors
.
dim
())
-
1
)
tensors
=
tensor
.
repeat
(
*
ind
)
ind
=
[
3
]
+
[
1
]
*
(
int
(
expected
.
dim
())
-
1
)
expected
=
expected
.
repeat
(
*
ind
)
torch
.
random
.
manual_seed
(
42
)
computed
=
functional
(
tensors
.
clone
(),
*
args
,
**
kwargs
)
def
test_DB_to_amplitude
(
self
):
def
test_DB_to_amplitude
(
self
):
# Make some noise
# Make some noise
x
=
torch
.
rand
(
1000
)
x
=
torch
.
rand
(
1000
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment