Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
e43ee196
Unverified
Commit
e43ee196
authored
Jul 06, 2020
by
moto
Committed by
GitHub
Jul 06, 2020
Browse files
Replace torchaudio.load in test with scipy func (#762)
parent
4b583eab
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
29 deletions
+29
-29
test/functional_cpu_test.py
test/functional_cpu_test.py
+1
-3
test/kaldi_compatibility_impl.py
test/kaldi_compatibility_impl.py
+19
-16
test/test_librosa_compatibility.py
test/test_librosa_compatibility.py
+5
-4
test/test_transforms.py
test/test_transforms.py
+3
-3
test/torchscript_consistency_impl.py
test/torchscript_consistency_impl.py
+1
-3
No files found.
test/functional_cpu_test.py
View file @
e43ee196
...
@@ -299,8 +299,6 @@ class TestIstft(common_utils.TorchaudioTestCase):
...
@@ -299,8 +299,6 @@ class TestIstft(common_utils.TorchaudioTestCase):
class
TestDetectPitchFrequency
(
common_utils
.
TorchaudioTestCase
):
class
TestDetectPitchFrequency
(
common_utils
.
TorchaudioTestCase
):
backend
=
'default'
def
test_pitch
(
self
):
def
test_pitch
(
self
):
test_filepath_100
=
common_utils
.
get_asset_path
(
"100Hz_44100Hz_16bit_05sec.wav"
)
test_filepath_100
=
common_utils
.
get_asset_path
(
"100Hz_44100Hz_16bit_05sec.wav"
)
test_filepath_440
=
common_utils
.
get_asset_path
(
"440Hz_44100Hz_16bit_05sec.wav"
)
test_filepath_440
=
common_utils
.
get_asset_path
(
"440Hz_44100Hz_16bit_05sec.wav"
)
...
@@ -312,7 +310,7 @@ class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
...
@@ -312,7 +310,7 @@ class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
]
]
for
filename
,
freq_ref
in
tests
:
for
filename
,
freq_ref
in
tests
:
waveform
,
sample_rate
=
torchaudio
.
load
(
filename
)
waveform
,
sample_rate
=
common_utils
.
load
_wav
(
filename
)
freq
=
torchaudio
.
functional
.
detect_pitch_frequency
(
waveform
,
sample_rate
)
freq
=
torchaudio
.
functional
.
detect_pitch_frequency
(
waveform
,
sample_rate
)
...
...
test/kaldi_compatibility_impl.py
View file @
e43ee196
...
@@ -5,11 +5,16 @@ import kaldi_io
...
@@ -5,11 +5,16 @@ import kaldi_io
import
torch
import
torch
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
import
torchaudio.compliance.kaldi
import
torchaudio.compliance.kaldi
from
.
import
common_utils
from
.common_utils
import
load_params
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
.common_utils
import
(
TestBaseMixin
,
load_params
,
skipIfNoExec
,
get_asset_path
,
load_wav
)
def
_convert_args
(
**
kwargs
):
def
_convert_args
(
**
kwargs
):
args
=
[]
args
=
[]
...
@@ -43,14 +48,12 @@ def _run_kaldi(command, input_type, input_value):
...
@@ -43,14 +48,12 @@ def _run_kaldi(command, input_type, input_value):
return
torch
.
from_numpy
(
result
.
copy
())
# copy supresses some torch warning
return
torch
.
from_numpy
(
result
.
copy
())
# copy supresses some torch warning
class
Kaldi
(
common_utils
.
TestBaseMixin
):
class
Kaldi
(
TestBaseMixin
):
backend
=
'sox'
def
assert_equal
(
self
,
output
,
*
,
expected
,
rtol
=
None
,
atol
=
None
):
def
assert_equal
(
self
,
output
,
*
,
expected
,
rtol
=
None
,
atol
=
None
):
expected
=
expected
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
expected
=
expected
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
assertEqual
(
output
,
expected
,
rtol
=
rtol
,
atol
=
atol
)
self
.
assertEqual
(
output
,
expected
,
rtol
=
rtol
,
atol
=
atol
)
@
common_utils
.
skipIfNoExec
(
'apply-cmvn-sliding'
)
@
skipIfNoExec
(
'apply-cmvn-sliding'
)
def
test_sliding_window_cmn
(
self
):
def
test_sliding_window_cmn
(
self
):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs
=
{
kwargs
=
{
...
@@ -67,33 +70,33 @@ class Kaldi(common_utils.TestBaseMixin):
...
@@ -67,33 +70,33 @@ class Kaldi(common_utils.TestBaseMixin):
self
.
assert_equal
(
result
,
expected
=
kaldi_result
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
)
@
parameterized
.
expand
(
load_params
(
'kaldi_test_fbank_args.json'
))
@
parameterized
.
expand
(
load_params
(
'kaldi_test_fbank_args.json'
))
@
common_utils
.
skipIfNoExec
(
'compute-fbank-feats'
)
@
skipIfNoExec
(
'compute-fbank-feats'
)
def
test_fbank
(
self
,
kwargs
):
def
test_fbank
(
self
,
kwargs
):
"""fbank should be numerically compatible with compute-fbank-feats"""
"""fbank should be numerically compatible with compute-fbank-feats"""
wave_file
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
wave_file
=
get_asset_path
(
'kaldi_file.wav'
)
waveform
=
torchaudio
.
load_wav
(
wave_file
)[
0
].
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
waveform
=
load_wav
(
wave_file
,
normalize
=
False
)[
0
].
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
result
=
torchaudio
.
compliance
.
kaldi
.
fbank
(
waveform
,
**
kwargs
)
result
=
torchaudio
.
compliance
.
kaldi
.
fbank
(
waveform
,
**
kwargs
)
command
=
[
'compute-fbank-feats'
]
+
_convert_args
(
**
kwargs
)
+
[
'scp:-'
,
'ark:-'
]
command
=
[
'compute-fbank-feats'
]
+
_convert_args
(
**
kwargs
)
+
[
'scp:-'
,
'ark:-'
]
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
,
rtol
=
1e-4
,
atol
=
1e-8
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
,
rtol
=
1e-4
,
atol
=
1e-8
)
@
parameterized
.
expand
(
load_params
(
'kaldi_test_spectrogram_args.json'
))
@
parameterized
.
expand
(
load_params
(
'kaldi_test_spectrogram_args.json'
))
@
common_utils
.
skipIfNoExec
(
'compute-spectrogram-feats'
)
@
skipIfNoExec
(
'compute-spectrogram-feats'
)
def
test_spectrogram
(
self
,
kwargs
):
def
test_spectrogram
(
self
,
kwargs
):
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
wave_file
=
get_asset_path
(
'kaldi_file.wav'
)
waveform
=
torchaudio
.
load_wav
(
wave_file
)[
0
].
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
waveform
=
load_wav
(
wave_file
,
normalize
=
False
)[
0
].
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
result
=
torchaudio
.
compliance
.
kaldi
.
spectrogram
(
waveform
,
**
kwargs
)
result
=
torchaudio
.
compliance
.
kaldi
.
spectrogram
(
waveform
,
**
kwargs
)
command
=
[
'compute-spectrogram-feats'
]
+
_convert_args
(
**
kwargs
)
+
[
'scp:-'
,
'ark:-'
]
command
=
[
'compute-spectrogram-feats'
]
+
_convert_args
(
**
kwargs
)
+
[
'scp:-'
,
'ark:-'
]
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
,
rtol
=
1e-4
,
atol
=
1e-8
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
,
rtol
=
1e-4
,
atol
=
1e-8
)
@
parameterized
.
expand
(
load_params
(
'kaldi_test_mfcc_args.json'
))
@
parameterized
.
expand
(
load_params
(
'kaldi_test_mfcc_args.json'
))
@
common_utils
.
skipIfNoExec
(
'compute-mfcc-feats'
)
@
skipIfNoExec
(
'compute-mfcc-feats'
)
def
test_mfcc
(
self
,
kwargs
):
def
test_mfcc
(
self
,
kwargs
):
"""mfcc should be numerically compatible with compute-mfcc-feats"""
"""mfcc should be numerically compatible with compute-mfcc-feats"""
wave_file
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
wave_file
=
get_asset_path
(
'kaldi_file.wav'
)
waveform
=
torchaudio
.
load_wav
(
wave_file
)[
0
].
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
waveform
=
load_wav
(
wave_file
,
normalize
=
False
)[
0
].
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
result
=
torchaudio
.
compliance
.
kaldi
.
mfcc
(
waveform
,
**
kwargs
)
result
=
torchaudio
.
compliance
.
kaldi
.
mfcc
(
waveform
,
**
kwargs
)
command
=
[
'compute-mfcc-feats'
]
+
_convert_args
(
**
kwargs
)
+
[
'scp:-'
,
'ark:-'
]
command
=
[
'compute-mfcc-feats'
]
+
_convert_args
(
**
kwargs
)
+
[
'scp:-'
,
'ark:-'
]
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
...
...
test/test_librosa_compatibility.py
View file @
e43ee196
...
@@ -160,7 +160,8 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -160,7 +160,8 @@ class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for functions in `transforms` module."""
"""Test suite for functions in `transforms` module."""
def
assert_compatibilities
(
self
,
n_fft
,
hop_length
,
power
,
n_mels
,
n_mfcc
,
sample_rate
):
def
assert_compatibilities
(
self
,
n_fft
,
hop_length
,
power
,
n_mels
,
n_mfcc
,
sample_rate
):
common_utils
.
set_audio_backend
(
'default'
)
common_utils
.
set_audio_backend
(
'default'
)
sound
,
sample_rate
=
_load_audio_asset
(
'sinewave.wav'
)
path
=
common_utils
.
get_asset_path
(
'sinewave.wav'
)
sound
,
sample_rate
=
common_utils
.
load_wav
(
path
)
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
# (64000)
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
# (64000)
# test core spectrogram
# test core spectrogram
...
@@ -300,9 +301,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -300,9 +301,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
hop_length
=
n_fft
//
4
hop_length
=
n_fft
//
4
# Prepare mel spectrogram input. We use torchaudio to compute one.
# Prepare mel spectrogram input. We use torchaudio to compute one.
common_utils
.
s
et_a
udio_backend
(
'default
'
)
path
=
common_utils
.
g
et_a
sset_path
(
'steam-train-whistle-daniel_simon.wav
'
)
sound
,
sample_rate
=
_load_audio_asset
(
sound
,
sample_rate
=
common_utils
.
load_wav
(
path
)
'steam-train-whistle-daniel_simon.wav'
,
offset
=
2
**
10
,
num_frames
=
2
**
14
)
sound
=
sound
[:,
2
**
10
:
2
**
10
+
2
**
14
]
sound
=
sound
.
mean
(
dim
=
0
,
keepdim
=
True
)
sound
=
sound
.
mean
(
dim
=
0
,
keepdim
=
True
)
spec_orig
=
F
.
spectrogram
(
spec_orig
=
F
.
spectrogram
(
sound
,
pad
=
0
,
window
=
torch
.
hann_window
(
n_fft
),
n_fft
=
n_fft
,
sound
,
pad
=
0
,
window
=
torch
.
hann_window
(
n_fft
),
n_fft
=
n_fft
,
...
...
test/test_transforms.py
View file @
e43ee196
...
@@ -45,7 +45,7 @@ class Tester(common_utils.TorchaudioTestCase):
...
@@ -45,7 +45,7 @@ class Tester(common_utils.TorchaudioTestCase):
def
test_AmplitudeToDB
(
self
):
def
test_AmplitudeToDB
(
self
):
filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
)
waveform
=
common_utils
.
load
_wav
(
filepath
)
[
0
]
mag_to_db_transform
=
transforms
.
AmplitudeToDB
(
'magnitude'
,
80.
)
mag_to_db_transform
=
transforms
.
AmplitudeToDB
(
'magnitude'
,
80.
)
power_to_db_transform
=
transforms
.
AmplitudeToDB
(
'power'
,
80.
)
power_to_db_transform
=
transforms
.
AmplitudeToDB
(
'power'
,
80.
)
...
@@ -115,7 +115,7 @@ class Tester(common_utils.TorchaudioTestCase):
...
@@ -115,7 +115,7 @@ class Tester(common_utils.TorchaudioTestCase):
self
.
assertTrue
(
mel_transform2
.
mel_scale
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertTrue
(
mel_transform2
.
mel_scale
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
# check on multi-channel audio
# check on multi-channel audio
filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.wav'
)
x_stereo
,
sr_stereo
=
torchaudio
.
load
(
filepath
)
# (2, 278756), 44100
x_stereo
=
common_utils
.
load
_wav
(
filepath
)
[
0
]
# (2, 278756), 44100
spectrogram_stereo
=
s2db
(
mel_transform
(
x_stereo
))
# (2, 128, 1394)
spectrogram_stereo
=
s2db
(
mel_transform
(
x_stereo
))
# (2, 128, 1394)
self
.
assertTrue
(
spectrogram_stereo
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_stereo
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_stereo
.
size
(
0
)
==
2
)
self
.
assertTrue
(
spectrogram_stereo
.
size
(
0
)
==
2
)
...
@@ -166,7 +166,7 @@ class Tester(common_utils.TorchaudioTestCase):
...
@@ -166,7 +166,7 @@ class Tester(common_utils.TorchaudioTestCase):
def
test_resample_size
(
self
):
def
test_resample_size
(
self
):
input_path
=
common_utils
.
get_asset_path
(
'sinewave.wav'
)
input_path
=
common_utils
.
get_asset_path
(
'sinewave.wav'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
input_path
)
waveform
,
sample_rate
=
common_utils
.
load
_wav
(
input_path
)
upsample_rate
=
sample_rate
*
2
upsample_rate
=
sample_rate
*
2
downsample_rate
=
sample_rate
//
2
downsample_rate
=
sample_rate
//
2
...
...
test/torchscript_consistency_impl.py
View file @
e43ee196
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
import
unittest
import
unittest
import
torch
import
torch
import
torchaudio
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
import
torchaudio.transforms
as
T
import
torchaudio.transforms
as
T
...
@@ -616,6 +615,5 @@ class Transforms(common_utils.TestBaseMixin):
...
@@ -616,6 +615,5 @@ class Transforms(common_utils.TestBaseMixin):
def
test_Vad
(
self
):
def
test_Vad
(
self
):
filepath
=
common_utils
.
get_asset_path
(
"vad-go-mono-32000.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"vad-go-mono-32000.wav"
)
common_utils
.
set_audio_backend
(
'default'
)
waveform
,
sample_rate
=
common_utils
.
load_wav
(
filepath
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
)
self
.
_assert_consistency
(
T
.
Vad
(
sample_rate
=
sample_rate
),
waveform
)
self
.
_assert_consistency
(
T
.
Vad
(
sample_rate
=
sample_rate
),
waveform
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment