Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a20da5e3
Unverified
Commit
a20da5e3
authored
Jul 01, 2020
by
moto
Committed by
GitHub
Jul 01, 2020
Browse files
Refactor test utilities (#756)
parent
6b159054
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
421 additions
and
155 deletions
+421
-155
test/common_utils/__init__.py
test/common_utils/__init__.py
+31
-0
test/common_utils/backend_utils.py
test/common_utils/backend_utils.py
+41
-0
test/common_utils/data_utils.py
test/common_utils/data_utils.py
+78
-0
test/common_utils/parameterized_utils.py
test/common_utils/parameterized_utils.py
+10
-0
test/common_utils/sox_utils.py
test/common_utils/sox_utils.py
+0
-0
test/common_utils/test_case_utils.py
test/common_utils/test_case_utils.py
+75
-0
test/common_utils/wav_utils.py
test/common_utils/wav_utils.py
+86
-0
test/functional_cpu_test.py
test/functional_cpu_test.py
+28
-3
test/kaldi_compatibility_impl.py
test/kaldi_compatibility_impl.py
+5
-10
test/sox_io_backend/common.py
test/sox_io_backend/common.py
+2
-90
test/sox_io_backend/test_info.py
test/sox_io_backend/test_info.py
+9
-9
test/sox_io_backend/test_load.py
test/sox_io_backend/test_load.py
+15
-15
test/sox_io_backend/test_roundtrip.py
test/sox_io_backend/test_roundtrip.py
+4
-4
test/sox_io_backend/test_save.py
test/sox_io_backend/test_save.py
+15
-15
test/sox_io_backend/test_torchscript.py
test/sox_io_backend/test_torchscript.py
+8
-8
test/test_io.py
test/test_io.py
+14
-1
No files found.
test/common_utils/__init__.py
0 → 100644
View file @
a20da5e3
from
.data_utils
import
(
get_asset_path
,
get_whitenoise
,
get_sinusoid
,
)
from
.backend_utils
import
(
set_audio_backend
,
BACKENDS
,
BACKENDS_MP3
,
)
from
.test_case_utils
import
(
TempDirMixin
,
TestBaseMixin
,
PytorchTestCase
,
TorchaudioTestCase
,
skipIfNoCuda
,
skipIfNoExec
,
skipIfNoModule
,
skipIfNoExtension
,
skipIfNoSoxBackend
,
)
from
.wav_utils
import
(
get_wav_data
,
normalize_wav
,
load_wav
,
save_wav
,
)
from
.parameterized_utils
import
(
load_params
,
)
from
.
import
sox_utils
test/common_utils/backend_utils.py
0 → 100644
View file @
a20da5e3
import
unittest
import
torchaudio
from
.
import
data_utils
BACKENDS
=
torchaudio
.
list_audio_backends
()
def
_filter_backends_with_mp3
(
backends
):
# Filter out backends that do not support mp3
test_filepath
=
data_utils
.
get_asset_path
(
'steam-train-whistle-daniel_simon.mp3'
)
def
supports_mp3
(
backend
):
torchaudio
.
set_audio_backend
(
backend
)
try
:
torchaudio
.
load
(
test_filepath
)
return
True
except
(
RuntimeError
,
ImportError
):
return
False
return
[
backend
for
backend
in
backends
if
supports_mp3
(
backend
)]
BACKENDS_MP3
=
_filter_backends_with_mp3
(
BACKENDS
)
def
set_audio_backend
(
backend
):
"""Allow additional backend value, 'default'"""
if
backend
==
'default'
:
if
'sox'
in
BACKENDS
:
be
=
'sox'
elif
'soundfile'
in
BACKENDS
:
be
=
'soundfile'
else
:
raise
unittest
.
SkipTest
(
'No default backend available'
)
else
:
be
=
backend
torchaudio
.
set_audio_backend
(
be
)
test/common_utils.py
→
test/common_utils
/data_utils
.py
View file @
a20da5e3
import
os
import
os.path
import
shutil
import
tempfile
import
unittest
from
typing
import
Union
from
typing
import
Union
from
shutil
import
copytree
import
torch
import
torch
from
torch.testing._internal.common_utils
import
TestCase
as
PytorchTestCase
import
torchaudio
from
torchaudio._internal.module_utils
import
is_module_available
_TEST_DIR_PATH
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
BACKENDS
=
torchaudio
.
list_audio_backends
()
_TEST_DIR_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
))
def
get_asset_path
(
*
paths
):
def
get_asset_path
(
*
paths
):
...
@@ -19,138 +13,6 @@ def get_asset_path(*paths):
...
@@ -19,138 +13,6 @@ def get_asset_path(*paths):
return
os
.
path
.
join
(
_TEST_DIR_PATH
,
'assets'
,
*
paths
)
return
os
.
path
.
join
(
_TEST_DIR_PATH
,
'assets'
,
*
paths
)
def
create_temp_assets_dir
():
"""
Creates a temporary directory and moves all files from test/assets there.
Returns a Tuple[string, TemporaryDirectory] which is the folder path
and object.
"""
tmp_dir
=
tempfile
.
TemporaryDirectory
()
copytree
(
os
.
path
.
join
(
_TEST_DIR_PATH
,
"assets"
),
os
.
path
.
join
(
tmp_dir
.
name
,
"assets"
))
return
tmp_dir
.
name
,
tmp_dir
def
random_float_tensor
(
seed
,
size
,
a
=
22695477
,
c
=
1
,
m
=
2
**
32
):
""" Generates random tensors given a seed and size
https://en.wikipedia.org/wiki/Linear_congruential_generator
X_{n + 1} = (a * X_n + c) % m
Using Borland C/C++ values
The tensor will have values between [0,1)
Inputs:
seed (int): an int
size (Tuple[int]): the size of the output tensor
a (int): the multiplier constant to the generator
c (int): the additive constant to the generator
m (int): the modulus constant to the generator
"""
num_elements
=
1
for
s
in
size
:
num_elements
*=
s
arr
=
[(
a
*
seed
+
c
)
%
m
]
for
i
in
range
(
num_elements
-
1
):
arr
.
append
((
a
*
arr
[
i
]
+
c
)
%
m
)
return
torch
.
tensor
(
arr
).
float
().
view
(
size
)
/
m
def
filter_backends_with_mp3
(
backends
):
# Filter out backends that do not support mp3
test_filepath
=
get_asset_path
(
'steam-train-whistle-daniel_simon.mp3'
)
def
supports_mp3
(
backend
):
torchaudio
.
set_audio_backend
(
backend
)
try
:
torchaudio
.
load
(
test_filepath
)
return
True
except
(
RuntimeError
,
ImportError
):
return
False
return
[
backend
for
backend
in
backends
if
supports_mp3
(
backend
)]
BACKENDS_MP3
=
filter_backends_with_mp3
(
BACKENDS
)
def
set_audio_backend
(
backend
):
"""Allow additional backend value, 'default'"""
if
backend
==
'default'
:
if
'sox'
in
BACKENDS
:
be
=
'sox'
elif
'soundfile'
in
BACKENDS
:
be
=
'soundfile'
else
:
raise
unittest
.
SkipTest
(
'No default backend available'
)
else
:
be
=
backend
torchaudio
.
set_audio_backend
(
be
)
class
TempDirMixin
:
"""Mixin to provide easy access to temp dir"""
temp_dir_
=
None
base_temp_dir
=
None
temp_dir
=
None
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key
=
'TORCHAUDIO_TEST_TEMP_DIR'
if
key
in
os
.
environ
:
cls
.
base_temp_dir
=
os
.
environ
[
key
]
else
:
cls
.
temp_dir_
=
tempfile
.
TemporaryDirectory
()
cls
.
base_temp_dir
=
cls
.
temp_dir_
.
name
@
classmethod
def
tearDownClass
(
cls
):
super
().
tearDownClass
()
if
isinstance
(
cls
.
temp_dir_
,
tempfile
.
TemporaryDirectory
):
cls
.
temp_dir_
.
cleanup
()
def
setUp
(
self
):
self
.
temp_dir
=
os
.
path
.
join
(
self
.
base_temp_dir
,
self
.
id
())
def
get_temp_path
(
self
,
*
paths
):
path
=
os
.
path
.
join
(
self
.
temp_dir
,
*
paths
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
return
path
class
TestBaseMixin
:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype
=
None
device
=
None
backend
=
None
def
setUp
(
self
):
super
().
setUp
()
set_audio_backend
(
self
.
backend
)
class
TorchaudioTestCase
(
TestBaseMixin
,
PytorchTestCase
):
pass
def
skipIfNoExec
(
cmd
):
return
unittest
.
skipIf
(
shutil
.
which
(
cmd
)
is
None
,
f
'`
{
cmd
}
` is not available'
)
def
skipIfNoModule
(
module
,
display_name
=
None
):
display_name
=
display_name
or
module
return
unittest
.
skipIf
(
not
is_module_available
(
module
),
f
'"
{
display_name
}
" is not available'
)
skipIfNoSoxBackend
=
unittest
.
skipIf
(
'sox'
not
in
BACKENDS
,
'Sox backend not available'
)
skipIfNoCuda
=
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
skipIfNoExtension
=
skipIfNoModule
(
'torchaudio._torchaudio'
,
'torchaudio C++ extension'
)
def
get_whitenoise
(
def
get_whitenoise
(
*
,
*
,
sample_rate
:
int
=
16000
,
sample_rate
:
int
=
16000
,
...
...
test/common_utils/parameterized_utils.py
0 → 100644
View file @
a20da5e3
import
json
from
parameterized
import
param
from
.data_utils
import
get_asset_path
def
load_params
(
*
paths
):
with
open
(
get_asset_path
(
*
paths
),
'r'
)
as
file
:
return
[
param
(
json
.
loads
(
line
))
for
line
in
file
]
test/
sox_io_backend
/sox_utils.py
→
test/
common_utils
/sox_utils.py
View file @
a20da5e3
File moved
test/common_utils/test_case_utils.py
0 → 100644
View file @
a20da5e3
import
shutil
import
os.path
import
tempfile
import
unittest
import
torch
from
torch.testing._internal.common_utils
import
TestCase
as
PytorchTestCase
import
torchaudio
from
torchaudio._internal.module_utils
import
is_module_available
from
.backend_utils
import
set_audio_backend
class
TempDirMixin
:
"""Mixin to provide easy access to temp dir"""
temp_dir_
=
None
base_temp_dir
=
None
temp_dir
=
None
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key
=
'TORCHAUDIO_TEST_TEMP_DIR'
if
key
in
os
.
environ
:
cls
.
base_temp_dir
=
os
.
environ
[
key
]
else
:
cls
.
temp_dir_
=
tempfile
.
TemporaryDirectory
()
cls
.
base_temp_dir
=
cls
.
temp_dir_
.
name
@
classmethod
def
tearDownClass
(
cls
):
super
().
tearDownClass
()
if
isinstance
(
cls
.
temp_dir_
,
tempfile
.
TemporaryDirectory
):
cls
.
temp_dir_
.
cleanup
()
def
setUp
(
self
):
super
().
setUp
()
self
.
temp_dir
=
os
.
path
.
join
(
self
.
base_temp_dir
,
self
.
id
())
def
get_temp_path
(
self
,
*
paths
):
path
=
os
.
path
.
join
(
self
.
temp_dir
,
*
paths
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
return
path
class
TestBaseMixin
:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype
=
None
device
=
None
backend
=
None
def
setUp
(
self
):
super
().
setUp
()
set_audio_backend
(
self
.
backend
)
class
TorchaudioTestCase
(
TestBaseMixin
,
PytorchTestCase
):
pass
def
skipIfNoExec
(
cmd
):
return
unittest
.
skipIf
(
shutil
.
which
(
cmd
)
is
None
,
f
'`
{
cmd
}
` is not available'
)
def
skipIfNoModule
(
module
,
display_name
=
None
):
display_name
=
display_name
or
module
return
unittest
.
skipIf
(
not
is_module_available
(
module
),
f
'"
{
display_name
}
" is not available'
)
skipIfNoSoxBackend
=
unittest
.
skipIf
(
'sox'
not
in
torchaudio
.
list_audio_backends
(),
'Sox backend not available'
)
skipIfNoCuda
=
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
skipIfNoExtension
=
skipIfNoModule
(
'torchaudio._torchaudio'
,
'torchaudio C++ extension'
)
test/common_utils/wav_utils.py
0 → 100644
View file @
a20da5e3
from
typing
import
Optional
import
torch
import
scipy.io.wavfile
def
normalize_wav
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
tensor
.
dtype
==
torch
.
float32
:
pass
elif
tensor
.
dtype
==
torch
.
int32
:
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
[
tensor
>
0
]
/=
2147483647.
tensor
[
tensor
<
0
]
/=
2147483648.
elif
tensor
.
dtype
==
torch
.
int16
:
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
[
tensor
>
0
]
/=
32767.
tensor
[
tensor
<
0
]
/=
32768.
elif
tensor
.
dtype
==
torch
.
uint8
:
tensor
=
tensor
.
to
(
torch
.
float32
)
-
128
tensor
[
tensor
>
0
]
/=
127.
tensor
[
tensor
<
0
]
/=
128.
return
tensor
def
get_wav_data
(
dtype
:
str
,
num_channels
:
int
,
*
,
num_frames
:
Optional
[
int
]
=
None
,
normalize
:
bool
=
True
,
channels_first
:
bool
=
True
,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483648, 2147483647] for int32
[-32768, 32767] for int16
[0, 255] for uint8
num_frames allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_
=
getattr
(
torch
,
dtype
)
if
num_frames
is
None
:
if
dtype
==
'uint8'
:
num_frames
=
256
else
:
num_frames
=
1
<<
16
if
dtype
==
'uint8'
:
base
=
torch
.
linspace
(
0
,
255
,
num_frames
,
dtype
=
dtype_
)
if
dtype
==
'float32'
:
base
=
torch
.
linspace
(
-
1.
,
1.
,
num_frames
,
dtype
=
dtype_
)
if
dtype
==
'int32'
:
base
=
torch
.
linspace
(
-
2147483648
,
2147483647
,
num_frames
,
dtype
=
dtype_
)
if
dtype
==
'int16'
:
base
=
torch
.
linspace
(
-
32768
,
32767
,
num_frames
,
dtype
=
dtype_
)
data
=
base
.
repeat
([
num_channels
,
1
])
if
not
channels_first
:
data
=
data
.
transpose
(
1
,
0
)
if
normalize
:
data
=
normalize_wav
(
data
)
return
data
def
load_wav
(
path
:
str
,
normalize
=
True
,
channels_first
=
True
)
->
torch
.
Tensor
:
"""Load wav file without torchaudio"""
sample_rate
,
data
=
scipy
.
io
.
wavfile
.
read
(
path
)
data
=
torch
.
from_numpy
(
data
.
copy
())
if
data
.
ndim
==
1
:
data
=
data
.
unsqueeze
(
1
)
if
normalize
:
data
=
normalize_wav
(
data
)
if
channels_first
:
data
=
data
.
transpose
(
1
,
0
)
return
data
,
sample_rate
def
save_wav
(
path
,
data
,
sample_rate
,
channels_first
=
True
):
"""Save wav file without torchaudio"""
if
channels_first
:
data
=
data
.
transpose
(
1
,
0
)
scipy
.
io
.
wavfile
.
write
(
path
,
sample_rate
,
data
.
numpy
())
test/functional_cpu_test.py
View file @
a20da5e3
...
@@ -10,6 +10,31 @@ from . import common_utils
...
@@ -10,6 +10,31 @@ from . import common_utils
from
.functional_impl
import
Lfilter
from
.functional_impl
import
Lfilter
def
random_float_tensor
(
seed
,
size
,
a
=
22695477
,
c
=
1
,
m
=
2
**
32
):
""" Generates random tensors given a seed and size
https://en.wikipedia.org/wiki/Linear_congruential_generator
X_{n + 1} = (a * X_n + c) % m
Using Borland C/C++ values
The tensor will have values between [0,1)
Inputs:
seed (int): an int
size (Tuple[int]): the size of the output tensor
a (int): the multiplier constant to the generator
c (int): the additive constant to the generator
m (int): the modulus constant to the generator
"""
num_elements
=
1
for
s
in
size
:
num_elements
*=
s
arr
=
[(
a
*
seed
+
c
)
%
m
]
for
i
in
range
(
num_elements
-
1
):
arr
.
append
((
a
*
arr
[
i
]
+
c
)
%
m
)
return
torch
.
tensor
(
arr
).
float
().
view
(
size
)
/
m
class
TestLFilterFloat32
(
Lfilter
,
common_utils
.
PytorchTestCase
):
class
TestLFilterFloat32
(
Lfilter
,
common_utils
.
PytorchTestCase
):
dtype
=
torch
.
float32
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
...
@@ -49,7 +74,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
...
@@ -49,7 +74,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
for
data_size
in
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]:
for
data_size
in
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]:
for
i
in
range
(
100
):
for
i
in
range
(
100
):
sound
=
common_utils
.
random_float_tensor
(
i
,
data_size
)
sound
=
random_float_tensor
(
i
,
data_size
)
stft
=
torch
.
stft
(
sound
,
**
kwargs
)
stft
=
torch
.
stft
(
sound
,
**
kwargs
)
estimate
=
torchaudio
.
functional
.
istft
(
stft
,
length
=
sound
.
size
(
1
),
**
kwargs
)
estimate
=
torchaudio
.
functional
.
istft
(
stft
,
length
=
sound
.
size
(
1
),
**
kwargs
)
...
@@ -211,8 +236,8 @@ class TestIstft(common_utils.TorchaudioTestCase):
...
@@ -211,8 +236,8 @@ class TestIstft(common_utils.TorchaudioTestCase):
def
_test_linearity_of_istft
(
self
,
data_size
,
kwargs
,
atol
=
1e-6
,
rtol
=
1e-8
):
def
_test_linearity_of_istft
(
self
,
data_size
,
kwargs
,
atol
=
1e-6
,
rtol
=
1e-8
):
for
i
in
range
(
self
.
number_of_trials
):
for
i
in
range
(
self
.
number_of_trials
):
tensor1
=
common_utils
.
random_float_tensor
(
i
,
data_size
)
tensor1
=
random_float_tensor
(
i
,
data_size
)
tensor2
=
common_utils
.
random_float_tensor
(
i
*
2
,
data_size
)
tensor2
=
random_float_tensor
(
i
*
2
,
data_size
)
a
,
b
=
torch
.
rand
(
2
)
a
,
b
=
torch
.
rand
(
2
)
istft1
=
torchaudio
.
functional
.
istft
(
tensor1
,
**
kwargs
)
istft1
=
torchaudio
.
functional
.
istft
(
tensor1
,
**
kwargs
)
istft2
=
torchaudio
.
functional
.
istft
(
tensor2
,
**
kwargs
)
istft2
=
torchaudio
.
functional
.
istft
(
tensor2
,
**
kwargs
)
...
...
test/kaldi_compatibility_impl.py
View file @
a20da5e3
"""Test suites for checking numerical compatibility against Kaldi"""
"""Test suites for checking numerical compatibility against Kaldi"""
import
json
import
subprocess
import
subprocess
import
kaldi_io
import
kaldi_io
...
@@ -8,7 +7,8 @@ import torchaudio.functional as F
...
@@ -8,7 +7,8 @@ import torchaudio.functional as F
import
torchaudio.compliance.kaldi
import
torchaudio.compliance.kaldi
from
.
import
common_utils
from
.
import
common_utils
from
parameterized
import
parameterized
,
param
from
.common_utils
import
load_params
from
parameterized
import
parameterized
def
_convert_args
(
**
kwargs
):
def
_convert_args
(
**
kwargs
):
...
@@ -43,11 +43,6 @@ def _run_kaldi(command, input_type, input_value):
...
@@ -43,11 +43,6 @@ def _run_kaldi(command, input_type, input_value):
return
torch
.
from_numpy
(
result
.
copy
())
# copy supresses some torch warning
return
torch
.
from_numpy
(
result
.
copy
())
# copy supresses some torch warning
def
_load_params
(
path
):
with
open
(
path
,
'r'
)
as
file
:
return
[
param
(
json
.
loads
(
line
))
for
line
in
file
]
class
Kaldi
(
common_utils
.
TestBaseMixin
):
class
Kaldi
(
common_utils
.
TestBaseMixin
):
backend
=
'sox'
backend
=
'sox'
...
@@ -71,7 +66,7 @@ class Kaldi(common_utils.TestBaseMixin):
...
@@ -71,7 +66,7 @@ class Kaldi(common_utils.TestBaseMixin):
kaldi_result
=
_run_kaldi
(
command
,
'ark'
,
tensor
)
kaldi_result
=
_run_kaldi
(
command
,
'ark'
,
tensor
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
)
@
parameterized
.
expand
(
_
load_params
(
common_utils
.
get_asset_path
(
'kaldi_test_fbank_args.json'
))
)
@
parameterized
.
expand
(
load_params
(
'kaldi_test_fbank_args.json'
))
@
common_utils
.
skipIfNoExec
(
'compute-fbank-feats'
)
@
common_utils
.
skipIfNoExec
(
'compute-fbank-feats'
)
def
test_fbank
(
self
,
kwargs
):
def
test_fbank
(
self
,
kwargs
):
"""fbank should be numerically compatible with compute-fbank-feats"""
"""fbank should be numerically compatible with compute-fbank-feats"""
...
@@ -82,7 +77,7 @@ class Kaldi(common_utils.TestBaseMixin):
...
@@ -82,7 +77,7 @@ class Kaldi(common_utils.TestBaseMixin):
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
,
rtol
=
1e-4
,
atol
=
1e-8
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
,
rtol
=
1e-4
,
atol
=
1e-8
)
@
parameterized
.
expand
(
_
load_params
(
common_utils
.
get_asset_path
(
'kaldi_test_spectrogram_args.json'
))
)
@
parameterized
.
expand
(
load_params
(
'kaldi_test_spectrogram_args.json'
))
@
common_utils
.
skipIfNoExec
(
'compute-spectrogram-feats'
)
@
common_utils
.
skipIfNoExec
(
'compute-spectrogram-feats'
)
def
test_spectrogram
(
self
,
kwargs
):
def
test_spectrogram
(
self
,
kwargs
):
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
...
@@ -93,7 +88,7 @@ class Kaldi(common_utils.TestBaseMixin):
...
@@ -93,7 +88,7 @@ class Kaldi(common_utils.TestBaseMixin):
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
,
rtol
=
1e-4
,
atol
=
1e-8
)
self
.
assert_equal
(
result
,
expected
=
kaldi_result
,
rtol
=
1e-4
,
atol
=
1e-8
)
@
parameterized
.
expand
(
_
load_params
(
common_utils
.
get_asset_path
(
'kaldi_test_mfcc_args.json'
))
)
@
parameterized
.
expand
(
load_params
(
'kaldi_test_mfcc_args.json'
))
@
common_utils
.
skipIfNoExec
(
'compute-mfcc-feats'
)
@
common_utils
.
skipIfNoExec
(
'compute-mfcc-feats'
)
def
test_mfcc
(
self
,
kwargs
):
def
test_mfcc
(
self
,
kwargs
):
"""mfcc should be numerically compatible with compute-mfcc-feats"""
"""mfcc should be numerically compatible with compute-mfcc-feats"""
...
...
test/sox_io_backend/common.py
View file @
a20da5e3
from
typing
import
Optional
def
name_func
(
func
,
_
,
params
):
return
f
'
{
func
.
__name__
}
_
{
"_"
.
join
(
str
(
arg
)
for
arg
in
params
.
args
)
}
'
import
torch
import
scipy.io.wavfile
def
get_test_name
(
func
,
_
,
params
):
return
f
'
{
func
.
__name__
}
_
{
"_"
.
join
(
str
(
p
)
for
p
in
params
.
args
)
}
'
def
normalize_wav
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
tensor
.
dtype
==
torch
.
float32
:
pass
elif
tensor
.
dtype
==
torch
.
int32
:
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
[
tensor
>
0
]
/=
2147483647.
tensor
[
tensor
<
0
]
/=
2147483648.
elif
tensor
.
dtype
==
torch
.
int16
:
tensor
=
tensor
.
to
(
torch
.
float32
)
tensor
[
tensor
>
0
]
/=
32767.
tensor
[
tensor
<
0
]
/=
32768.
elif
tensor
.
dtype
==
torch
.
uint8
:
tensor
=
tensor
.
to
(
torch
.
float32
)
-
128
tensor
[
tensor
>
0
]
/=
127.
tensor
[
tensor
<
0
]
/=
128.
return
tensor
def
get_wav_data
(
dtype
:
str
,
num_channels
:
int
,
*
,
num_frames
:
Optional
[
int
]
=
None
,
normalize
:
bool
=
True
,
channels_first
:
bool
=
True
,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483648, 2147483647] for int32
[-32768, 32767] for int16
[0, 255] for uint8
num_frames allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_
=
getattr
(
torch
,
dtype
)
if
num_frames
is
None
:
if
dtype
==
'uint8'
:
num_frames
=
256
else
:
num_frames
=
1
<<
16
if
dtype
==
'uint8'
:
base
=
torch
.
linspace
(
0
,
255
,
num_frames
,
dtype
=
dtype_
)
if
dtype
==
'float32'
:
base
=
torch
.
linspace
(
-
1.
,
1.
,
num_frames
,
dtype
=
dtype_
)
if
dtype
==
'int32'
:
base
=
torch
.
linspace
(
-
2147483648
,
2147483647
,
num_frames
,
dtype
=
dtype_
)
if
dtype
==
'int16'
:
base
=
torch
.
linspace
(
-
32768
,
32767
,
num_frames
,
dtype
=
dtype_
)
data
=
base
.
repeat
([
num_channels
,
1
])
if
not
channels_first
:
data
=
data
.
transpose
(
1
,
0
)
if
normalize
:
data
=
normalize_wav
(
data
)
return
data
def
load_wav
(
path
:
str
,
normalize
=
True
,
channels_first
=
True
)
->
torch
.
Tensor
:
"""Load wav file without torchaudio"""
sample_rate
,
data
=
scipy
.
io
.
wavfile
.
read
(
path
)
data
=
torch
.
from_numpy
(
data
.
copy
())
if
data
.
ndim
==
1
:
data
=
data
.
unsqueeze
(
1
)
if
normalize
:
data
=
normalize_wav
(
data
)
if
channels_first
:
data
=
data
.
transpose
(
1
,
0
)
return
data
,
sample_rate
def
save_wav
(
path
,
data
,
sample_rate
,
channels_first
=
True
):
"""Save wav file without torchaudio"""
if
channels_first
:
data
=
data
.
transpose
(
1
,
0
)
scipy
.
io
.
wavfile
.
write
(
path
,
sample_rate
,
data
.
numpy
())
test/sox_io_backend/test_info.py
View file @
a20da5e3
...
@@ -8,13 +8,13 @@ from ..common_utils import (
...
@@ -8,13 +8,13 @@ from ..common_utils import (
PytorchTestCase
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoExtension
,
skipIfNoExtension
,
)
sox_utils
,
from
.common
import
(
get_test_name
,
get_wav_data
,
get_wav_data
,
save_wav
,
save_wav
,
)
)
from
.
import
sox_utils
from
.common
import
(
name_func
,
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
...
@@ -24,7 +24,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -24,7 +24,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.info` can check wav file correctly"""
"""`sox_io_backend.info` can check wav file correctly"""
duration
=
1
duration
=
1
...
@@ -40,7 +40,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -40,7 +40,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
8000
,
16000
],
[
4
,
8
,
16
,
32
],
[
4
,
8
,
16
,
32
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_wav_multiple_channels
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav_multiple_channels
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration
=
1
duration
=
1
...
@@ -56,7 +56,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -56,7 +56,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
96
,
128
,
160
,
192
,
224
,
256
,
320
],
[
96
,
128
,
160
,
192
,
224
,
256
,
320
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.info` can check mp3 file correctly"""
"""`sox_io_backend.info` can check mp3 file correctly"""
duration
=
1
duration
=
1
...
@@ -75,7 +75,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -75,7 +75,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
list
(
range
(
9
)),
list
(
range
(
9
)),
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.info` can check flac file correctly"""
"""`sox_io_backend.info` can check flac file correctly"""
duration
=
1
duration
=
1
...
@@ -93,7 +93,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -93,7 +93,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.info` can check vorbis file correctly"""
"""`sox_io_backend.info` can check vorbis file correctly"""
duration
=
1
duration
=
1
...
...
test/sox_io_backend/test_load.py
View file @
a20da5e3
...
@@ -8,14 +8,14 @@ from ..common_utils import (
...
@@ -8,14 +8,14 @@ from ..common_utils import (
PytorchTestCase
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoExtension
,
skipIfNoExtension
,
)
from
.common
import
(
get_test_name
,
get_wav_data
,
get_wav_data
,
load_wav
,
load_wav
,
save_wav
,
save_wav
,
sox_utils
,
)
from
.common
import
(
name_func
,
)
)
from
.
import
sox_utils
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
...
@@ -129,7 +129,7 @@ class TestLoad(LoadTestBase):
...
@@ -129,7 +129,7 @@ class TestLoad(LoadTestBase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load wav format correctly."""
"""`sox_io_backend.load` can load wav format correctly."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
...
@@ -139,7 +139,7 @@ class TestLoad(LoadTestBase):
...
@@ -139,7 +139,7 @@ class TestLoad(LoadTestBase):
[
16000
],
[
16000
],
[
2
],
[
2
],
[
False
],
[
False
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load large wav file correctly."""
"""`sox_io_backend.load` can load large wav file correctly."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
...
@@ -148,7 +148,7 @@ class TestLoad(LoadTestBase):
...
@@ -148,7 +148,7 @@ class TestLoad(LoadTestBase):
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
4
,
8
,
16
,
32
],
[
4
,
8
,
16
,
32
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
"""`sox_io_backend.load` can load wav file with more than 2 channels."""
"""`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate
=
8000
sample_rate
=
8000
...
@@ -159,7 +159,7 @@ class TestLoad(LoadTestBase):
...
@@ -159,7 +159,7 @@ class TestLoad(LoadTestBase):
[
8000
,
16000
,
44100
],
[
8000
,
16000
,
44100
],
[
1
,
2
],
[
1
,
2
],
[
96
,
128
,
160
,
192
,
224
,
256
,
320
],
[
96
,
128
,
160
,
192
,
224
,
256
,
320
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.load` can load mp3 format correctly."""
"""`sox_io_backend.load` can load mp3 format correctly."""
self
.
assert_mp3
(
sample_rate
,
num_channels
,
bit_rate
,
duration
=
1
)
self
.
assert_mp3
(
sample_rate
,
num_channels
,
bit_rate
,
duration
=
1
)
...
@@ -168,7 +168,7 @@ class TestLoad(LoadTestBase):
...
@@ -168,7 +168,7 @@ class TestLoad(LoadTestBase):
[
16000
],
[
16000
],
[
2
],
[
2
],
[
128
],
[
128
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_mp3_large
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
def
test_mp3_large
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.load` can load large mp3 file correctly."""
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
...
@@ -178,7 +178,7 @@ class TestLoad(LoadTestBase):
...
@@ -178,7 +178,7 @@ class TestLoad(LoadTestBase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
list
(
range
(
9
)),
list
(
range
(
9
)),
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load flac format correctly."""
"""`sox_io_backend.load` can load flac format correctly."""
self
.
assert_flac
(
sample_rate
,
num_channels
,
compression_level
,
duration
=
1
)
self
.
assert_flac
(
sample_rate
,
num_channels
,
compression_level
,
duration
=
1
)
...
@@ -187,7 +187,7 @@ class TestLoad(LoadTestBase):
...
@@ -187,7 +187,7 @@ class TestLoad(LoadTestBase):
[
16000
],
[
16000
],
[
2
],
[
2
],
[
0
],
[
0
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load large flac file correctly."""
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
...
@@ -197,7 +197,7 @@ class TestLoad(LoadTestBase):
...
@@ -197,7 +197,7 @@ class TestLoad(LoadTestBase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.load` can load vorbis format correctly."""
"""`sox_io_backend.load` can load vorbis format correctly."""
self
.
assert_vorbis
(
sample_rate
,
num_channels
,
quality_level
,
duration
=
1
)
self
.
assert_vorbis
(
sample_rate
,
num_channels
,
quality_level
,
duration
=
1
)
...
@@ -206,7 +206,7 @@ class TestLoad(LoadTestBase):
...
@@ -206,7 +206,7 @@ class TestLoad(LoadTestBase):
[
16000
],
[
16000
],
[
2
],
[
2
],
[
10
],
[
10
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_vorbis_large
(
self
,
sample_rate
,
num_channels
,
quality_level
):
def
test_vorbis_large
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.load` can load large vorbis file correctly."""
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
...
@@ -230,14 +230,14 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
...
@@ -230,14 +230,14 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
0
,
1
,
10
,
100
,
1000
],
[
0
,
1
,
10
,
100
,
1000
],
[
-
1
,
1
,
10
,
100
,
1000
],
[
-
1
,
1
,
10
,
100
,
1000
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_frame
(
self
,
frame_offset
,
num_frames
):
def
test_frame
(
self
,
frame_offset
,
num_frames
):
"""num_frames and frame_offset correctly specify the region of data"""
"""num_frames and frame_offset correctly specify the region of data"""
found
,
_
=
sox_io_backend
.
load
(
self
.
path
,
frame_offset
,
num_frames
)
found
,
_
=
sox_io_backend
.
load
(
self
.
path
,
frame_offset
,
num_frames
)
frame_end
=
None
if
num_frames
==
-
1
else
frame_offset
+
num_frames
frame_end
=
None
if
num_frames
==
-
1
else
frame_offset
+
num_frames
self
.
assertEqual
(
found
,
self
.
original
[:,
frame_offset
:
frame_end
])
self
.
assertEqual
(
found
,
self
.
original
[:,
frame_offset
:
frame_end
])
@
parameterized
.
expand
([(
True
,
),
(
False
,
)],
name_func
=
get_test_name
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)],
name_func
=
name_func
)
def
test_channels_first
(
self
,
channels_first
):
def
test_channels_first
(
self
,
channels_first
):
"""channels_first swaps axes"""
"""channels_first swaps axes"""
found
,
_
=
sox_io_backend
.
load
(
self
.
path
,
channels_first
=
channels_first
)
found
,
_
=
sox_io_backend
.
load
(
self
.
path
,
channels_first
=
channels_first
)
...
...
test/sox_io_backend/test_roundtrip.py
View file @
a20da5e3
...
@@ -8,10 +8,10 @@ from ..common_utils import (
...
@@ -8,10 +8,10 @@ from ..common_utils import (
PytorchTestCase
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoExtension
,
skipIfNoExtension
,
get_wav_data
,
)
)
from
.common
import
(
from
.common
import
(
get_test_name
,
name_func
,
get_wav_data
,
)
)
...
@@ -23,7 +23,7 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
...
@@ -23,7 +23,7 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""save/load round trip should not degrade data for wav formats"""
"""save/load round trip should not degrade data for wav formats"""
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
original
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
)
...
@@ -39,7 +39,7 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
...
@@ -39,7 +39,7 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
list
(
range
(
9
)),
list
(
range
(
9
)),
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""save/load round trip should not degrade data for flac formats"""
"""save/load round trip should not degrade data for flac formats"""
original
=
get_wav_data
(
'float32'
,
num_channels
)
original
=
get_wav_data
(
'float32'
,
num_channels
)
...
...
test/sox_io_backend/test_save.py
View file @
a20da5e3
...
@@ -8,14 +8,14 @@ from ..common_utils import (
...
@@ -8,14 +8,14 @@ from ..common_utils import (
PytorchTestCase
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoExtension
,
skipIfNoExtension
,
)
from
.common
import
(
get_test_name
,
get_wav_data
,
get_wav_data
,
load_wav
,
load_wav
,
save_wav
,
save_wav
,
sox_utils
,
)
from
.common
import
(
name_func
,
)
)
from
.
import
sox_utils
class
SaveTestBase
(
TempDirMixin
,
PytorchTestCase
):
class
SaveTestBase
(
TempDirMixin
,
PytorchTestCase
):
...
@@ -176,7 +176,7 @@ class TestSave(SaveTestBase):
...
@@ -176,7 +176,7 @@ class TestSave(SaveTestBase):
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.save` can save wav format."""
"""`sox_io_backend.save` can save wav format."""
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
self
.
assert_wav
(
dtype
,
sample_rate
,
num_channels
,
num_frames
=
None
)
...
@@ -185,7 +185,7 @@ class TestSave(SaveTestBase):
...
@@ -185,7 +185,7 @@ class TestSave(SaveTestBase):
[
'float32'
],
[
'float32'
],
[
16000
],
[
16000
],
[
2
],
[
2
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav_large
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.save` can save large wav file."""
"""`sox_io_backend.save` can save large wav file."""
two_hours
=
2
*
60
*
60
*
sample_rate
two_hours
=
2
*
60
*
60
*
sample_rate
...
@@ -194,7 +194,7 @@ class TestSave(SaveTestBase):
...
@@ -194,7 +194,7 @@ class TestSave(SaveTestBase):
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
4
,
8
,
16
,
32
],
[
4
,
8
,
16
,
32
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
def
test_multiple_channels
(
self
,
dtype
,
num_channels
):
"""`sox_io_backend.save` can save wav with more than 2 channels."""
"""`sox_io_backend.save` can save wav with more than 2 channels."""
sample_rate
=
8000
sample_rate
=
8000
...
@@ -204,7 +204,7 @@ class TestSave(SaveTestBase):
...
@@ -204,7 +204,7 @@ class TestSave(SaveTestBase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
-
4.2
,
-
0.2
,
0
,
0.2
,
96
,
128
,
160
,
192
,
224
,
256
,
320
],
[
-
4.2
,
-
0.2
,
0
,
0.2
,
96
,
128
,
160
,
192
,
224
,
256
,
320
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.save` can save mp3 format."""
"""`sox_io_backend.save` can save mp3 format."""
self
.
assert_mp3
(
sample_rate
,
num_channels
,
bit_rate
,
duration
=
1
)
self
.
assert_mp3
(
sample_rate
,
num_channels
,
bit_rate
,
duration
=
1
)
...
@@ -213,7 +213,7 @@ class TestSave(SaveTestBase):
...
@@ -213,7 +213,7 @@ class TestSave(SaveTestBase):
[
16000
],
[
16000
],
[
2
],
[
2
],
[
128
],
[
128
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_mp3_large
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
def
test_mp3_large
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.save` can save large mp3 file."""
"""`sox_io_backend.save` can save large mp3 file."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
...
@@ -223,7 +223,7 @@ class TestSave(SaveTestBase):
...
@@ -223,7 +223,7 @@ class TestSave(SaveTestBase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
list
(
range
(
9
)),
list
(
range
(
9
)),
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.save` can save flac format."""
"""`sox_io_backend.save` can save flac format."""
self
.
assert_flac
(
sample_rate
,
num_channels
,
compression_level
,
duration
=
1
)
self
.
assert_flac
(
sample_rate
,
num_channels
,
compression_level
,
duration
=
1
)
...
@@ -232,7 +232,7 @@ class TestSave(SaveTestBase):
...
@@ -232,7 +232,7 @@ class TestSave(SaveTestBase):
[
16000
],
[
16000
],
[
2
],
[
2
],
[
0
],
[
0
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.save` can save large flac file."""
"""`sox_io_backend.save` can save large flac file."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
...
@@ -242,7 +242,7 @@ class TestSave(SaveTestBase):
...
@@ -242,7 +242,7 @@ class TestSave(SaveTestBase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
[
-
1
,
0
,
1
,
2
,
3
,
3.6
,
5
,
10
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.save` can save vorbis format."""
"""`sox_io_backend.save` can save vorbis format."""
self
.
assert_vorbis
(
sample_rate
,
num_channels
,
quality_level
,
duration
=
20
)
self
.
assert_vorbis
(
sample_rate
,
num_channels
,
quality_level
,
duration
=
20
)
...
@@ -255,7 +255,7 @@ class TestSave(SaveTestBase):
...
@@ -255,7 +255,7 @@ class TestSave(SaveTestBase):
[16000],
[16000],
[2],
[2],
[10],
[10],
)), name_func=
get_test_name
)
)), name_func=
name_func
)
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save large vorbis file correctly."""
"""`sox_io_backend.save` can save large vorbis file correctly."""
two_hours = 2 * 60 * 60
two_hours = 2 * 60 * 60
...
@@ -267,7 +267,7 @@ class TestSave(SaveTestBase):
...
@@ -267,7 +267,7 @@ class TestSave(SaveTestBase):
@
skipIfNoExtension
@
skipIfNoExtension
class
TestSaveParams
(
TempDirMixin
,
PytorchTestCase
):
class
TestSaveParams
(
TempDirMixin
,
PytorchTestCase
):
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
@
parameterized
.
expand
([(
True
,
),
(
False
,
)],
name_func
=
get_test_name
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)],
name_func
=
name_func
)
def
test_channels_first
(
self
,
channels_first
):
def
test_channels_first
(
self
,
channels_first
):
"""channels_first swaps axes"""
"""channels_first swaps axes"""
path
=
self
.
get_temp_path
(
'data.wav'
)
path
=
self
.
get_temp_path
(
'data.wav'
)
...
@@ -280,7 +280,7 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
...
@@ -280,7 +280,7 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
@
parameterized
.
expand
([
@
parameterized
.
expand
([
'float32'
,
'int32'
,
'int16'
,
'uint8'
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
name_func
=
get_test_name
)
],
name_func
=
name_func
)
def
test_noncontiguous
(
self
,
dtype
):
def
test_noncontiguous
(
self
,
dtype
):
"""Noncontiguous tensors are saved correctly"""
"""Noncontiguous tensors are saved correctly"""
path
=
self
.
get_temp_path
(
'data.wav'
)
path
=
self
.
get_temp_path
(
'data.wav'
)
...
...
test/sox_io_backend/test_torchscript.py
View file @
a20da5e3
...
@@ -10,14 +10,14 @@ from ..common_utils import (
...
@@ -10,14 +10,14 @@ from ..common_utils import (
TorchaudioTestCase
,
TorchaudioTestCase
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoExtension
,
skipIfNoExtension
,
)
from
.common
import
(
get_test_name
,
get_wav_data
,
get_wav_data
,
save_wav
,
save_wav
,
load_wav
,
load_wav
,
sox_utils
,
)
from
.common
import
(
name_func
,
)
)
from
.
import
sox_utils
def
py_info_func
(
filepath
:
str
)
->
torch
.
classes
.
torchaudio
.
SignalInfo
:
def
py_info_func
(
filepath
:
str
)
->
torch
.
classes
.
torchaudio
.
SignalInfo
:
...
@@ -47,7 +47,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
...
@@ -47,7 +47,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_info_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_info_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`sox_io_backend.info` is torchscript-able and returns the same result"""
"""`sox_io_backend.info` is torchscript-able and returns the same result"""
audio_path
=
self
.
get_temp_path
(
f
'
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
audio_path
=
self
.
get_temp_path
(
f
'
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
.wav'
)
...
@@ -71,7 +71,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
...
@@ -71,7 +71,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
[
1
,
2
],
[
1
,
2
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_load_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
def
test_load_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
channels_first
):
"""`sox_io_backend.load` is torchscript-able and returns the same result"""
"""`sox_io_backend.load` is torchscript-able and returns the same result"""
audio_path
=
self
.
get_temp_path
(
f
'test_load_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
_
{
normalize
}
.wav'
)
audio_path
=
self
.
get_temp_path
(
f
'test_load_
{
dtype
}
_
{
sample_rate
}
_
{
num_channels
}
_
{
normalize
}
.wav'
)
...
@@ -94,7 +94,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
...
@@ -94,7 +94,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_save_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_save_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
script_path
=
self
.
get_temp_path
(
'save_func.zip'
)
script_path
=
self
.
get_temp_path
(
'save_func.zip'
)
torch
.
jit
.
script
(
py_save_func
).
save
(
script_path
)
torch
.
jit
.
script
(
py_save_func
).
save
(
script_path
)
...
@@ -119,7 +119,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
...
@@ -119,7 +119,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
list
(
range
(
9
)),
list
(
range
(
9
)),
)),
name_func
=
get_test_name
)
)),
name_func
=
name_func
)
def
test_save_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_save_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
script_path
=
self
.
get_temp_path
(
'save_func.zip'
)
script_path
=
self
.
get_temp_path
(
'save_func.zip'
)
torch
.
jit
.
script
(
py_save_func
).
save
(
script_path
)
torch
.
jit
.
script
(
py_save_func
).
save
(
script_path
)
...
...
test/test_io.py
View file @
a20da5e3
import
os
import
os
import
math
import
math
import
shutil
import
tempfile
import
unittest
import
unittest
import
torch
import
torch
import
torchaudio
import
torchaudio
from
.common_utils
import
BACKENDS
,
BACKENDS_MP3
,
create_temp_assets_dir
from
.common_utils
import
BACKENDS
,
BACKENDS_MP3
,
get_asset_path
def
create_temp_assets_dir
():
"""
Creates a temporary directory and moves all files from test/assets there.
Returns a Tuple[string, TemporaryDirectory] which is the folder path
and object.
"""
tmp_dir
=
tempfile
.
TemporaryDirectory
()
shutil
.
copytree
(
get_asset_path
(),
os
.
path
.
join
(
tmp_dir
.
name
,
"assets"
))
return
tmp_dir
.
name
,
tmp_dir
class
Test_LoadSave
(
unittest
.
TestCase
):
class
Test_LoadSave
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment