Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
b152ee61
"docker/vscode:/vscode.git/clone" did not exist on "e9b624fe227d2e01d3aff057b4a49f0cae58da13"
Unverified
Commit
b152ee61
authored
Jan 29, 2021
by
moto
Committed by
GitHub
Jan 29, 2021
Browse files
Add encoding attribute to AudioMetaData (#1206)
parent
674a71d1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
327 additions
and
136 deletions
+327
-136
test/torchaudio_unittest/soundfile_backend/info_test.py
test/torchaudio_unittest/soundfile_backend/info_test.py
+12
-22
test/torchaudio_unittest/sox_io_backend/info_test.py
test/torchaudio_unittest/sox_io_backend/info_test.py
+206
-89
torchaudio/backend/_soundfile_backend.py
torchaudio/backend/_soundfile_backend.py
+38
-9
torchaudio/backend/common.py
torchaudio/backend/common.py
+10
-1
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+5
-6
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+49
-6
torchaudio/csrc/sox/io.h
torchaudio/csrc/sox/io.h
+5
-2
torchaudio/csrc/sox/register.cpp
torchaudio/csrc/sox/register.cpp
+2
-1
No files found.
test/torchaudio_unittest/soundfile_backend/info_test.py
View file @
b152ee61
...
@@ -13,6 +13,8 @@ from torchaudio_unittest.common_utils import (
...
@@ -13,6 +13,8 @@ from torchaudio_unittest.common_utils import (
get_wav_data
,
get_wav_data
,
save_wav
,
save_wav
,
)
)
# TODO refactor and move these to common location
from
torchaudio_unittest.sox_io_backend.info_test
import
get_encoding
,
get_bits_per_sample
from
.common
import
skipIfFormatNotSupported
,
parameterize
from
.common
import
skipIfFormatNotSupported
,
parameterize
if
_mod_utils
.
is_module_available
(
"soundfile"
):
if
_mod_utils
.
is_module_available
(
"soundfile"
):
...
@@ -22,11 +24,10 @@ if _mod_utils.is_module_available("soundfile"):
...
@@ -22,11 +24,10 @@ if _mod_utils.is_module_available("soundfile"):
@
skipIfNoModule
(
"soundfile"
)
@
skipIfNoModule
(
"soundfile"
)
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
@
parameterize
(
@
parameterize
(
[
(
"float32"
,
32
),
(
"int32"
,
32
),
(
"int16"
,
16
),
(
"uint8"
,
8
)
],
[
8000
,
16000
],
[
1
,
2
],
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
)
)
def
test_wav
(
self
,
dtype
_and_bit_depth
,
sample_rate
,
num_channels
):
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check wav file correctly"""
"""`soundfile_backend.info` can check wav file correctly"""
dtype
,
bits_per_sample
=
dtype_and_bit_depth
duration
=
1
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
data
=
get_wav_data
(
...
@@ -37,25 +38,8 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -37,25 +38,8 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
bits_per_sample
==
get_bits_per_sample
(
"wav"
,
dtype
)
assert
info
.
encoding
==
get_encoding
(
"wav"
,
dtype
)
@
parameterize
(
[(
"float32"
,
32
),
(
"int32"
,
32
),
(
"int16"
,
16
),
(
"uint8"
,
8
)],
[
8000
,
16000
],
[
1
,
2
],
)
def
test_wav_multiple_channels
(
self
,
dtype_and_bit_depth
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check wav file with channels more than 2 correctly"""
dtype
,
bits_per_sample
=
dtype_and_bit_depth
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
soundfile_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"FLAC"
)
@
skipIfFormatNotSupported
(
"FLAC"
)
...
@@ -72,6 +56,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -72,6 +56,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
num_frames
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
16
assert
info
.
bits_per_sample
==
16
assert
info
.
encoding
==
"FLAC"
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"OGG"
)
@
skipIfFormatNotSupported
(
"OGG"
)
...
@@ -88,6 +73,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -88,6 +73,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"VORBIS"
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[(
'PCM_24'
,
24
),
(
'PCM_32'
,
32
)])
@
parameterize
([
8000
,
16000
],
[
1
,
2
],
[(
'PCM_24'
,
24
),
(
'PCM_32'
,
32
)])
@
skipIfFormatNotSupported
(
"NIST"
)
@
skipIfFormatNotSupported
(
"NIST"
)
...
@@ -105,6 +91,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -105,6 +91,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"PCM_S"
def
test_unknown_subtype_warning
(
self
):
def
test_unknown_subtype_warning
(
self
):
"""soundfile_backend.info issues a warning when the subtype is unknown
"""soundfile_backend.info issues a warning when the subtype is unknown
...
@@ -118,6 +105,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -118,6 +105,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
frames
=
356
frames
=
356
channels
=
2
channels
=
2
subtype
=
'UNSEEN_SUBTYPE'
subtype
=
'UNSEEN_SUBTYPE'
format
=
'UNKNOWN'
return
MockSoundFileInfo
()
return
MockSoundFileInfo
()
with
patch
(
"soundfile.info"
,
_mock_info_func
):
with
patch
(
"soundfile.info"
,
_mock_info_func
):
...
@@ -147,6 +135,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
...
@@ -147,6 +135,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
num_frames
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"FLAC"
if
ext
==
'flac'
else
"PCM_S"
def
test_fileobj_wav
(
self
):
def
test_fileobj_wav
(
self
):
"""Loading audio via file-like object works"""
"""Loading audio via file-like object works"""
...
@@ -179,6 +168,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
...
@@ -179,6 +168,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
num_frames
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"FLAC"
if
ext
==
'flac'
else
"PCM_S"
def
test_tarobj_wav
(
self
):
def
test_tarobj_wav
(
self
):
"""Query compressed audio via file-like object works"""
"""Query compressed audio via file-like object works"""
...
...
test/torchaudio_unittest/sox_io_backend/info_test.py
View file @
b152ee61
import
io
import
io
import
os
import
itertools
import
itertools
import
tarfile
import
tarfile
...
@@ -27,6 +28,30 @@ if _mod_utils.is_module_available("requests"):
...
@@ -27,6 +28,30 @@ if _mod_utils.is_module_available("requests"):
import
requests
import
requests
def
get_encoding
(
ext
,
dtype
):
exts
=
{
'mp3'
,
'flac'
,
'vorbis'
,
}
encodings
=
{
'float32'
:
'PCM_F'
,
'int32'
:
'PCM_S'
,
'int16'
:
'PCM_S'
,
'uint8'
:
'PCM_U'
,
}
return
ext
.
upper
()
if
ext
in
exts
else
encodings
[
dtype
]
def
get_bits_per_sample
(
ext
,
dtype
):
bits_per_samples
=
{
'flac'
:
24
,
'mp3'
:
0
,
'vorbis'
:
0
,
}
return
bits_per_samples
.
get
(
ext
,
sox_utils
.
get_bit_depth
(
dtype
))
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
@
skipIfNoExtension
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
...
@@ -46,6 +71,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -46,6 +71,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
'wav'
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
...
@@ -63,6 +89,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -63,6 +89,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
assert
info
.
encoding
==
get_encoding
(
'wav'
,
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -83,6 +110,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -83,6 +110,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
# assert info.num_frames == sample_rate * duration
# assert info.num_frames == sample_rate * duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"MP3"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -102,6 +130,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -102,6 +130,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
24
# FLAC standard
assert
info
.
bits_per_sample
==
24
# FLAC standard
assert
info
.
encoding
==
"FLAC"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -121,6 +150,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -121,6 +150,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"VORBIS"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -137,9 +167,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -137,9 +167,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
"PCM_S"
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'int32'
,
'int16'
,
'uint8'
],
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
)),
name_func
=
name_func
)
)),
name_func
=
name_func
)
...
@@ -156,6 +187,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -156,6 +187,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
bits_per_sample
==
bits_per_sample
assert
info
.
encoding
==
get_encoding
(
"amb"
,
dtype
)
def
test_amr_nb
(
self
):
def
test_amr_nb
(
self
):
"""`sox_io_backend.info` can check amr-nb file correctly"""
"""`sox_io_backend.info` can check amr-nb file correctly"""
...
@@ -171,6 +203,41 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -171,6 +203,41 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
assert
info
.
bits_per_sample
==
0
assert
info
.
encoding
==
"AMR_NB"
def
test_ulaw
(
self
):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.wav'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
'u-law'
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ULAW"
def
test_alaw
(
self
):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration
=
1
num_channels
=
1
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.wav'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
8
,
encoding
=
'a-law'
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
8
assert
info
.
encoding
==
"ALAW"
@
skipIfNoExtension
@
skipIfNoExtension
...
@@ -188,6 +255,7 @@ class TestInfoOpus(PytorchTestCase):
...
@@ -188,6 +255,7 @@ class TestInfoOpus(PytorchTestCase):
assert
info
.
num_frames
==
32768
assert
info
.
num_frames
==
32768
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
info
.
encoding
==
"OPUS"
@
skipIfNoExtension
@
skipIfNoExtension
...
@@ -205,144 +273,193 @@ class TestLoadWithoutExtension(PytorchTestCase):
...
@@ -205,144 +273,193 @@ class TestLoadWithoutExtension(PytorchTestCase):
path
=
get_asset_path
(
"mp3_without_ext"
)
path
=
get_asset_path
(
"mp3_without_ext"
)
sinfo
=
sox_io_backend
.
info
(
path
,
format
=
"mp3"
)
sinfo
=
sox_io_backend
.
info
(
path
,
format
=
"mp3"
)
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
num_frames
==
81216
assert
sinfo
.
num_channels
==
1
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
sinfo
.
encoding
==
"MP3"
class
FileObjTestBase
(
TempDirMixin
):
def
_gen_file
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
duration
=
num_frames
/
sample_rate
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
num_channels
,
encoding
=
sox_utils
.
get_encoding
(
dtype
),
bit_depth
=
bit_depth
,
duration
=
duration
)
return
path
@
skipIfNoExtension
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
class
TestFileObject
(
FileObjTestBase
,
PytorchTestCase
):
def
_query_fileobj
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
open
(
path
,
'rb'
)
as
fileobj
:
return
sox_io_backend
.
info
(
fileobj
,
format_
)
def
_query_bytesio
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
return
sox_io_backend
.
info
(
fileobj
,
format_
)
def
_query_tarfile
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
return
sox_io_backend
.
info
(
fileobj
,
format_
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'wav'
,
"float32"
),
(
'mp3'
,
0
),
(
'wav'
,
"int32"
),
(
'flac'
,
24
),
(
'wav'
,
"int16"
),
(
'vorbis'
,
0
),
(
'wav'
,
"uint8"
),
(
'amb'
,
32
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
])
def
test_fileobj
(
self
,
ext
,
bits_per_sampl
e
):
def
test_fileobj
(
self
,
ext
,
dtyp
e
):
"""Querying audio via file object works"""
"""Querying audio via file object works"""
sample_rate
=
16000
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
num_channels
=
2
duration
=
3
sinfo
=
self
.
_query_fileobj
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
duration
=
duration
)
w
it
h
open
(
path
,
'rb'
)
as
fileobj
:
b
it
s_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format_
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
def
_test_bytesio
(
self
,
ext
,
bits_per_sample
,
duration
):
@
parameterized
.
expand
([
(
'wav'
,
"float32"
),
(
'wav'
,
"int32"
),
(
'wav'
,
"int16"
),
(
'wav'
,
"uint8"
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
def
test_bytesio
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works for small data"""
sample_rate
=
16000
sample_rate
=
16000
num_frames
=
3
*
sample_rate
num_channels
=
2
num_channels
=
2
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
path
,
sample_rate
,
num_channels
=
2
,
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
duration
=
duration
)
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format_
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'wav'
,
"float32"
),
(
'mp3'
,
0
),
(
'wav'
,
"int32"
),
(
'flac'
,
24
),
(
'wav'
,
"int16"
),
(
'vorbis'
,
0
),
(
'wav'
,
"uint8"
),
(
'amb'
,
32
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
])
def
test_bytesio
(
self
,
ext
,
bits_per_sample
):
def
test_bytesio_tiny
(
self
,
ext
,
dtype
):
"""Querying audio via ByteIO object works"""
self
.
_test_bytesio
(
ext
,
bits_per_sample
,
duration
=
3
)
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
])
def
test_bytesio_tiny
(
self
,
ext
,
bits_per_sample
):
"""Querying audio via ByteIO object works for small data"""
"""Querying audio via ByteIO object works for small data"""
self
.
_test_bytesio
(
ext
,
bits_per_sample
,
duration
=
1
/
1600
)
sample_rate
=
8000
num_frames
=
4
num_channels
=
2
sinfo
=
self
.
_query_bytesio
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'wav'
,
"float32"
),
(
'mp3'
,
0
),
(
'wav'
,
"int32"
),
(
'flac'
,
24
),
(
'wav'
,
"int16"
),
(
'vorbis'
,
0
),
(
'wav'
,
"uint8"
),
(
'amb'
,
32
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
])
def
test_tarfile
(
self
,
ext
,
bits_per_sampl
e
):
def
test_tarfile
(
self
,
ext
,
dtyp
e
):
"""Querying compressed audio via file-like object works"""
"""Querying compressed audio via file-like object works"""
sample_rate
=
16000
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
num_channels
=
2
duration
=
3
sinfo
=
self
.
_query_tarfile
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
sox_utils
.
gen_audio_file
(
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
audio_path
,
sample_rate
,
num_channels
=
num_channels
,
duration
=
duration
)
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format
=
format_
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
@
skipIfNoExtension
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoModule
(
"requests"
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
PytorchTestCase
):
class
TestFileObjectHttp
(
HttpServerMixin
,
FileObjTestBase
,
PytorchTestCase
):
def
_query_http
(
self
,
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
):
audio_path
=
self
.
_gen_file
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
audio_file
=
os
.
path
.
basename
(
audio_path
)
url
=
self
.
get_url
(
audio_file
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
return
sox_io_backend
.
info
(
resp
.
raw
,
format
=
format_
)
@
parameterized
.
expand
([
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'wav'
,
"float32"
),
(
'mp3'
,
0
),
(
'wav'
,
"int32"
),
(
'flac'
,
24
),
(
'wav'
,
"int16"
),
(
'vorbis'
,
0
),
(
'wav'
,
"uint8"
),
(
'amb'
,
32
),
(
'mp3'
,
"float32"
),
(
'flac'
,
"float32"
),
(
'vorbis'
,
"float32"
),
(
'amb'
,
"int16"
),
])
])
def
test_requests
(
self
,
ext
,
bits_per_sampl
e
):
def
test_requests
(
self
,
ext
,
dtyp
e
):
"""Querying compressed audio via requests works"""
"""Querying compressed audio via requests works"""
sample_rate
=
16000
sample_rate
=
16000
num_frames
=
3.0
*
sample_rate
num_channels
=
2
num_channels
=
2
duration
=
3
sinfo
=
self
.
_query_http
(
ext
,
dtype
,
sample_rate
,
num_channels
,
num_frames
)
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
num_channels
,
duration
=
duration
)
url
=
self
.
get_url
(
audio_file
)
bits_per_sample
=
get_bits_per_sample
(
ext
,
dtype
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
num_frames
=
0
if
ext
in
[
'mp3'
,
'vorbis'
]
else
num_frames
sinfo
=
sox_io_backend
.
info
(
resp
.
raw
,
format
=
format_
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
num_frames
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
bits_per_sample
==
bits_per_sample
assert
sinfo
.
encoding
==
get_encoding
(
ext
,
dtype
)
torchaudio/backend/_soundfile_backend.py
View file @
b152ee61
...
@@ -50,6 +50,37 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
...
@@ -50,6 +50,37 @@ _SUBTYPE_TO_BITS_PER_SAMPLE = {
}
}
def
_get_bit_depth
(
subtype
):
if
subtype
not
in
_SUBTYPE_TO_BITS_PER_SAMPLE
:
warnings
.
warn
(
f
"The
{
subtype
}
subtype is unknown to TorchAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
return
_SUBTYPE_TO_BITS_PER_SAMPLE
.
get
(
subtype
,
0
)
_SUBTYPE_TO_ENCODING
=
{
'PCM_S8'
:
'PCM_S'
,
'PCM_16'
:
'PCM_S'
,
'PCM_24'
:
'PCM_S'
,
'PCM_32'
:
'PCM_S'
,
'PCM_U8'
:
'PCM_U'
,
'FLOAT'
:
'PCM_F'
,
'DOUBLE'
:
'PCM_F'
,
'ULAW'
:
'ULAW'
,
'ALAW'
:
'ALAW'
,
'VORBIS'
:
'VORBIS'
,
}
def
_get_encoding
(
format
:
str
,
subtype
:
str
):
if
format
==
'FLAC'
:
return
'FLAC'
return
_SUBTYPE_TO_ENCODING
.
get
(
subtype
,
'UNKNOWN'
)
@
_mod_utils
.
requires_module
(
"soundfile"
)
@
_mod_utils
.
requires_module
(
"soundfile"
)
def
info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
)
->
AudioMetaData
:
def
info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
)
->
AudioMetaData
:
"""Get signal information of an audio file.
"""Get signal information of an audio file.
...
@@ -68,15 +99,13 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
...
@@ -68,15 +99,13 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
AudioMetaData: meta data of the given audio.
AudioMetaData: meta data of the given audio.
"""
"""
sinfo
=
soundfile
.
info
(
filepath
)
sinfo
=
soundfile
.
info
(
filepath
)
if
sinfo
.
subtype
not
in
_SUBTYPE_TO_BITS_PER_SAMPLE
:
return
AudioMetaData
(
warnings
.
warn
(
sinfo
.
samplerate
,
f
"The
{
sinfo
.
subtype
}
subtype is unknown to TorchAudio. As a result, the bits_per_sample "
sinfo
.
frames
,
"attribute will be set to 0. If you are seeing this warning, please "
sinfo
.
channels
,
"report by opening an issue on github (after checking for existing/closed ones). "
bits_per_sample
=
_get_bit_depth
(
sinfo
.
subtype
),
"You may otherwise ignore this warning."
encoding
=
_get_encoding
(
sinfo
.
format
,
sinfo
.
subtype
),
)
)
bits_per_sample
=
_SUBTYPE_TO_BITS_PER_SAMPLE
.
get
(
sinfo
.
subtype
,
0
)
return
AudioMetaData
(
sinfo
.
samplerate
,
sinfo
.
frames
,
sinfo
.
channels
,
bits_per_sample
=
bits_per_sample
)
_SUBTYPE2DTYPE
=
{
_SUBTYPE2DTYPE
=
{
...
...
torchaudio/backend/common.py
View file @
b152ee61
...
@@ -14,12 +14,21 @@ class AudioMetaData:
...
@@ -14,12 +14,21 @@ class AudioMetaData:
:ivar int num_channels: The number of channels
:ivar int num_channels: The number of channels
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
or when it cannot be accurately inferred.
or when it cannot be accurately inferred.
:ivar str encoding: Audio encoding.
"""
"""
def
__init__
(
self
,
sample_rate
:
int
,
num_frames
:
int
,
num_channels
:
int
,
bits_per_sample
:
int
):
def
__init__
(
self
,
sample_rate
:
int
,
num_frames
:
int
,
num_channels
:
int
,
bits_per_sample
:
int
,
encoding
:
str
,
):
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
num_frames
=
num_frames
self
.
num_frames
=
num_frames
self
.
num_channels
=
num_channels
self
.
num_channels
=
num_channels
self
.
bits_per_sample
=
bits_per_sample
self
.
bits_per_sample
=
bits_per_sample
self
.
encoding
=
encoding
@
_mod_utils
.
deprecated
(
'Please migrate to `AudioMetaData`.'
,
'0.9.0'
)
@
_mod_utils
.
deprecated
(
'Please migrate to `AudioMetaData`.'
,
'0.9.0'
)
...
...
torchaudio/backend/sox_io_backend.py
View file @
b152ee61
...
@@ -17,17 +17,15 @@ def _info(
...
@@ -17,17 +17,15 @@ def _info(
format
:
Optional
[
str
]
=
None
,
format
:
Optional
[
str
]
=
None
,
)
->
AudioMetaData
:
)
->
AudioMetaData
:
if
hasattr
(
filepath
,
'read'
):
if
hasattr
(
filepath
,
'read'
):
sinfo
=
torchaudio
.
_torchaudio
.
get_info_fileobj
(
sinfo
=
torchaudio
.
_torchaudio
.
get_info_fileobj
(
filepath
,
format
)
filepath
,
format
)
return
AudioMetaData
(
*
sinfo
)
sample_rate
,
num_channels
,
num_frames
,
bits_per_sample
=
sinfo
return
AudioMetaData
(
sample_rate
,
num_frames
,
num_channels
,
bits_per_sample
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
os
.
fspath
(
filepath
),
format
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
os
.
fspath
(
filepath
),
format
)
return
AudioMetaData
(
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
(),
sinfo
.
get_num_channels
(),
sinfo
.
get_bits_per_sample
(),
sinfo
.
get_bits_per_sample
(),
sinfo
.
get_encoding
(),
)
)
...
@@ -69,7 +67,8 @@ def info(
...
@@ -69,7 +67,8 @@ def info(
sinfo
.
get_sample_rate
(),
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
(),
sinfo
.
get_num_channels
(),
sinfo
.
get_bits_per_sample
())
sinfo
.
get_bits_per_sample
(),
sinfo
.
get_encoding
())
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
...
...
torchaudio/csrc/sox/io.cpp
View file @
b152ee61
...
@@ -14,11 +14,13 @@ SignalInfo::SignalInfo(
...
@@ -14,11 +14,13 @@ SignalInfo::SignalInfo(
const
int64_t
sample_rate_
,
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
,
const
int64_t
num_frames_
,
const
int64_t
bits_per_sample_
)
const
int64_t
bits_per_sample_
,
const
std
::
string
encoding_
)
:
sample_rate
(
sample_rate_
),
:
sample_rate
(
sample_rate_
),
num_channels
(
num_channels_
),
num_channels
(
num_channels_
),
num_frames
(
num_frames_
),
num_frames
(
num_frames_
),
bits_per_sample
(
bits_per_sample_
){};
bits_per_sample
(
bits_per_sample_
),
encoding
(
encoding_
){};
int64_t
SignalInfo
::
getSampleRate
()
const
{
int64_t
SignalInfo
::
getSampleRate
()
const
{
return
sample_rate
;
return
sample_rate
;
...
@@ -36,6 +38,45 @@ int64_t SignalInfo::getBitsPerSample() const {
...
@@ -36,6 +38,45 @@ int64_t SignalInfo::getBitsPerSample() const {
return
bits_per_sample
;
return
bits_per_sample
;
}
}
std
::
string
SignalInfo
::
getEncoding
()
const
{
return
encoding
;
}
namespace
{
std
::
string
get_encoding
(
sox_encoding_t
encoding
)
{
switch
(
encoding
)
{
case
SOX_ENCODING_UNKNOWN
:
return
"UNKNOWN"
;
case
SOX_ENCODING_SIGN2
:
return
"PCM_S"
;
case
SOX_ENCODING_UNSIGNED
:
return
"PCM_U"
;
case
SOX_ENCODING_FLOAT
:
return
"PCM_F"
;
case
SOX_ENCODING_FLAC
:
return
"FLAC"
;
case
SOX_ENCODING_ULAW
:
return
"ULAW"
;
case
SOX_ENCODING_ALAW
:
return
"ALAW"
;
case
SOX_ENCODING_MP3
:
return
"MP3"
;
case
SOX_ENCODING_VORBIS
:
return
"VORBIS"
;
case
SOX_ENCODING_AMR_WB
:
return
"AMR_WB"
;
case
SOX_ENCODING_AMR_NB
:
return
"AMR_NB"
;
case
SOX_ENCODING_OPUS
:
return
"OPUS"
;
default:
return
"UNKNOWN"
;
}
}
}
// namespace
c10
::
intrusive_ptr
<
SignalInfo
>
get_info_file
(
c10
::
intrusive_ptr
<
SignalInfo
>
get_info_file
(
const
std
::
string
&
path
,
const
std
::
string
&
path
,
c10
::
optional
<
std
::
string
>&
format
)
{
c10
::
optional
<
std
::
string
>&
format
)
{
...
@@ -53,7 +94,8 @@ c10::intrusive_ptr<SignalInfo> get_info_file(
...
@@ -53,7 +94,8 @@ c10::intrusive_ptr<SignalInfo> get_info_file(
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
));
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
),
get_encoding
(
sf
->
encoding
.
encoding
));
}
}
namespace
{
namespace
{
...
@@ -157,7 +199,7 @@ void save_audio_file(
...
@@ -157,7 +199,7 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
>
get_info_fileobj
(
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
,
std
::
string
>
get_info_fileobj
(
py
::
object
fileobj
,
py
::
object
fileobj
,
c10
::
optional
<
std
::
string
>&
format
)
{
c10
::
optional
<
std
::
string
>&
format
)
{
// Prepare in-memory file object
// Prepare in-memory file object
...
@@ -202,9 +244,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
...
@@ -202,9 +244,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t> get_info_fileobj(
return
std
::
make_tuple
(
return
std
::
make_tuple
(
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
));
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
),
get_encoding
(
sf
->
encoding
.
encoding
));
}
}
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
load_audio_fileobj
(
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
load_audio_fileobj
(
...
...
torchaudio/csrc/sox/io.h
View file @
b152ee61
...
@@ -16,16 +16,19 @@ struct SignalInfo : torch::CustomClassHolder {
...
@@ -16,16 +16,19 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t
num_channels
;
int64_t
num_channels
;
int64_t
num_frames
;
int64_t
num_frames
;
int64_t
bits_per_sample
;
int64_t
bits_per_sample
;
std
::
string
encoding
;
SignalInfo
(
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
,
const
int64_t
num_frames_
,
const
int64_t
bits_per_sample_
);
const
int64_t
bits_per_sample_
,
const
std
::
string
encoding_
);
int64_t
getSampleRate
()
const
;
int64_t
getSampleRate
()
const
;
int64_t
getNumChannels
()
const
;
int64_t
getNumChannels
()
const
;
int64_t
getNumFrames
()
const
;
int64_t
getNumFrames
()
const
;
int64_t
getBitsPerSample
()
const
;
int64_t
getBitsPerSample
()
const
;
std
::
string
getEncoding
()
const
;
};
};
c10
::
intrusive_ptr
<
SignalInfo
>
get_info_file
(
c10
::
intrusive_ptr
<
SignalInfo
>
get_info_file
(
...
@@ -51,7 +54,7 @@ void save_audio_file(
...
@@ -51,7 +54,7 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
>
get_info_fileobj
(
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
,
std
::
string
>
get_info_fileobj
(
py
::
object
fileobj
,
py
::
object
fileobj
,
c10
::
optional
<
std
::
string
>&
format
);
c10
::
optional
<
std
::
string
>&
format
);
...
...
torchaudio/csrc/sox/register.cpp
View file @
b152ee61
...
@@ -45,7 +45,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -45,7 +45,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.
def
(
"get_num_frames"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumFrames
)
.
def
(
"get_num_frames"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumFrames
)
.
def
(
.
def
(
"get_bits_per_sample"
,
"get_bits_per_sample"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
);
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
)
.
def
(
"get_encoding"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getEncoding
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info_file
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info_file
);
m
.
def
(
m
.
def
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment