Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
a6efd497
Unverified
Commit
a6efd497
authored
Aug 11, 2020
by
moto
Committed by
GitHub
Aug 11, 2020
Browse files
Add SPHERE format support (#871)
parent
2a6b6b55
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
1 deletion
+93
-1
test/torchaudio_unittest/sox_io_backend/info_test.py
test/torchaudio_unittest/sox_io_backend/info_test.py
+14
-0
test/torchaudio_unittest/sox_io_backend/load_test.py
test/torchaudio_unittest/sox_io_backend/load_test.py
+30
-0
test/torchaudio_unittest/sox_io_backend/save_test.py
test/torchaudio_unittest/sox_io_backend/save_test.py
+40
-0
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+3
-1
torchaudio/csrc/sox_utils.cpp
torchaudio/csrc/sox_utils.cpp
+6
-0
No files found.
test/torchaudio_unittest/sox_io_backend/info_test.py
View file @
a6efd497
...
...
@@ -108,6 +108,20 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.info` can check sph file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
'data.sph'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
@
skipIfNoExtension
class
TestInfoOpus
(
PytorchTestCase
):
...
...
test/torchaudio_unittest/sox_io_backend/load_test.py
View file @
a6efd497
...
...
@@ -120,6 +120,28 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
def
assert_sphere
(
self
,
sample_rate
,
num_channels
,
duration
):
"""`sox_io_backend.load` can load sph format.
This test takes the same strategy as mp3 to compare the result
"""
path
=
self
.
get_temp_path
(
'1.original.sph'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
# 1. Generate sph with sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
bit_depth
=
32
,
duration
=
duration
)
# 2. Convert to wav with sox
sox_utils
.
convert_audio_file
(
path
,
ref_path
)
# 3. Load sph with torchaudio
data
,
sr
=
sox_io_backend
.
load
(
path
)
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
...
...
@@ -230,6 +252,14 @@ class TestLoad(LoadTestBase):
assert
sample_rate
==
sr
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.load` can load sph format correctly."""
self
.
assert_sphere
(
sample_rate
,
num_channels
,
duration
=
1
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
...
...
test/torchaudio_unittest/sox_io_backend/save_test.py
View file @
a6efd497
...
...
@@ -168,6 +168,38 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
else
:
raise
error
def
assert_sphere
(
self
,
sample_rate
,
num_channels
,
duration
):
"""`sox_io_backend.save` can save sph format.
This test takes the same strategy as mp3 to compare the result
"""
src_path
=
self
.
get_temp_path
(
'1.reference.wav'
)
flc_path
=
self
.
get_temp_path
(
'2.1.torchaudio.sph'
)
wav_path
=
self
.
get_temp_path
(
'2.2.torchaudio.wav'
)
flc_path_sox
=
self
.
get_temp_path
(
'3.1.sox.sph'
)
wav_path_sox
=
self
.
get_temp_path
(
'3.2.sox.wav'
)
# 1. Generate original wav
data
=
get_wav_data
(
'float32'
,
num_channels
,
normalize
=
True
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to sph with torchaudio
sox_io_backend
.
save
(
flc_path
,
load_wav
(
src_path
)[
0
],
sample_rate
)
# 2.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path
,
wav_path
,
bit_depth
=
32
)
# 2.3. Load
found
=
load_wav
(
wav_path
)[
0
]
# 3.1. Convert the original wav to sph with SoX
sox_utils
.
convert_audio_file
(
src_path
,
flc_path_sox
)
# 3.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path_sox
,
wav_path_sox
,
bit_depth
=
32
)
# 3.3. Load
expected
=
load_wav
(
wav_path_sox
)[
0
]
self
.
assertEqual
(
found
,
expected
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
...
...
@@ -262,6 +294,14 @@ class TestSave(SaveTestBase):
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
'''
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.save` can save sph format."""
self
.
assert_sphere
(
sample_rate
,
num_channels
,
duration
=
1
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
...
...
torchaudio/backend/sox_io_backend.py
View file @
a6efd497
...
...
@@ -58,6 +58,7 @@ def load(
* FLAC
* OGG/VORBIS
* OPUS
* SPHERE
To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
...
...
@@ -132,6 +133,7 @@ def save(
* MP3
* FLAC
* OGG/VORBIS
* SPHERE
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
...
...
@@ -158,7 +160,7 @@ def save(
"""
if
compression
is
None
:
ext
=
str
(
filepath
)[
-
3
:].
lower
()
if
ext
==
'wav'
:
if
ext
in
[
'wav'
,
'sph'
]
:
compression
=
0.
elif
ext
==
'mp3'
:
compression
=
-
4.5
...
...
torchaudio/csrc/sox_utils.cpp
View file @
a6efd497
...
...
@@ -234,6 +234,8 @@ sox_encoding_t get_encoding(
return
SOX_ENCODING_FLOAT
;
throw
std
::
runtime_error
(
"Unsupported dtype."
);
}
if
(
filetype
==
"sph"
)
return
SOX_ENCODING_SIGN2
;
throw
std
::
runtime_error
(
"Unsupported file type."
);
}
...
...
@@ -257,6 +259,8 @@ unsigned get_precision(
return
32
;
throw
std
::
runtime_error
(
"Unsupported dtype."
);
}
if
(
filetype
==
"sph"
)
return
32
;
throw
std
::
runtime_error
(
"Unsupported file type."
);
}
...
...
@@ -285,6 +289,8 @@ sox_encodinginfo_t get_encodinginfo(
return
compression
;
if
(
filetype
==
"wav"
)
return
0.
;
if
(
filetype
==
"sph"
)
return
0.
;
throw
std
::
runtime_error
(
"Unsupported file type."
);
}();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment