Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
2994ce2e
Unverified
Commit
2994ce2e
authored
Oct 09, 2023
by
moto
Committed by
GitHub
Oct 09, 2023
Browse files
Add bytes support to StreamReader (#3642)
Addresses
https://github.com/pytorch/audio/issues/3640
parent
ec13a815
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
3 deletions
+92
-3
src/torchaudio/csrc/ffmpeg/pybind/pybind.cpp
src/torchaudio/csrc/ffmpeg/pybind/pybind.cpp
+80
-0
src/torchaudio/io/_stream_reader.py
src/torchaudio/io/_stream_reader.py
+8
-2
test/torchaudio_unittest/io/stream_reader_test.py
test/torchaudio_unittest/io/stream_reader_test.py
+4
-1
No files found.
src/torchaudio/csrc/ffmpeg/pybind/pybind.cpp
View file @
2994ce2e
...
@@ -186,6 +186,61 @@ struct StreamWriterFileObj : private FileObj, public StreamWriterCustomIO {
...
@@ -186,6 +186,61 @@ struct StreamWriterFileObj : private FileObj, public StreamWriterCustomIO {
py
::
hasattr
(
fileobj
,
"seek"
)
?
&
seek_func
:
nullptr
)
{}
py
::
hasattr
(
fileobj
,
"seek"
)
?
&
seek_func
:
nullptr
)
{}
};
};
//////////////////////////////////////////////////////////////////////////////
// StreamReader/Writer Bytes
//////////////////////////////////////////////////////////////////////////////
struct
BytesWrapper
{
std
::
string_view
src
;
size_t
index
=
0
;
};
static
int
read_bytes
(
void
*
opaque
,
uint8_t
*
buf
,
int
buf_size
)
{
BytesWrapper
*
wrapper
=
static_cast
<
BytesWrapper
*>
(
opaque
);
auto
num_read
=
FFMIN
(
wrapper
->
src
.
size
()
-
wrapper
->
index
,
buf_size
);
if
(
num_read
==
0
)
{
return
AVERROR_EOF
;
}
auto
head
=
wrapper
->
src
.
data
()
+
wrapper
->
index
;
memcpy
(
buf
,
head
,
num_read
);
wrapper
->
index
+=
num_read
;
return
num_read
;
}
static
int64_t
seek_bytes
(
void
*
opaque
,
int64_t
offset
,
int
whence
)
{
BytesWrapper
*
wrapper
=
static_cast
<
BytesWrapper
*>
(
opaque
);
if
(
whence
==
AVSEEK_SIZE
)
{
return
wrapper
->
src
.
size
();
}
if
(
whence
==
SEEK_SET
)
{
wrapper
->
index
=
offset
;
}
else
if
(
whence
==
SEEK_CUR
)
{
wrapper
->
index
+=
offset
;
}
else
if
(
whence
==
SEEK_END
)
{
wrapper
->
index
=
wrapper
->
src
.
size
()
+
offset
;
}
else
{
TORCH_INTERNAL_ASSERT
(
false
,
"Unexpected whence value: "
,
whence
);
}
return
static_cast
<
int64_t
>
(
wrapper
->
index
);
}
struct
StreamReaderBytes
:
private
BytesWrapper
,
public
StreamReaderCustomIO
{
StreamReaderBytes
(
std
::
string_view
src
,
const
c10
::
optional
<
std
::
string
>&
format
,
const
c10
::
optional
<
std
::
map
<
std
::
string
,
std
::
string
>>&
option
,
int64_t
buffer_size
)
:
BytesWrapper
{
src
},
StreamReaderCustomIO
(
this
,
format
,
buffer_size
,
read_bytes
,
seek_bytes
,
option
)
{}
};
#ifndef TORCHAUDIO_FFMPEG_EXT_NAME
#ifndef TORCHAUDIO_FFMPEG_EXT_NAME
#error TORCHAUDIO_FFMPEG_EXT_NAME must be defined.
#error TORCHAUDIO_FFMPEG_EXT_NAME must be defined.
#endif
#endif
...
@@ -353,6 +408,31 @@ PYBIND11_MODULE(TORCHAUDIO_FFMPEG_EXT_NAME, m) {
...
@@ -353,6 +408,31 @@ PYBIND11_MODULE(TORCHAUDIO_FFMPEG_EXT_NAME, m) {
.
def
(
"fill_buffer"
,
&
StreamReaderFileObj
::
fill_buffer
)
.
def
(
"fill_buffer"
,
&
StreamReaderFileObj
::
fill_buffer
)
.
def
(
"is_buffer_ready"
,
&
StreamReaderFileObj
::
is_buffer_ready
)
.
def
(
"is_buffer_ready"
,
&
StreamReaderFileObj
::
is_buffer_ready
)
.
def
(
"pop_chunks"
,
&
StreamReaderFileObj
::
pop_chunks
);
.
def
(
"pop_chunks"
,
&
StreamReaderFileObj
::
pop_chunks
);
py
::
class_
<
StreamReaderBytes
>
(
m
,
"StreamReaderBytes"
,
py
::
module_local
())
.
def
(
py
::
init
<
std
::
string_view
,
const
c10
::
optional
<
std
::
string
>&
,
const
c10
::
optional
<
OptionDict
>&
,
int64_t
>
())
.
def
(
"num_src_streams"
,
&
StreamReaderBytes
::
num_src_streams
)
.
def
(
"num_out_streams"
,
&
StreamReaderBytes
::
num_out_streams
)
.
def
(
"find_best_audio_stream"
,
&
StreamReaderBytes
::
find_best_audio_stream
)
.
def
(
"find_best_video_stream"
,
&
StreamReaderBytes
::
find_best_video_stream
)
.
def
(
"get_metadata"
,
&
StreamReaderBytes
::
get_metadata
)
.
def
(
"get_src_stream_info"
,
&
StreamReaderBytes
::
get_src_stream_info
)
.
def
(
"get_out_stream_info"
,
&
StreamReaderBytes
::
get_out_stream_info
)
.
def
(
"seek"
,
&
StreamReaderBytes
::
seek
)
.
def
(
"add_audio_stream"
,
&
StreamReaderBytes
::
add_audio_stream
)
.
def
(
"add_video_stream"
,
&
StreamReaderBytes
::
add_video_stream
)
.
def
(
"remove_stream"
,
&
StreamReaderBytes
::
remove_stream
)
.
def
(
"process_packet"
,
py
::
overload_cast
<
const
c10
::
optional
<
double
>&
,
const
double
>
(
&
StreamReader
::
process_packet
))
.
def
(
"process_all_packets"
,
&
StreamReaderBytes
::
process_all_packets
)
.
def
(
"fill_buffer"
,
&
StreamReaderBytes
::
fill_buffer
)
.
def
(
"is_buffer_ready"
,
&
StreamReaderBytes
::
is_buffer_ready
)
.
def
(
"pop_chunks"
,
&
StreamReaderBytes
::
pop_chunks
);
}
}
}
// namespace
}
// namespace
...
...
src/torchaudio/io/_stream_reader.py
View file @
2994ce2e
...
@@ -10,6 +10,7 @@ from torch.utils._pytree import tree_map
...
@@ -10,6 +10,7 @@ from torch.utils._pytree import tree_map
if
torchaudio
.
_extension
.
_FFMPEG_EXT
is
not
None
:
if
torchaudio
.
_extension
.
_FFMPEG_EXT
is
not
None
:
_StreamReader
=
torchaudio
.
_extension
.
_FFMPEG_EXT
.
StreamReader
_StreamReader
=
torchaudio
.
_extension
.
_FFMPEG_EXT
.
StreamReader
_StreamReaderBytes
=
torchaudio
.
_extension
.
_FFMPEG_EXT
.
StreamReaderBytes
_StreamReaderFileObj
=
torchaudio
.
_extension
.
_FFMPEG_EXT
.
StreamReaderFileObj
_StreamReaderFileObj
=
torchaudio
.
_extension
.
_FFMPEG_EXT
.
StreamReaderFileObj
...
@@ -447,12 +448,14 @@ class StreamReader:
...
@@ -447,12 +448,14 @@ class StreamReader:
For the detailed usage of this class, please refer to the tutorial.
For the detailed usage of this class, please refer to the tutorial.
Args:
Args:
src (str, path-like or file-like object): The media source.
src (str, path-like
, bytes
or file-like object): The media source.
If string-type, it must be a resource indicator that FFmpeg can
If string-type, it must be a resource indicator that FFmpeg can
handle. This includes a file path, URL, device identifier or
handle. This includes a file path, URL, device identifier or
filter expression. The supported value depends on the FFmpeg found
filter expression. The supported value depends on the FFmpeg found
in the system.
in the system.
If bytes, it must be an encoded media data in contiguous memory.
If file-like object, it must support `read` method with the signature
If file-like object, it must support `read` method with the signature
`read(size: int) -> bytes`.
`read(size: int) -> bytes`.
Additionally, if the file-like object has `seek` method, it uses
Additionally, if the file-like object has `seek` method, it uses
...
@@ -518,7 +521,10 @@ class StreamReader:
...
@@ -518,7 +521,10 @@ class StreamReader:
option
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
option
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
buffer_size
:
int
=
4096
,
buffer_size
:
int
=
4096
,
):
):
if
hasattr
(
src
,
"read"
):
self
.
src
=
src
if
isinstance
(
src
,
bytes
):
self
.
_be
=
_StreamReaderBytes
(
src
,
format
,
option
,
buffer_size
)
elif
hasattr
(
src
,
"read"
):
self
.
_be
=
_StreamReaderFileObj
(
src
,
format
,
option
,
buffer_size
)
self
.
_be
=
_StreamReaderFileObj
(
src
,
format
,
option
,
buffer_size
)
else
:
else
:
self
.
_be
=
_StreamReader
(
str
(
src
),
format
,
option
)
self
.
_be
=
_StreamReader
(
str
(
src
),
format
,
option
)
...
...
test/torchaudio_unittest/io/stream_reader_test.py
View file @
2994ce2e
...
@@ -77,7 +77,7 @@ class ChunkTensorTest(TorchaudioTestCase):
...
@@ -77,7 +77,7 @@ class ChunkTensorTest(TorchaudioTestCase):
# Helper decorator and Mixin to duplicate the tests for fileobj
# Helper decorator and Mixin to duplicate the tests for fileobj
_media_source
=
parameterized_class
(
_media_source
=
parameterized_class
(
(
"test_type"
,),
(
"test_type"
,),
[(
"str"
,),
(
"fileobj"
,)],
[(
"str"
,),
(
"fileobj"
,)
,
(
"bytes"
,)
],
class_name_func
=
lambda
cls
,
_
,
params
:
f
'
{
cls
.
__name__
}
_
{
params
[
"test_type"
]
}
'
,
class_name_func
=
lambda
cls
,
_
,
params
:
f
'
{
cls
.
__name__
}
_
{
params
[
"test_type"
]
}
'
,
)
)
...
@@ -95,6 +95,9 @@ class _MediaSourceMixin:
...
@@ -95,6 +95,9 @@ class _MediaSourceMixin:
self
.
src
=
path
self
.
src
=
path
elif
self
.
test_type
==
"fileobj"
:
elif
self
.
test_type
==
"fileobj"
:
self
.
src
=
open
(
path
,
"rb"
)
self
.
src
=
open
(
path
,
"rb"
)
elif
self
.
test_type
==
"bytes"
:
with
open
(
path
,
"rb"
)
as
f
:
self
.
src
=
f
.
read
()
return
self
.
src
return
self
.
src
def
tearDown
(
self
):
def
tearDown
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment