Unverified Commit 2994ce2e authored by moto's avatar moto Committed by GitHub
Browse files

Add bytes support to StreamReader (#3642)

Addresses https://github.com/pytorch/audio/issues/3640
parent ec13a815
...@@ -186,6 +186,61 @@ struct StreamWriterFileObj : private FileObj, public StreamWriterCustomIO { ...@@ -186,6 +186,61 @@ struct StreamWriterFileObj : private FileObj, public StreamWriterCustomIO {
py::hasattr(fileobj, "seek") ? &seek_func : nullptr) {} py::hasattr(fileobj, "seek") ? &seek_func : nullptr) {}
}; };
//////////////////////////////////////////////////////////////////////////////
// StreamReader/Writer Bytes
//////////////////////////////////////////////////////////////////////////////
struct BytesWrapper {
std::string_view src;
size_t index = 0;
};
static int read_bytes(void* opaque, uint8_t* buf, int buf_size) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);
auto num_read = FFMIN(wrapper->src.size() - wrapper->index, buf_size);
if (num_read == 0) {
return AVERROR_EOF;
}
auto head = wrapper->src.data() + wrapper->index;
memcpy(buf, head, num_read);
wrapper->index += num_read;
return num_read;
}
static int64_t seek_bytes(void* opaque, int64_t offset, int whence) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);
if (whence == AVSEEK_SIZE) {
return wrapper->src.size();
}
if (whence == SEEK_SET) {
wrapper->index = offset;
} else if (whence == SEEK_CUR) {
wrapper->index += offset;
} else if (whence == SEEK_END) {
wrapper->index = wrapper->src.size() + offset;
} else {
TORCH_INTERNAL_ASSERT(false, "Unexpected whence value: ", whence);
}
return static_cast<int64_t>(wrapper->index);
}
struct StreamReaderBytes : private BytesWrapper, public StreamReaderCustomIO {
StreamReaderBytes(
std::string_view src,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: BytesWrapper{src},
StreamReaderCustomIO(
this,
format,
buffer_size,
read_bytes,
seek_bytes,
option) {}
};
#ifndef TORCHAUDIO_FFMPEG_EXT_NAME #ifndef TORCHAUDIO_FFMPEG_EXT_NAME
#error TORCHAUDIO_FFMPEG_EXT_NAME must be defined. #error TORCHAUDIO_FFMPEG_EXT_NAME must be defined.
#endif #endif
...@@ -353,6 +408,31 @@ PYBIND11_MODULE(TORCHAUDIO_FFMPEG_EXT_NAME, m) { ...@@ -353,6 +408,31 @@ PYBIND11_MODULE(TORCHAUDIO_FFMPEG_EXT_NAME, m) {
.def("fill_buffer", &StreamReaderFileObj::fill_buffer) .def("fill_buffer", &StreamReaderFileObj::fill_buffer)
.def("is_buffer_ready", &StreamReaderFileObj::is_buffer_ready) .def("is_buffer_ready", &StreamReaderFileObj::is_buffer_ready)
.def("pop_chunks", &StreamReaderFileObj::pop_chunks); .def("pop_chunks", &StreamReaderFileObj::pop_chunks);
py::class_<StreamReaderBytes>(m, "StreamReaderBytes", py::module_local())
.def(py::init<
std::string_view,
const c10::optional<std::string>&,
const c10::optional<OptionDict>&,
int64_t>())
.def("num_src_streams", &StreamReaderBytes::num_src_streams)
.def("num_out_streams", &StreamReaderBytes::num_out_streams)
.def("find_best_audio_stream", &StreamReaderBytes::find_best_audio_stream)
.def("find_best_video_stream", &StreamReaderBytes::find_best_video_stream)
.def("get_metadata", &StreamReaderBytes::get_metadata)
.def("get_src_stream_info", &StreamReaderBytes::get_src_stream_info)
.def("get_out_stream_info", &StreamReaderBytes::get_out_stream_info)
.def("seek", &StreamReaderBytes::seek)
.def("add_audio_stream", &StreamReaderBytes::add_audio_stream)
.def("add_video_stream", &StreamReaderBytes::add_video_stream)
.def("remove_stream", &StreamReaderBytes::remove_stream)
.def(
"process_packet",
py::overload_cast<const c10::optional<double>&, const double>(
&StreamReader::process_packet))
.def("process_all_packets", &StreamReaderBytes::process_all_packets)
.def("fill_buffer", &StreamReaderBytes::fill_buffer)
.def("is_buffer_ready", &StreamReaderBytes::is_buffer_ready)
.def("pop_chunks", &StreamReaderBytes::pop_chunks);
} }
} // namespace } // namespace
......
...@@ -10,6 +10,7 @@ from torch.utils._pytree import tree_map ...@@ -10,6 +10,7 @@ from torch.utils._pytree import tree_map
if torchaudio._extension._FFMPEG_EXT is not None: if torchaudio._extension._FFMPEG_EXT is not None:
_StreamReader = torchaudio._extension._FFMPEG_EXT.StreamReader _StreamReader = torchaudio._extension._FFMPEG_EXT.StreamReader
_StreamReaderBytes = torchaudio._extension._FFMPEG_EXT.StreamReaderBytes
_StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj _StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj
...@@ -447,12 +448,14 @@ class StreamReader: ...@@ -447,12 +448,14 @@ class StreamReader:
For the detailed usage of this class, please refer to the tutorial. For the detailed usage of this class, please refer to the tutorial.
Args: Args:
src (str, path-like or file-like object): The media source. src (str, path-like, bytes or file-like object): The media source.
If string-type, it must be a resource indicator that FFmpeg can If string-type, it must be a resource indicator that FFmpeg can
handle. This includes a file path, URL, device identifier or handle. This includes a file path, URL, device identifier or
filter expression. The supported value depends on the FFmpeg found filter expression. The supported value depends on the FFmpeg found
in the system. in the system.
If bytes, it must be an encoded media data in contiguous memory.
If file-like object, it must support `read` method with the signature If file-like object, it must support `read` method with the signature
`read(size: int) -> bytes`. `read(size: int) -> bytes`.
Additionally, if the file-like object has `seek` method, it uses Additionally, if the file-like object has `seek` method, it uses
...@@ -518,7 +521,10 @@ class StreamReader: ...@@ -518,7 +521,10 @@ class StreamReader:
option: Optional[Dict[str, str]] = None, option: Optional[Dict[str, str]] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
): ):
if hasattr(src, "read"): self.src = src
if isinstance(src, bytes):
self._be = _StreamReaderBytes(src, format, option, buffer_size)
elif hasattr(src, "read"):
self._be = _StreamReaderFileObj(src, format, option, buffer_size) self._be = _StreamReaderFileObj(src, format, option, buffer_size)
else: else:
self._be = _StreamReader(str(src), format, option) self._be = _StreamReader(str(src), format, option)
......
...@@ -77,7 +77,7 @@ class ChunkTensorTest(TorchaudioTestCase): ...@@ -77,7 +77,7 @@ class ChunkTensorTest(TorchaudioTestCase):
# Helper decorator and Mixin to duplicate the tests for fileobj # Helper decorator and Mixin to duplicate the tests for fileobj
_media_source = parameterized_class( _media_source = parameterized_class(
("test_type",), ("test_type",),
[("str",), ("fileobj",)], [("str",), ("fileobj",), ("bytes",)],
class_name_func=lambda cls, _, params: f'{cls.__name__}_{params["test_type"]}', class_name_func=lambda cls, _, params: f'{cls.__name__}_{params["test_type"]}',
) )
...@@ -95,6 +95,9 @@ class _MediaSourceMixin: ...@@ -95,6 +95,9 @@ class _MediaSourceMixin:
self.src = path self.src = path
elif self.test_type == "fileobj": elif self.test_type == "fileobj":
self.src = open(path, "rb") self.src = open(path, "rb")
elif self.test_type == "bytes":
with open(path, "rb") as f:
self.src = f.read()
return self.src return self.src
def tearDown(self): def tearDown(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment