Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
a6efd497
"vscode:/vscode.git/clone" did not exist on "a06af061f7df58320419b8958ec9730c1d4eba40"
Unverified
Commit
a6efd497
authored
Aug 11, 2020
by
moto
Committed by
GitHub
Aug 11, 2020
Browse files
Add SPHERE format support (#871)
parent
2a6b6b55
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
1 deletion
+93
-1
test/torchaudio_unittest/sox_io_backend/info_test.py
test/torchaudio_unittest/sox_io_backend/info_test.py
+14
-0
test/torchaudio_unittest/sox_io_backend/load_test.py
test/torchaudio_unittest/sox_io_backend/load_test.py
+30
-0
test/torchaudio_unittest/sox_io_backend/save_test.py
test/torchaudio_unittest/sox_io_backend/save_test.py
+40
-0
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+3
-1
torchaudio/csrc/sox_utils.cpp
torchaudio/csrc/sox_utils.cpp
+6
-0
No files found.
test/torchaudio_unittest/sox_io_backend/info_test.py
View file @
a6efd497
...
@@ -108,6 +108,20 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -108,6 +108,20 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.info` can check sph file correctly"""
duration
=
1
path
=
self
.
get_temp_path
(
'data.sph'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
@
skipIfNoExtension
@
skipIfNoExtension
class
TestInfoOpus
(
PytorchTestCase
):
class
TestInfoOpus
(
PytorchTestCase
):
...
...
test/torchaudio_unittest/sox_io_backend/load_test.py
View file @
a6efd497
...
@@ -120,6 +120,28 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
...
@@ -120,6 +120,28 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
assert
sr
==
sample_rate
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
def
assert_sphere
(
self
,
sample_rate
,
num_channels
,
duration
):
"""`sox_io_backend.load` can load sph format.
This test takes the same strategy as mp3 to compare the result
"""
path
=
self
.
get_temp_path
(
'1.original.sph'
)
ref_path
=
self
.
get_temp_path
(
'2.reference.wav'
)
# 1. Generate sph with sox
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
bit_depth
=
32
,
duration
=
duration
)
# 2. Convert to wav with sox
sox_utils
.
convert_audio_file
(
path
,
ref_path
)
# 3. Load sph with torchaudio
data
,
sr
=
sox_io_backend
.
load
(
path
)
# 4. Load wav with scipy
data_ref
=
load_wav
(
ref_path
)[
0
]
# 5. Compare
assert
sr
==
sample_rate
self
.
assertEqual
(
data
,
data_ref
,
atol
=
4e-05
,
rtol
=
1.3e-06
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
@
skipIfNoExtension
...
@@ -230,6 +252,14 @@ class TestLoad(LoadTestBase):
...
@@ -230,6 +252,14 @@ class TestLoad(LoadTestBase):
assert
sample_rate
==
sr
assert
sample_rate
==
sr
self
.
assertEqual
(
expected
,
found
)
self
.
assertEqual
(
expected
,
found
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.load` can load sph format correctly."""
self
.
assert_sphere
(
sample_rate
,
num_channels
,
duration
=
1
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
@
skipIfNoExtension
...
...
test/torchaudio_unittest/sox_io_backend/save_test.py
View file @
a6efd497
...
@@ -168,6 +168,38 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
...
@@ -168,6 +168,38 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
else
:
else
:
raise
error
raise
error
def
assert_sphere
(
self
,
sample_rate
,
num_channels
,
duration
):
"""`sox_io_backend.save` can save sph format.
This test takes the same strategy as mp3 to compare the result
"""
src_path
=
self
.
get_temp_path
(
'1.reference.wav'
)
flc_path
=
self
.
get_temp_path
(
'2.1.torchaudio.sph'
)
wav_path
=
self
.
get_temp_path
(
'2.2.torchaudio.wav'
)
flc_path_sox
=
self
.
get_temp_path
(
'3.1.sox.sph'
)
wav_path_sox
=
self
.
get_temp_path
(
'3.2.sox.wav'
)
# 1. Generate original wav
data
=
get_wav_data
(
'float32'
,
num_channels
,
normalize
=
True
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
src_path
,
data
,
sample_rate
)
# 2.1. Convert the original wav to sph with torchaudio
sox_io_backend
.
save
(
flc_path
,
load_wav
(
src_path
)[
0
],
sample_rate
)
# 2.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path
,
wav_path
,
bit_depth
=
32
)
# 2.3. Load
found
=
load_wav
(
wav_path
)[
0
]
# 3.1. Convert the original wav to sph with SoX
sox_utils
.
convert_audio_file
(
src_path
,
flc_path_sox
)
# 3.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils
.
convert_audio_file
(
flc_path_sox
,
wav_path_sox
,
bit_depth
=
32
)
# 3.3. Load
expected
=
load_wav
(
wav_path_sox
)[
0
]
self
.
assertEqual
(
found
,
expected
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
@
skipIfNoExtension
...
@@ -262,6 +294,14 @@ class TestSave(SaveTestBase):
...
@@ -262,6 +294,14 @@ class TestSave(SaveTestBase):
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
'''
'''
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
"""`sox_io_backend.save` can save sph format."""
self
.
assert_sphere
(
sample_rate
,
num_channels
,
duration
=
1
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
@
skipIfNoExtension
...
...
torchaudio/backend/sox_io_backend.py
View file @
a6efd497
...
@@ -58,6 +58,7 @@ def load(
...
@@ -58,6 +58,7 @@ def load(
* FLAC
* FLAC
* OGG/VORBIS
* OGG/VORBIS
* OPUS
* OPUS
* SPHERE
To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
...
@@ -132,6 +133,7 @@ def save(
...
@@ -132,6 +133,7 @@ def save(
* MP3
* MP3
* FLAC
* FLAC
* OGG/VORBIS
* OGG/VORBIS
* SPHERE
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
...
@@ -158,7 +160,7 @@ def save(
...
@@ -158,7 +160,7 @@ def save(
"""
"""
if
compression
is
None
:
if
compression
is
None
:
ext
=
str
(
filepath
)[
-
3
:].
lower
()
ext
=
str
(
filepath
)[
-
3
:].
lower
()
if
ext
==
'wav'
:
if
ext
in
[
'wav'
,
'sph'
]
:
compression
=
0.
compression
=
0.
elif
ext
==
'mp3'
:
elif
ext
==
'mp3'
:
compression
=
-
4.5
compression
=
-
4.5
...
...
torchaudio/csrc/sox_utils.cpp
View file @
a6efd497
...
@@ -234,6 +234,8 @@ sox_encoding_t get_encoding(
...
@@ -234,6 +234,8 @@ sox_encoding_t get_encoding(
return
SOX_ENCODING_FLOAT
;
return
SOX_ENCODING_FLOAT
;
throw
std
::
runtime_error
(
"Unsupported dtype."
);
throw
std
::
runtime_error
(
"Unsupported dtype."
);
}
}
if
(
filetype
==
"sph"
)
return
SOX_ENCODING_SIGN2
;
throw
std
::
runtime_error
(
"Unsupported file type."
);
throw
std
::
runtime_error
(
"Unsupported file type."
);
}
}
...
@@ -257,6 +259,8 @@ unsigned get_precision(
...
@@ -257,6 +259,8 @@ unsigned get_precision(
return
32
;
return
32
;
throw
std
::
runtime_error
(
"Unsupported dtype."
);
throw
std
::
runtime_error
(
"Unsupported dtype."
);
}
}
if
(
filetype
==
"sph"
)
return
32
;
throw
std
::
runtime_error
(
"Unsupported file type."
);
throw
std
::
runtime_error
(
"Unsupported file type."
);
}
}
...
@@ -285,6 +289,8 @@ sox_encodinginfo_t get_encodinginfo(
...
@@ -285,6 +289,8 @@ sox_encodinginfo_t get_encodinginfo(
return
compression
;
return
compression
;
if
(
filetype
==
"wav"
)
if
(
filetype
==
"wav"
)
return
0.
;
return
0.
;
if
(
filetype
==
"sph"
)
return
0.
;
throw
std
::
runtime_error
(
"Unsupported file type."
);
throw
std
::
runtime_error
(
"Unsupported file type."
);
}();
}();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment