Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
80a8739a
Unverified
Commit
80a8739a
authored
Mar 17, 2021
by
Caroline Chen
Committed by
GitHub
Mar 17, 2021
Browse files
Refactor sox_io load_test (#1394)
parent
6bad3a66
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
194 deletions
+58
-194
test/torchaudio_unittest/backend/sox_io/load_test.py
test/torchaudio_unittest/backend/sox_io/load_test.py
+58
-194
No files found.
test/torchaudio_unittest/backend/sox_io/load_test.py
View file @
80a8739a
...
@@ -29,217 +29,76 @@ if _mod_utils.is_module_available("requests"):
...
@@ -29,217 +29,76 @@ if _mod_utils.is_module_available("requests"):
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
class
LoadTestBase
(
TempDirMixin
,
PytorchTestCase
):
def
assert_wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
):
def
assert_format
(
"""`sox_io_backend.load` can load wav format correctly.
self
,
format
:
str
,
Wav data loaded with sox_io backend should match those with scipy
sample_rate
:
float
,
"""
num_channels
:
int
,
path
=
self
.
get_temp_path
(
'reference.wav'
)
compression
:
float
=
None
,
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
normalize
,
num_frames
=
duration
*
sample_rate
)
bit_depth
:
int
=
None
,
save_wav
(
path
,
data
,
sample_rate
)
duration
:
float
=
1
,
expected
=
load_wav
(
path
,
normalize
=
normalize
)[
0
]
normalize
:
bool
=
True
,
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
normalize
)
encoding
:
str
=
None
,
assert
sr
==
sample_rate
atol
:
float
=
4e-05
,
self
.
assertEqual
(
data
,
expected
)
rtol
:
float
=
1.3e-06
,
):
def
assert_24bit_wav
(
self
,
sample_rate
,
num_channels
,
normalize
,
duration
):
"""`sox_io_backend.load` can load given format correctly.
""" `sox_io_backend.load` can load 24-bit signed PCM wav format. Since torch does not support the ``int24`` dtype,
we implicitly cast the resulting tensor to the ``int32`` dtype.
file encodings introduce delay and boundary effects so
we create a reference wav file from the original file format
It is not possible to use #assert_wav method above, as #get_wav_data does not support
the 'int24' dtype. This is because torch does not support the ``int24`` dtype.
Hence, we must use the following workaround.
x
|
| 1. Generate 24-bit wav with Sox.
|
v 2. Convert 24-bit wav to 32-bit wav with Sox.
wav(24-bit) ----------------------> wav(32-bit)
| |
| 3. Load 24-bit wav with torchaudio| 4. Load 32-bit wav with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
# Underlying assumptions are:
# i. Sox properly converts from 24-bit to 32-bit
# ii. Loading 32-bit wav file with scipy is correct.
"""
path
=
self
.
get_temp_path
(
'1.original.wav'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
# 1. Generate 24-bit signed wav with Sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
bit_depth
=
24
,
duration
=
duration
)
# 2. Convert from 24-bit wav to 32-bit wav with sox
sox_utils
.
convert_audio_file
(
path
,
ref_path
,
bit_depth
=
32
)
# 3. Load 24-bit wav with torchaudio
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
normalize
)
# 4. Load 32-bit wav with scipy
data_ref
=
load_wav
(
ref_path
,
normalize
=
normalize
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
3e-03
,
rtol
=
1.3e-06
)
def
assert_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
,
duration
):
"""`sox_io_backend.load` can load mp3 format.
mp3 encoding introduces delay and boundary effects so
we create reference wav file from mp3
x
x
|
|
| 1. Generate
mp3
with Sox
| 1. Generate
given format
with Sox
|
|
v 2. Convert to wav with Sox
v 2. Convert to wav with Sox
mp3 --------
----------------------> wav
given format
----------------------> wav
| |
| |
| 3. Load with torchaudio
| 4. Load with scipy
|
3. Load with torchaudio | 4. Load with scipy
| |
| |
v v
v v
tensor ----------> x <----------- tensor
tensor ----------> x <----------- tensor
5. Compare
5. Compare
Underlying assumptions are
:
Underlying assumptions are
;
i. Conversion of
mp3
to wav with Sox preserves data.
i. Conversion of
given format
to wav with Sox preserves data.
ii. Loading wav file with scipy is correct.
ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allows to load reference
mp3 d
at
a
By combining i & ii, step 2. and 4. allows to load reference
given form
at
without using torchaudio
data
without using torchaudio
"""
"""
path
=
self
.
get_temp_path
(
'1.original.mp3'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
# 1. Generate mp3 with sox
path
=
self
.
get_temp_path
(
f
'1.original.
{
format
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
bit_rate
,
duration
=
duration
)
# 2. Convert to wav with sox
sox_utils
.
convert_audio_file
(
path
,
ref_path
)
# 3. Load mp3 with torchaudio
data
,
sr
=
sox_io_backend
.
load
(
path
)
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
3e-03
,
rtol
=
1.3e-06
)
def
assert_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
,
duration
):
"""`sox_io_backend.load` can load flac format.
This test takes the same strategy as mp3 to compare the result
"""
path
=
self
.
get_temp_path
(
'1.original.flac'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
# 1. Generate
flac
with sox
# 1. Generate
the given format
with sox
sox_utils
.
gen_audio_file
(
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
path
,
sample_rate
,
num_channels
,
encoding
=
encoding
,
compression
=
compression_level
,
bit_depth
=
16
,
duration
=
duration
)
compression
=
compression
,
bit_depth
=
bit_depth
,
duration
=
duration
,
)
# 2. Convert to wav with sox
# 2. Convert to wav with sox
sox_utils
.
convert_audio_file
(
path
,
ref_path
)
wav_bit_depth
=
32
if
bit_depth
==
24
else
None
# for 24-bit wav
# 3. Load flac with torchaudio
sox_utils
.
convert_audio_file
(
path
,
ref_path
,
bit_depth
=
wav_bit_depth
)
data
,
sr
=
sox_io_backend
.
load
(
path
)
# 3. Load the given format with torchaudio
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
def
assert_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
,
duration
):
"""`sox_io_backend.load` can load vorbis format.
This test takes the same strategy as mp3 to compare the result
"""
path
=
self
.
get_temp_path
(
'1.original.vorbis'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
# 1. Generate vorbis with sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
bit_depth
=
16
,
duration
=
duration
)
# 2. Convert to wav with sox
sox_utils
.
convert_audio_file
(
path
,
ref_path
)
# 3. Load vorbis with torchaudio
data
,
sr
=
sox_io_backend
.
load
(
path
)
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
def
assert_sphere
(
self
,
sample_rate
,
num_channels
,
duration
):
"""`sox_io_backend.load` can load sph format.
This test takes the same strategy as mp3 to compare the result
"""
path
=
self
.
get_temp_path
(
'1.original.sph'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
# 1. Generate sph with sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
bit_depth
=
32
,
duration
=
duration
)
# 2. Convert to wav with sox
sox_utils
.
convert_audio_file
(
path
,
ref_path
)
# 3. Load sph with torchaudio
data
,
sr
=
sox_io_backend
.
load
(
path
)
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
def
assert_amb
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
):
"""`sox_io_backend.load` can load amb format.
This test takes the same strategy as mp3 to compare the result
"""
path
=
self
.
get_temp_path
(
'1.original.amb'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
# 1. Generate amb with sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
encoding
=
sox_utils
.
get_encoding
(
dtype
),
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
),
duration
=
duration
)
# 2. Convert to wav with sox
sox_utils
.
convert_audio_file
(
path
,
ref_path
)
# 3. Load amb with torchaudio
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
normalize
)
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
normalize
)
# 4. Load wav with scipy
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
,
normalize
=
normalize
)[
0
]
data_ref
=
load_wav
(
ref_path
,
normalize
=
normalize
)[
0
]
# 5. Compare
# 5. Compare
assert
sr
==
sample_rate
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
self
.
assertEqual
(
data
,
data_ref
,
atol
=
atol
,
rtol
=
rtol
)
def
assert_
amr_nb
(
self
,
duration
):
def
assert_
wav
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
):
"""`sox_io_backend.load` can load
amr-nb
format.
"""`sox_io_backend.load` can load
wav
format
correctly
.
This test takes
th
e
s
ame strategy as mp3 to compare the result
Wav data loaded wi
th s
ox_io backend should match those with scipy
"""
"""
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'reference.wav'
)
num_channels
=
1
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
normalize
,
num_frames
=
duration
*
sample_rate
)
path
=
self
.
get_temp_path
(
'1.original.amr-nb'
)
save_wav
(
path
,
data
,
sample_rate
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
expected
=
load_wav
(
path
,
normalize
=
normalize
)[
0
]
data
,
sr
=
sox_io_backend
.
load
(
path
,
normalize
=
normalize
)
# 1. Generate amr-nb with sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
bit_depth
=
32
,
duration
=
duration
)
# 2. Convert to wav with sox
sox_utils
.
convert_audio_file
(
path
,
ref_path
)
# 3. Load amr-nb with torchaudio
data
,
sr
=
sox_io_backend
.
load
(
path
)
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
self
.
assertEqual
(
data
,
expected
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
...
@@ -263,7 +122,7 @@ class TestLoad(LoadTestBase):
...
@@ -263,7 +122,7 @@ class TestLoad(LoadTestBase):
)),
name_func
=
name_func
)
)),
name_func
=
name_func
)
def
test_24bit_wav
(
self
,
sample_rate
,
num_channels
,
normalize
):
def
test_24bit_wav
(
self
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self
.
assert_
24bit_wav
(
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
self
.
assert_
format
(
"wav"
,
sample_rate
,
num_channels
,
bit_depth
=
24
,
normalize
=
normalize
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'int16'
],
[
'int16'
],
...
@@ -293,7 +152,7 @@ class TestLoad(LoadTestBase):
...
@@ -293,7 +152,7 @@ class TestLoad(LoadTestBase):
)),
name_func
=
name_func
)
)),
name_func
=
name_func
)
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
def
test_mp3
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.load` can load mp3 format correctly."""
"""`sox_io_backend.load` can load mp3 format correctly."""
self
.
assert_
mp3
(
sample_rate
,
num_channels
,
bit_rate
,
duration
=
1
)
self
.
assert_
format
(
"mp3"
,
sample_rate
,
num_channels
,
compression
=
bit_rate
,
duration
=
1
,
atol
=
5e-05
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
16000
],
...
@@ -303,7 +162,7 @@ class TestLoad(LoadTestBase):
...
@@ -303,7 +162,7 @@ class TestLoad(LoadTestBase):
def
test_mp3_large
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
def
test_mp3_large
(
self
,
sample_rate
,
num_channels
,
bit_rate
):
"""`sox_io_backend.load` can load large mp3 file correctly."""
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
self
.
assert_
mp3
(
sample_rate
,
num_channels
,
bit_rate
,
two_hours
)
self
.
assert_
format
(
"mp3"
,
sample_rate
,
num_channels
,
compression
=
bit_rate
,
duration
=
two_hours
,
atol
=
5e-05
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -312,7 +171,7 @@ class TestLoad(LoadTestBase):
...
@@ -312,7 +171,7 @@ class TestLoad(LoadTestBase):
)),
name_func
=
name_func
)
)),
name_func
=
name_func
)
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_flac
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load flac format correctly."""
"""`sox_io_backend.load` can load flac format correctly."""
self
.
assert_f
lac
(
sample_rate
,
num_channels
,
compression
_level
,
duration
=
1
)
self
.
assert_f
ormat
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
bit_depth
=
16
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
16000
],
...
@@ -322,7 +181,8 @@ class TestLoad(LoadTestBase):
...
@@ -322,7 +181,8 @@ class TestLoad(LoadTestBase):
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
def
test_flac_large
(
self
,
sample_rate
,
num_channels
,
compression_level
):
"""`sox_io_backend.load` can load large flac file correctly."""
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
self
.
assert_flac
(
sample_rate
,
num_channels
,
compression_level
,
two_hours
)
self
.
assert_format
(
"flac"
,
sample_rate
,
num_channels
,
compression
=
compression_level
,
bit_depth
=
16
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -331,7 +191,7 @@ class TestLoad(LoadTestBase):
...
@@ -331,7 +191,7 @@ class TestLoad(LoadTestBase):
)),
name_func
=
name_func
)
)),
name_func
=
name_func
)
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
def
test_vorbis
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.load` can load vorbis format correctly."""
"""`sox_io_backend.load` can load vorbis format correctly."""
self
.
assert_vorbis
(
sample_rate
,
num_channels
,
quality_level
,
duration
=
1
)
self
.
assert_
format
(
"
vorbis
"
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
bit_depth
=
16
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
16000
],
[
16000
],
...
@@ -341,7 +201,8 @@ class TestLoad(LoadTestBase):
...
@@ -341,7 +201,8 @@ class TestLoad(LoadTestBase):
def
test_vorbis_large
(
self
,
sample_rate
,
num_channels
,
quality_level
):
def
test_vorbis_large
(
self
,
sample_rate
,
num_channels
,
quality_level
):
"""`sox_io_backend.load` can load large vorbis file correctly."""
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours
=
2
*
60
*
60
two_hours
=
2
*
60
*
60
self
.
assert_vorbis
(
sample_rate
,
num_channels
,
quality_level
,
two_hours
)
self
.
assert_format
(
"vorbis"
,
sample_rate
,
num_channels
,
compression
=
quality_level
,
bit_depth
=
16
,
duration
=
two_hours
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'96k'
],
[
'96k'
],
...
@@ -366,7 +227,7 @@ class TestLoad(LoadTestBase):
...
@@ -366,7 +227,7 @@ class TestLoad(LoadTestBase):
)),
name_func
=
name_func
)
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.load` can load sph format correctly."""
"""`sox_io_backend.load` can load sph format correctly."""
self
.
assert_
sphere
(
sample_rate
,
num_channels
,
duration
=
1
)
self
.
assert_
format
(
"sph"
,
sample_rate
,
num_channels
,
bit_depth
=
32
,
duration
=
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
],
[
'float32'
,
'int32'
,
'int16'
],
...
@@ -375,12 +236,15 @@ class TestLoad(LoadTestBase):
...
@@ -375,12 +236,15 @@ class TestLoad(LoadTestBase):
[
False
,
True
],
[
False
,
True
],
)),
name_func
=
name_func
)
)),
name_func
=
name_func
)
def
test_amb
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
def
test_amb
(
self
,
dtype
,
sample_rate
,
num_channels
,
normalize
):
"""`sox_io_backend.load` can load sph format correctly."""
"""`sox_io_backend.load` can load amb format correctly."""
self
.
assert_amb
(
dtype
,
sample_rate
,
num_channels
,
normalize
,
duration
=
1
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
encoding
=
sox_utils
.
get_encoding
(
dtype
)
self
.
assert_format
(
"amb"
,
sample_rate
,
num_channels
,
bit_depth
=
bit_depth
,
duration
=
1
,
encoding
=
encoding
,
normalize
=
normalize
)
def
test_amr_nb
(
self
):
def
test_amr_nb
(
self
):
"""`sox_io_backend.load` can load amr_nb format correctly."""
"""`sox_io_backend.load` can load amr_nb format correctly."""
self
.
assert_
amr_nb
(
duration
=
1
)
self
.
assert_
format
(
"amr-nb"
,
sample_rate
=
8000
,
num_channels
=
1
,
bit_depth
=
32
,
duration
=
1
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment