Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b152ee61
Unverified
Commit
b152ee61
authored
Jan 29, 2021
by
moto
Committed by
GitHub
Jan 29, 2021
Browse files
Add encoding attribute to AudioMetaData (#1206)
parent
674a71d1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
327 additions
and
136 deletions
+327
-136
test/torchaudio_unittest/soundfile_backend/info_test.py
test/torchaudio_unittest/soundfile_backend/info_test.py
+12
-22
test/torchaudio_unittest/sox_io_backend/info_test.py
test/torchaudio_unittest/sox_io_backend/info_test.py
+206
-89
torchaudio/backend/_soundfile_backend.py
torchaudio/backend/_soundfile_backend.py
+38
-9
torchaudio/backend/common.py
torchaudio/backend/common.py
+10
-1
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+5
-6
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+49
-6
torchaudio/csrc/sox/io.h
torchaudio/csrc/sox/io.h
+5
-2
torchaudio/csrc/sox/register.cpp
torchaudio/csrc/sox/register.cpp
+2
-1
No files found.
test/torchaudio_unittest/soundfile_backend/info_test.py
View file @
b152ee61
...
...
@@ -13,6 +13,8 @@ from torchaudio_unittest.common_utils import (
get_wav_data
,
save_wav
,
)
# TODO refactor and move these to common location
from
torchaudio_unittest.sox_io_backend.info_test
import
get_encoding
,
get_bits_per_sample
from
.common
import
skipIfFormatNotSupported
,
parameterize
if
_mod_utils
.
is_module_available
(
"soundfile"
):
...
...
@@ -22,11 +24,10 @@ if _mod_utils.is_module_available("soundfile"):
@
skipIfNoModule
(
"soundfile"
)
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
@
parameterize
(
[
(
"float32"
,
32
),
(
"int32"
,
32
),
(
"int16"
,
16
),
(
"uint8"
,
8
)
],
[
8000
,
16000
],
[
1
,
2
],
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
def
test_wav
(
self
,
dtype
_and_bit_depth
,
sample_rate
,
num_channels
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check wav file correctly"""
dtype
,
bits_per_sample
=
dtype_and_bit_depth
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
...
...
@@ -37,25 +38,8 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
@
parameterize
(
[(
"float32"
,
32
),
(
"int32"
,
32
),
(
"int16"
,
16
),
(
"uint8"
,
8
)],
[
8000
,
16000
],
[
1
,
2
],
)
def
test_wav_multiple_channels
(
self
,
dtype_and_bit_depth
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check wav file with channels more than 2 correctly"""
dtype
,
bits_per_sample
=
dtype_and_bit_depth
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
soundfile_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
bits_per_sample
==
get_bits_per_sample
(
"wav"
,
dtype
)
assert
info
.
encoding
==
get_encoding
(
"wav"
,
dtype
)
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"FLAC"
)
...
...
@@ -72,6 +56,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
16
assert
info
.
encoding
==
"FLAC"
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"OGG"
)
...
...
@@ -88,6 +73,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"VORBIS"
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[(
'PCM_24'
,
24
),
(
'PCM_32'
,
32
)])
@
skipIfFormatNotSupported
(
"NIST"
)
...
...
@@ -105,6 +91,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"PCM_S"
def
test_unknown_subtype_warning
(
self
):
"""soundfile_backend.info issues a warning when the subtype is unknown
...
...
@@ -118,6 +105,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
frames
=
356
channels
=
2
subtype
=
'UNSEEN_SUBTYPE'
format
=
'UNKNOWN'
return
MockSoundFileInfo
()
with
patch
(
"soundfile.info"
,
_mock_info_func
):
...
...
@@ -147,6 +135,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"FLAC"
if
ext
==
'flac'
else
"PCM_S"
def
test_fileobj_wav
(
self
):
"""Loading audio via file-like object works"""
...
...
@@ -179,6 +168,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"FLAC"
if
ext
==
'flac'
else
"PCM_S"
def
test_tarobj_wav
(
self
):
"""Query compressed audio via file-like object works"""
...
...
test/torchaudio_unittest/sox_io_backend/info_test.py
View file @
b152ee61
import
io
import
os
import
itertools
import
tarfile
...
...
@@ -27,6 +28,30 @@ if _mod_utils.is_module_available("requests"):
import
requests
def
get_encoding
(
ext
,
dtype
):
exts
=
{
'mp3'
,
'flac'
,
'vorbis'
,
}
encodings
=
{
'float32'
:
'PCM_F'
,
'int32'
:
'PCM_S'
,
'int16'
:
'PCM_S'
,
'uint8'
:
'PCM_U'
,
}
return
ext
.
upper
()
if
ext
in
exts
else
encodings
[
dtype
]
def
get_bits_per_sample
(
ext
,
dtype
):
bits_per_samples
=
{
'flac'
:
24
,
'mp3'
:
0
,
'vorbis'
:
0
,
}
return
bits_per_samples
.
get
(
ext
,
sox_utils
.
get_bit_depth
(
dtype
))
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
...
...
@@ -46,6 +71,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
'wav'
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
...
...
@@ -63,6 +89,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
'wav'
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
...
...
@@ -83,6 +110,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
# assert info.num_frames == sample_rate * duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"MP3"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
...
...
@@ -102,6 +130,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
24
# FLAC standard
assert
info
.
encoding
==
"FLAC"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
...
...
@@ -121,6 +150,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"VORBIS"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
...
...
@@ -137,9 +167,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"PCM_S"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
1
,
2
],
)),
name_func
=
name_func
)
...
...
@@ -156,6 +187,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
get_encoding
(
"amb"
,
dtype
)
def
test_amr_nb
(
self
):
"""`sox_io_backend.info` can check amr-nb file correctly"""
...
...
@@ -171,6 +203,41 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"AMR_NB"
def
test_ulaw
(
self
):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.wav'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
'u-law'
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ULAW"
def
test_alaw
(
self
):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.wav'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
'a-law'
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ALAW"
@
skipIfNoExtension
...
...
@@ -188,6 +255,7 @@ class TestInfoOpus(PytorchTestCase):
assert
info
.
num_frames
==
32768
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"OPUS"
@
skipIfNoExtension
...
...
@@ -205,144 +273,193 @@ class TestLoadWithoutExtension(PytorchTestCase):
path
=
get_asset_path
(
"mp3_without_ext"
)
sinfo
=
sox_io_backend
.
info
(
path
,
format
=
"mp3"
)
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
num_frames
==
81216
assert
sinfo
.
num_channels
==
1
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
sinfo
.
encoding
==
"MP3"
class
FileObjTestBase
(
TempDirMixin
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
duration
=
num_frames
/
sample_rate
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
num_channels
,
encoding
=
sox_utils
.
get_encoding
(
dtype
),
bit_depth
=
bit_depth
,
duration
=
duration
)
return
path
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
open
(
path
,
'rb'
)
as
fileobj
:
return
sox_io_backend
.
info
(
fileobj
,
format_
)
def
_query_bytesio
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
return
sox_io_backend
.
info
(
fileobj
,
format_
)
def
_query_tarfile
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_fileobj
(
self
,
ext
,
bits_per_sampl
e
):
def
test_fileobj
(
self
,
ext
,
dtyp
e
):
"""Querying audio via file object works"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
duration
=
3
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
duration
=
duration
)
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
w
it
h
open
(
path
,
'rb'
)
as
fileobj
:
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format_
)
b
it
s_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
def
_test_bytesio
(
self
,
ext
,
bits_per_sample
,
duration
):
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_bytesio
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works for small data"""
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
duration
=
duration
)
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format_
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_bytesio
(
self
,
ext
,
bits_per_sample
):
"""Querying audio via ByteIO object works"""
self
.
_test_bytesio
(
ext
,
bits_per_sample
,
duration
=
3
)
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
])
def
test_bytesio_tiny
(
self
,
ext
,
bits_per_sample
):
def
test_bytesio_tiny
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works for small data"""
self
.
_test_bytesio
(
ext
,
bits_per_sample
,
duration
=
1
/
1600
)
sample_rate
=
8000
num_frames
=
4
num_channels
=
2
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_tarfile
(
self
,
ext
,
bits_per_sampl
e
):
def
test_tarfile
(
self
,
ext
,
dtyp
e
):
"""Querying compressed audio via file-like object works"""
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
duration
=
3
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
sinfo
=
self
.
_query_tarfile
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
num_channels
,
duration
=
duration
)
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format
=
format_
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
PytorchTestCase
):
class
TestFileObjectHttp
(
HttpServerMixin
,
FileObjTestBase
,
PytorchTestCase
):
def
_query_http
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
return
sox_io_backend
.
info
(
resp
.
raw
,
format
=
format_
)
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_requests
(
self
,
ext
,
bits_per_sampl
e
):
def
test_requests
(
self
,
ext
,
dtyp
e
):
"""Querying compressed audio via requests works"""
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
duration
=
3
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
num_channels
,
duration
=
duration
)
sinfo
=
self
.
_query_http
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
sinfo
=
sox_io_backend
.
info
(
resp
.
raw
,
format
=
format_
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
torchaudio/backend/_soundfile_backend.py
View file @
b152ee61
...
...
@@ -50,6 +50,37 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
}
def
_get_bit_depth
(
subtype
):
if
subtype
not
in
_SUBTYPE_TO_BITS_PER_SAMPLE
:
warnings
.
warn
(
f
"The
{
subtype
}
subtype is unknown to TorchAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
return
_SUBTYPE_TO_BITS_PER_SAMPLE
.
get
(
subtype
,
0
)
_SUBTYPE_TO_ENCODING
=
{
'PCM_S8'
:
'PCM_S'
,
'PCM_16'
:
'PCM_S'
,
'PCM_24'
:
'PCM_S'
,
'PCM_32'
:
'PCM_S'
,
'PCM_U8'
:
'PCM_U'
,
'FLOAT'
:
'PCM_F'
,
'DOUBLE'
:
'PCM_F'
,
'ULAW'
:
'ULAW'
,
'ALAW'
:
'ALAW'
,
'VORBIS'
:
'VORBIS'
,
}
def
_get_encoding
(
format
:
str
,
subtype
:
str
):
if
format
==
'FLAC'
:
return
'FLAC'
return
_SUBTYPE_TO_ENCODING
.
get
(
subtype
,
'UNKNOWN'
)
@
_mod_utils
.
requires_module
(
"soundfile"
)
def
info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
)
->
AudioMetaData
:
"""Get signal information of an audio file.
...
...
@@ -68,15 +99,13 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
AudioMetaData: meta data of the given audio.
"""
sinfo
=
soundfile
.
info
(
filepath
)
if
sinfo
.
subtype
not
in
_SUBTYPE_TO_BITS_PER_SAMPLE
:
warnings
.
warn
(
f
"The
{
sinfo
.
subtype
}
subtype is unknown to TorchAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
bits_per_sample
=
_SUBTYPE_TO_BITS_PER_SAMPLE
.
get
(
sinfo
.
subtype
,
0
)
return
AudioMetaData
(
sinfo
.
samplerate
,
sinfo
.
frames
,
sinfo
.
channels
,
bits_per_sample
=
bits_per_sample
)
return
AudioMetaData
(
sinfo
.
samplerate
,
sinfo
.
frames
,
sinfo
.
channels
,
bits_per_sample
=
_get_bit_depth
(
sinfo
.
subtype
),
encoding
=
_get_encoding
(
sinfo
.
format
,
sinfo
.
subtype
),
)
_SUBTYPE2DTYPE
=
{
...
...
torchaudio/backend/common.py
View file @
b152ee61
...
...
@@ -14,12 +14,21 @@ class AudioMetaData:
:ivar int num_channels: The number of channels
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
or when it cannot be accurately inferred.
:ivar str encoding: Audio encoding.
"""
def
__init__
(
self
,
sample_rate
:
int
,
num_frames
:
int
,
num_channels
:
int
,
bits_per_sample
:
int
):
def
__init__
(
self
,
sample_rate
:
int
,
num_frames
:
int
,
num_channels
:
int
,
bits_per_sample
:
int
,
encoding
:
str
,
):
self
.
sample_rate
=
sample_rate
self
.
num_frames
=
num_frames
self
.
num_channels
=
num_channels
self
.
bits_per_sample
=
bits_per_sample
self
.
encoding
=
encoding
@
_mod_utils
.
deprecated
(
'Please migrate to `AudioMetaData`.'
,
'0.9.0'
)
...
...
torchaudio/backend/sox_io_backend.py
View file @
b152ee61
...
...
@@ -17,17 +17,15 @@ def _info(
format
:
Optional
[
str
]
=
None
,
)
->
AudioMetaData
:
if
hasattr
(
filepath
,
'read'
):
sinfo
=
torchaudio
.
_torchaudio
.
get_info_fileobj
(
filepath
,
format
)
sample_rate
,
num_channels
,
num_frames
,
bits_per_sample
=
sinfo
return
AudioMetaData
(
sample_rate
,
num_frames
,
num_channels
,
bits_per_sample
)
sinfo
=
torchaudio
.
_torchaudio
.
get_info_fileobj
(
filepath
,
format
)
return
AudioMetaData
(
*
sinfo
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
os
.
fspath
(
filepath
),
format
)
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
(),
sinfo
.
get_bits_per_sample
(),
sinfo
.
get_encoding
(),
)
...
...
@@ -69,7 +67,8 @@ def info(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
(),
sinfo
.
get_bits_per_sample
())
sinfo
.
get_bits_per_sample
(),
sinfo
.
get_encoding
())
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
...
...
torchaudio/csrc/sox/io.cpp
View file @
b152ee61
...
...
@@ -14,11 +14,13 @@ SignalInfo::SignalInfo(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
,
const
int64_t
bits_per_sample_
)
const
int64_t
bits_per_sample_
,
const
std
::
string
encoding_
)
:
sample_rate
(
sample_rate_
),
num_channels
(
num_channels_
),
num_frames
(
num_frames_
),
bits_per_sample
(
bits_per_sample_
){};
bits_per_sample
(
bits_per_sample_
),
encoding
(
encoding_
){};
int64_t
SignalInfo
::
getSampleRate
()
const
{
return
sample_rate
;
...
...
@@ -36,6 +38,45 @@ int64_t SignalInfo::getBitsPerSample() const {
return
bits_per_sample
;
}
std
::
string
SignalInfo
::
getEncoding
()
const
{
return
encoding
;
}
namespace
{
std
::
string
get_encoding
(
sox_encoding_t
encoding
)
{
switch
(
encoding
)
{
case
SOX_ENCODING_UNKNOWN
:
return
"UNKNOWN"
;
case
SOX_ENCODING_SIGN2
:
return
"PCM_S"
;
case
SOX_ENCODING_UNSIGNED
:
return
"PCM_U"
;
case
SOX_ENCODING_FLOAT
:
return
"PCM_F"
;
case
SOX_ENCODING_FLAC
:
return
"FLAC"
;
case
SOX_ENCODING_ULAW
:
return
"ULAW"
;
case
SOX_ENCODING_ALAW
:
return
"ALAW"
;
case
SOX_ENCODING_MP3
:
return
"MP3"
;
case
SOX_ENCODING_VORBIS
:
return
"VORBIS"
;
case
SOX_ENCODING_AMR_WB
:
return
"AMR_WB"
;
case
SOX_ENCODING_AMR_NB
:
return
"AMR_NB"
;
case
SOX_ENCODING_OPUS
:
return
"OPUS"
;
default:
return
"UNKNOWN"
;
}
}
}
// namespace
c10
::
intrusive_ptr
<
SignalInfo
>
get_info_file
(
const
std
::
string
&
path
,
c10
::
optional
<
std
::
string
>&
format
)
{
...
...
@@ -53,7 +94,8 @@ c10::intrusive_ptr<SignalInfo> get_info_file(
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
));
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
),
get_encoding
(
sf
->
encoding
.
encoding
));
}
namespace
{
...
...
@@ -157,7 +199,7 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
>
get_info_fileobj
(
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
,
std
::
string
>
get_info_fileobj
(
py
::
object
fileobj
,
c10
::
optional
<
std
::
string
>&
format
)
{
// Prepare in-memory file object
...
...
@@ -202,9 +244,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
return
std
::
make_tuple
(
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
));
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
),
get_encoding
(
sf
->
encoding
.
encoding
));
}
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
load_audio_fileobj
(
...
...
torchaudio/csrc/sox/io.h
View file @
b152ee61
...
...
@@ -16,16 +16,19 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t
num_channels
;
int64_t
num_frames
;
int64_t
bits_per_sample
;
std
::
string
encoding
;
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
,
const
int64_t
bits_per_sample_
);
const
int64_t
bits_per_sample_
,
const
std
::
string
encoding_
);
int64_t
getSampleRate
()
const
;
int64_t
getNumChannels
()
const
;
int64_t
getNumFrames
()
const
;
int64_t
getBitsPerSample
()
const
;
std
::
string
getEncoding
()
const
;
};
c10
::
intrusive_ptr
<
SignalInfo
>
get_info_file
(
...
...
@@ -51,7 +54,7 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
>
get_info_fileobj
(
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
,
std
::
string
>
get_info_fileobj
(
py
::
object
fileobj
,
c10
::
optional
<
std
::
string
>&
format
);
...
...
torchaudio/csrc/sox/register.cpp
View file @
b152ee61
...
...
@@ -45,7 +45,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.
def
(
"get_num_frames"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumFrames
)
.
def
(
"get_bits_per_sample"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
);
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
)
.
def
(
"get_encoding"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getEncoding
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info_file
);
m
.
def
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment