Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
41c76a17
Unverified
Commit
41c76a17
authored
Jan 27, 2021
by
moto
Committed by
GitHub
Jan 27, 2021
Browse files
Support file-like object in info (#1108)
parent
22e7e877
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
364 additions
and
15 deletions
+364
-15
test/torchaudio_unittest/soundfile_backend/info_test.py
test/torchaudio_unittest/soundfile_backend/info_test.py
+63
-0
test/torchaudio_unittest/sox_io_backend/info_test.py
test/torchaudio_unittest/sox_io_backend/info_test.py
+150
-1
torchaudio/backend/_soundfile_backend.py
torchaudio/backend/_soundfile_backend.py
+6
-4
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+42
-7
torchaudio/csrc/pybind.cpp
torchaudio/csrc/pybind.cpp
+4
-0
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+51
-1
torchaudio/csrc/sox/io.h
torchaudio/csrc/sox/io.h
+5
-1
torchaudio/csrc/sox/register.cpp
torchaudio/csrc/sox/register.cpp
+1
-1
torchaudio/csrc/sox/utils.cpp
torchaudio/csrc/sox/utils.cpp
+32
-0
torchaudio/csrc/sox/utils.h
torchaudio/csrc/sox/utils.h
+10
-0
No files found.
test/torchaudio_unittest/soundfile_backend/info_test.py
View file @
41c76a17
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
warnings
import
warnings
import
tarfile
import
torch
import
torch
from
torchaudio.backend
import
_soundfile_backend
as
soundfile_backend
from
torchaudio.backend
import
_soundfile_backend
as
soundfile_backend
...
@@ -125,3 +126,65 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -125,3 +126,65 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
len
(
w
)
==
1
assert
len
(
w
)
==
1
assert
"UNSEEN_SUBTYPE subtype is unknown to TorchAudio"
in
str
(
w
[
-
1
].
message
)
assert
"UNSEEN_SUBTYPE subtype is unknown to TorchAudio"
in
str
(
w
[
-
1
].
message
)
assert
info
.
bits_per_sample
==
0
assert
info
.
bits_per_sample
==
0
@
skipIfNoModule
(
"soundfile"
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
def
_test_fileobj
(
self
,
ext
,
subtype
,
bits_per_sample
):
"""Query audio via file-like object works"""
duration
=
2
sample_rate
=
16000
num_channels
=
2
num_frames
=
sample_rate
*
duration
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
soundfile
.
write
(
path
,
data
,
sample_rate
,
subtype
=
subtype
)
with
open
(
path
,
'rb'
)
as
fileobj
:
info
=
soundfile_backend
.
info
(
fileobj
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
def
test_fileobj_wav
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
'wav'
,
'PCM_16'
,
16
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_fileobj_flac
(
self
):
"""Loading audio via file-like object works"""
self
.
_test_fileobj
(
'flac'
,
'PCM_16'
,
16
)
def
_test_tarobj
(
self
,
ext
,
subtype
,
bits_per_sample
):
"""Query compressed audio via file-like object works"""
duration
=
2
sample_rate
=
16000
num_channels
=
2
num_frames
=
sample_rate
*
duration
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
soundfile
.
write
(
audio_path
,
data
,
sample_rate
,
subtype
=
subtype
)
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
info
=
soundfile_backend
.
info
(
fileobj
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
def
test_tarobj_wav
(
self
):
"""Query compressed audio via file-like object works"""
self
.
_test_tarobj
(
'wav'
,
'PCM_16'
,
16
)
@
skipIfFormatNotSupported
(
"FLAC"
)
def
test_tarobj_flac
(
self
):
"""Query compressed audio via file-like object works"""
self
.
_test_tarobj
(
'flac'
,
'PCM_16'
,
16
)
test/torchaudio_unittest/sox_io_backend/info_test.py
View file @
41c76a17
import
io
import
itertools
import
itertools
from
parameterized
import
p
ar
ameterized
import
t
ar
file
from
parameterized
import
parameterized
from
torchaudio.backend
import
sox_io_backend
from
torchaudio.backend
import
sox_io_backend
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio_unittest.common_utils
import
(
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
TempDirMixin
,
HttpServerMixin
,
PytorchTestCase
,
PytorchTestCase
,
skipIfNoExec
,
skipIfNoExec
,
skipIfNoExtension
,
skipIfNoExtension
,
skipIfNoModule
,
get_asset_path
,
get_asset_path
,
get_wav_data
,
get_wav_data
,
save_wav
,
save_wav
,
...
@@ -18,6 +23,10 @@ from .common import (
...
@@ -18,6 +23,10 @@ from .common import (
)
)
if
_mod_utils
.
is_module_available
(
"requests"
):
import
requests
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExec
(
'sox'
)
@
skipIfNoExtension
@
skipIfNoExtension
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
...
@@ -197,3 +206,143 @@ class TestLoadWithoutExtension(PytorchTestCase):
...
@@ -197,3 +206,143 @@ class TestLoadWithoutExtension(PytorchTestCase):
sinfo
=
sox_io_backend
.
info
(
path
,
format
=
"mp3"
)
sinfo
=
sox_io_backend
.
info
(
path
,
format
=
"mp3"
)
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
class
TestFileObject
(
TempDirMixin
,
PytorchTestCase
):
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
])
def
test_fileobj
(
self
,
ext
,
bits_per_sample
):
"""Querying audio via file object works"""
sample_rate
=
16000
num_channels
=
2
duration
=
3
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
duration
=
duration
)
with
open
(
path
,
'rb'
)
as
fileobj
:
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format_
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
bits_per_sample
==
bits_per_sample
def
_test_bytesio
(
self
,
ext
,
bits_per_sample
,
duration
):
sample_rate
=
16000
num_channels
=
2
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
path
=
self
.
get_temp_path
(
f
'test.
{
ext
}
'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
=
2
,
duration
=
duration
)
with
open
(
path
,
'rb'
)
as
file_
:
fileobj
=
io
.
BytesIO
(
file_
.
read
())
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format_
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
bits_per_sample
==
bits_per_sample
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
])
def
test_bytesio
(
self
,
ext
,
bits_per_sample
):
"""Querying audio via ByteIO object works"""
self
.
_test_bytesio
(
ext
,
bits_per_sample
,
duration
=
3
)
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
])
def
test_bytesio_tiny
(
self
,
ext
,
bits_per_sample
):
"""Querying audio via ByteIO object works for small data"""
self
.
_test_bytesio
(
ext
,
bits_per_sample
,
duration
=
1
/
1600
)
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
])
def
test_tarfile
(
self
,
ext
,
bits_per_sample
):
"""Querying compressed audio via file-like object works"""
sample_rate
=
16000
num_channels
=
2
duration
=
3
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
archive_path
=
self
.
get_temp_path
(
'archive.tar.gz'
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
num_channels
,
duration
=
duration
)
with
tarfile
.
TarFile
(
archive_path
,
'w'
)
as
tarobj
:
tarobj
.
add
(
audio_path
,
arcname
=
audio_file
)
with
tarfile
.
TarFile
(
archive_path
,
'r'
)
as
tarobj
:
fileobj
=
tarobj
.
extractfile
(
audio_file
)
sinfo
=
sox_io_backend
.
info
(
fileobj
,
format
=
format_
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
bits_per_sample
==
bits_per_sample
@
skipIfNoExtension
@
skipIfNoExec
(
'sox'
)
@
skipIfNoModule
(
"requests"
)
class
TestFileObjectHttp
(
HttpServerMixin
,
PytorchTestCase
):
@
parameterized
.
expand
([
(
'wav'
,
32
),
(
'mp3'
,
0
),
(
'flac'
,
24
),
(
'vorbis'
,
0
),
(
'amb'
,
32
),
])
def
test_requests
(
self
,
ext
,
bits_per_sample
):
"""Querying compressed audio via requests works"""
sample_rate
=
16000
num_channels
=
2
duration
=
3
format_
=
ext
if
ext
in
[
'mp3'
]
else
None
audio_file
=
f
'test.
{
ext
}
'
audio_path
=
self
.
get_temp_path
(
audio_file
)
sox_utils
.
gen_audio_file
(
audio_path
,
sample_rate
,
num_channels
=
num_channels
,
duration
=
duration
)
url
=
self
.
get_url
(
audio_file
)
with
requests
.
get
(
url
,
stream
=
True
)
as
resp
:
sinfo
=
sox_io_backend
.
info
(
resp
.
raw
,
format
=
format_
)
assert
sinfo
.
sample_rate
==
sample_rate
assert
sinfo
.
num_channels
==
num_channels
if
ext
not
in
[
'mp3'
,
'vorbis'
]:
# these container formats do not have length info
assert
sinfo
.
num_frames
==
sample_rate
*
duration
assert
sinfo
.
bits_per_sample
==
bits_per_sample
torchaudio/backend/_soundfile_backend.py
View file @
41c76a17
...
@@ -55,10 +55,12 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
...
@@ -55,10 +55,12 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
"""Get signal information of an audio file.
Args:
Args:
filepath (str or pathlib.Path): Path to audio file.
filepath (path-like object or file-like object):
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
Source of audio data.
for the consistency with "sox_io" backend, which has a restriction on type annotation
Note:
for TorchScript compiler compatiblity.
* This argument is intentionally annotated as ``str`` only,
for the consistency with "sox_io" backend, which has a restriction
on type annotation due to TorchScript compiler compatiblity.
format (str, optional):
format (str, optional):
Not used. PySoundFile does not accept format hint.
Not used. PySoundFile does not accept format hint.
...
...
torchaudio/backend/sox_io_backend.py
View file @
41c76a17
...
@@ -10,6 +10,26 @@ import torchaudio
...
@@ -10,6 +10,26 @@ import torchaudio
from
.common
import
AudioMetaData
from
.common
import
AudioMetaData
@
torch
.
jit
.
unused
def
_info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
,
)
->
AudioMetaData
:
if
hasattr
(
filepath
,
'read'
):
sinfo
=
torchaudio
.
_torchaudio
.
get_info_fileobj
(
filepath
,
format
)
sample_rate
,
num_channels
,
num_frames
,
bits_per_sample
=
sinfo
return
AudioMetaData
(
sample_rate
,
num_frames
,
num_channels
,
bits_per_sample
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
os
.
fspath
(
filepath
),
format
)
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
(),
sinfo
.
get_bits_per_sample
(),
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
def
info
(
def
info
(
filepath
:
str
,
filepath
:
str
,
...
@@ -18,9 +38,21 @@ def info(
...
@@ -18,9 +38,21 @@ def info(
"""Get signal information of an audio file.
"""Get signal information of an audio file.
Args:
Args:
filepath (str or pathlib.Path):
filepath (path-like object or file-like object):
Path to audio file. This function also handles ``pathlib.Path`` objects,
Source of audio data. When the function is not compiled by TorchScript,
but is annotated as ``str`` for TorchScript compatibility.
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* When the input type is file-like object, this function cannot
get the correct length (``num_samples``) for certain formats,
such as ``mp3`` and ``vorbis``.
In this case, the value of ``num_samples`` is ``0``.
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
format (str, optional):
format (str, optional):
Override the format detection with the given format.
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
Providing the argument might help when libsox can not infer the format
...
@@ -29,10 +61,13 @@ def info(
...
@@ -29,10 +61,13 @@ def info(
Returns:
Returns:
AudioMetaData: Metadata of the given audio.
AudioMetaData: Metadata of the given audio.
"""
"""
# Cast to str in case type is `pathlib.Path`
if
not
torch
.
jit
.
is_scripting
():
filepath
=
str
(
filepath
)
return
_info
(
filepath
,
format
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
,
format
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
,
format
)
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
(),
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
(),
sinfo
.
get_bits_per_sample
())
sinfo
.
get_bits_per_sample
())
...
...
torchaudio/csrc/pybind.cpp
View file @
41c76a17
...
@@ -100,6 +100,10 @@ PYBIND11_MODULE(_torchaudio, m) {
...
@@ -100,6 +100,10 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info"
,
"get_info"
,
&
torch
::
audio
::
get_info
,
&
torch
::
audio
::
get_info
,
"Gets information about an audio file"
);
"Gets information about an audio file"
);
m
.
def
(
"get_info_fileobj"
,
&
torchaudio
::
sox_io
::
get_info_fileobj
,
"Get metadata of audio in file object."
);
m
.
def
(
m
.
def
(
"load_audio_fileobj"
,
"load_audio_fileobj"
,
&
torchaudio
::
sox_io
::
load_audio_fileobj
,
&
torchaudio
::
sox_io
::
load_audio_fileobj
,
...
...
torchaudio/csrc/sox/io.cpp
View file @
41c76a17
...
@@ -36,7 +36,7 @@ int64_t SignalInfo::getBitsPerSample() const {
...
@@ -36,7 +36,7 @@ int64_t SignalInfo::getBitsPerSample() const {
return
bits_per_sample
;
return
bits_per_sample
;
}
}
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
_file
(
const
std
::
string
&
path
,
const
std
::
string
&
path
,
c10
::
optional
<
std
::
string
>&
format
)
{
c10
::
optional
<
std
::
string
>&
format
)
{
SoxFormat
sf
(
sox_open_read
(
SoxFormat
sf
(
sox_open_read
(
...
@@ -149,6 +149,56 @@ void save_audio_file(
...
@@ -149,6 +149,56 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
>
get_info_fileobj
(
py
::
object
fileobj
,
c10
::
optional
<
std
::
string
>&
format
)
{
// Prepare in-memory file object
// When libsox opens a file, it also reads the header.
// When opening a file there are two functions that might touch FILE* (and the
// underlying buffer).
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
//
// `auto_detect_format` function only requires 256 bytes, but format-dependant
// `startread` handler might require more data. In case of vorbis, the size of
// header is unbounded, but typically 4kB maximum.
//
// "The header size is unbounded, although for streaming a rule-of-thumb of
// 4kB or less is recommended (and Xiph.Org's Vorbis encoder follows this
// suggestion)."
//
// See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
auto
capacity
=
4096
;
std
::
string
buffer
(
capacity
,
'\0'
);
auto
*
buf
=
const_cast
<
char
*>
(
buffer
.
data
());
auto
num_read
=
read_fileobj
(
&
fileobj
,
capacity
,
buf
);
// If the file is shorter than 256, then libsox cannot read the header.
auto
buf_size
=
(
num_read
>
256
)
?
num_read
:
256
;
SoxFormat
sf
(
sox_open_mem_read
(
buf
,
buf_size
,
/*signal=*/
nullptr
,
/*encoding=*/
nullptr
,
/*filetype=*/
format
.
has_value
()
?
format
.
value
().
c_str
()
:
nullptr
));
// In case of streamed data, length can be 0
validate_input_file
(
sf
,
/*check_length=*/
false
);
return
std
::
make_tuple
(
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
));
}
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
load_audio_fileobj
(
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
load_audio_fileobj
(
py
::
object
fileobj
,
py
::
object
fileobj
,
c10
::
optional
<
int64_t
>&
frame_offset
,
c10
::
optional
<
int64_t
>&
frame_offset
,
...
...
torchaudio/csrc/sox/io.h
View file @
41c76a17
...
@@ -28,7 +28,7 @@ struct SignalInfo : torch::CustomClassHolder {
...
@@ -28,7 +28,7 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t
getBitsPerSample
()
const
;
int64_t
getBitsPerSample
()
const
;
};
};
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
_file
(
const
std
::
string
&
path
,
const
std
::
string
&
path
,
c10
::
optional
<
std
::
string
>&
format
);
c10
::
optional
<
std
::
string
>&
format
);
...
@@ -50,6 +50,10 @@ void save_audio_file(
...
@@ -50,6 +50,10 @@ void save_audio_file(
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
>
get_info_fileobj
(
py
::
object
fileobj
,
c10
::
optional
<
std
::
string
>&
format
);
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
load_audio_fileobj
(
std
::
tuple
<
torch
::
Tensor
,
int64_t
>
load_audio_fileobj
(
py
::
object
fileobj
,
py
::
object
fileobj
,
c10
::
optional
<
int64_t
>&
frame_offset
,
c10
::
optional
<
int64_t
>&
frame_offset
,
...
...
torchaudio/csrc/sox/register.cpp
View file @
41c76a17
...
@@ -47,7 +47,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -47,7 +47,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"get_bits_per_sample"
,
"get_bits_per_sample"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
);
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info
_file
);
m
.
def
(
m
.
def
(
"torchaudio::sox_io_load_audio_file("
"torchaudio::sox_io_load_audio_file("
"str path,"
"str path,"
...
...
torchaudio/csrc/sox/utils.cpp
View file @
41c76a17
...
@@ -317,5 +317,37 @@ sox_encodinginfo_t get_encodinginfo(
...
@@ -317,5 +317,37 @@ sox_encodinginfo_t get_encodinginfo(
/*opposite_endian=*/
sox_false
};
/*opposite_endian=*/
sox_false
};
}
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
uint64_t
read_fileobj
(
py
::
object
*
fileobj
,
const
uint64_t
size
,
char
*
buffer
)
{
uint64_t
num_read
=
0
;
while
(
num_read
<
size
)
{
auto
request
=
size
-
num_read
;
auto
chunk
=
static_cast
<
std
::
string
>
(
static_cast
<
py
::
bytes
>
(
fileobj
->
attr
(
"read"
)(
request
)));
auto
chunk_len
=
chunk
.
length
();
if
(
chunk_len
==
0
)
{
break
;
}
if
(
chunk_len
>
request
)
{
std
::
ostringstream
message
;
message
<<
"Requested up to "
<<
request
<<
" bytes but, "
<<
"received "
<<
chunk_len
<<
" bytes. "
<<
"The given object does not confirm to read protocol of file object."
;
throw
std
::
runtime_error
(
message
.
str
());
}
std
::
cerr
<<
"req: "
<<
request
<<
", fetched: "
<<
chunk_len
<<
std
::
endl
;
std
::
cerr
<<
"buffer: "
<<
(
void
*
)
buffer
<<
std
::
endl
;
memcpy
(
buffer
,
chunk
.
data
(),
chunk_len
);
buffer
+=
chunk_len
;
num_read
+=
chunk_len
;
}
return
num_read
;
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
}
// namespace sox_utils
}
// namespace sox_utils
}
// namespace torchaudio
}
// namespace torchaudio
torchaudio/csrc/sox/utils.h
View file @
41c76a17
...
@@ -4,6 +4,10 @@
...
@@ -4,6 +4,10 @@
#include <sox.h>
#include <sox.h>
#include <torch/script.h>
#include <torch/script.h>
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
namespace
torchaudio
{
namespace
torchaudio
{
namespace
sox_utils
{
namespace
sox_utils
{
...
@@ -127,6 +131,12 @@ sox_encodinginfo_t get_encodinginfo(
...
@@ -127,6 +131,12 @@ sox_encodinginfo_t get_encodinginfo(
const
caffe2
::
TypeMeta
dtype
,
const
caffe2
::
TypeMeta
dtype
,
c10
::
optional
<
double
>&
compression
);
c10
::
optional
<
double
>&
compression
);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
uint64_t
read_fileobj
(
py
::
object
*
fileobj
,
uint64_t
size
,
char
*
buffer
);
#endif // TORCH_API_INCLUDE_EXTENSION_H
}
// namespace sox_utils
}
// namespace sox_utils
}
// namespace torchaudio
}
// namespace torchaudio
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment