Commit b012b452 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Bind StreamReader/Writer with PyBind11 (#3091)

Summary:
This commit is kind of clean up and preparation for future
development.

We plan to pass around more complicated objects among
StreamReader and StreamWriter, and TorchBind is not expressive enough
for defining intermediate object, so we use PyBind11 for binding
StreamWriter.

Pull Request resolved: https://github.com/pytorch/audio/pull/3091

Reviewed By: xiaohui-zhang

Differential Revision: D43515714

Pulled By: mthrok

fbshipit-source-id: 9097bb104bbf8c1536a5fab6f87447c08b10a7f2
parent f6d1bc96
...@@ -82,6 +82,14 @@ process_packet_block ...@@ -82,6 +82,14 @@ process_packet_block
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
.. doxygenfunction:: torchaudio::io::StreamReader::process_packet_block .. doxygenfunction:: torchaudio::io::StreamReader::process_packet_block
process_all_packets
^^^^^^^^^^^^^^^^^^^
.. doxygenfunction:: torchaudio::io::StreamReader::process_all_packets
fill_buffer
^^^^^^^^^^^
.. doxygenfunction:: torchaudio::io::StreamReader::fill_buffer
Retrieval Methods Retrieval Methods
----------------- -----------------
......
...@@ -20,7 +20,6 @@ set( ...@@ -20,7 +20,6 @@ set(
stream_reader/stream_reader_wrapper.cpp stream_reader/stream_reader_wrapper.cpp
stream_reader/stream_reader_binding.cpp stream_reader/stream_reader_binding.cpp
stream_writer/stream_writer.cpp stream_writer/stream_writer.cpp
stream_writer/stream_writer_binding.cpp
utils.cpp utils.cpp
) )
......
#include <pybind11/stl.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h> #include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h> #include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h>
...@@ -8,7 +7,21 @@ namespace io { ...@@ -8,7 +7,21 @@ namespace io {
namespace { namespace {
PYBIND11_MODULE(_torchaudio_ffmpeg, m) { PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<StreamWriterFileObj>(m, "StreamWriterFileObj") py::class_<Chunk>(m, "Chunk", py::module_local())
.def_readwrite("frames", &Chunk::frames)
.def_readwrite("pts", &Chunk::pts);
py::class_<StreamWriter>(m, "StreamWriter", py::module_local())
.def(py::init<const std::string&, const c10::optional<std::string>&>())
.def("set_metadata", &StreamWriter::set_metadata)
.def("add_audio_stream", &StreamWriter::add_audio_stream)
.def("add_video_stream", &StreamWriter::add_video_stream)
.def("dump_format", &StreamWriter::dump_format)
.def("open", &StreamWriter::open)
.def("write_audio_chunk", &StreamWriter::write_audio_chunk)
.def("write_video_chunk", &StreamWriter::write_video_chunk)
.def("flush", &StreamWriter::flush)
.def("close", &StreamWriter::close);
py::class_<StreamWriterFileObj>(m, "StreamWriterFileObj", py::module_local())
.def(py::init<py::object, const c10::optional<std::string>&, int64_t>()) .def(py::init<py::object, const c10::optional<std::string>&, int64_t>())
.def("set_metadata", &StreamWriterFileObj::set_metadata) .def("set_metadata", &StreamWriterFileObj::set_metadata)
.def("add_audio_stream", &StreamWriterFileObj::add_audio_stream) .def("add_audio_stream", &StreamWriterFileObj::add_audio_stream)
...@@ -19,8 +32,53 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) { ...@@ -19,8 +32,53 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def("write_video_chunk", &StreamWriterFileObj::write_video_chunk) .def("write_video_chunk", &StreamWriterFileObj::write_video_chunk)
.def("flush", &StreamWriterFileObj::flush) .def("flush", &StreamWriterFileObj::flush)
.def("close", &StreamWriterFileObj::close); .def("close", &StreamWriterFileObj::close);
py::class_<StreamReaderFileObj, c10::intrusive_ptr<StreamReaderFileObj>>( py::class_<OutputStreamInfo>(m, "OutputStreamInfo", py::module_local())
m, "StreamReaderFileObj") .def_readonly("source_index", &OutputStreamInfo::source_index)
.def_readonly(
"filter_description", &OutputStreamInfo::filter_description);
py::class_<SrcStreamInfo>(m, "SourceStreamInfo", py::module_local())
.def_property_readonly(
"media_type",
[](const SrcStreamInfo& s) {
return av_get_media_type_string(s.media_type);
})
.def_readonly("codec_name", &SrcStreamInfo::codec_name)
.def_readonly("codec_long_name", &SrcStreamInfo::codec_long_name)
.def_readonly("format", &SrcStreamInfo::fmt_name)
.def_readonly("bit_rate", &SrcStreamInfo::bit_rate)
.def_readonly("num_frames", &SrcStreamInfo::num_frames)
.def_readonly("bits_per_sample", &SrcStreamInfo::bits_per_sample)
.def_readonly("metadata", &SrcStreamInfo::metadata)
.def_readonly("sample_rate", &SrcStreamInfo::sample_rate)
.def_readonly("num_channels", &SrcStreamInfo::num_channels)
.def_readonly("width", &SrcStreamInfo::width)
.def_readonly("height", &SrcStreamInfo::height)
.def_readonly("frame_rate", &SrcStreamInfo::frame_rate);
py::class_<StreamReader>(m, "StreamReader", py::module_local())
.def(py::init<
const std::string&,
const c10::optional<std::string>&,
const c10::optional<OptionDict>&>())
.def("num_src_streams", &StreamReader::num_src_streams)
.def("num_out_streams", &StreamReader::num_out_streams)
.def("find_best_audio_stream", &StreamReader::find_best_audio_stream)
.def("find_best_video_stream", &StreamReader::find_best_video_stream)
.def("get_metadata", &StreamReader::get_metadata)
.def("get_src_stream_info", &StreamReader::get_src_stream_info)
.def("get_out_stream_info", &StreamReader::get_out_stream_info)
.def("seek", &StreamReader::seek)
.def("add_audio_stream", &StreamReader::add_audio_stream)
.def("add_video_stream", &StreamReader::add_video_stream)
.def("remove_stream", &StreamReader::remove_stream)
.def(
"process_packet",
py::overload_cast<const c10::optional<double>&, const double>(
&StreamReader::process_packet))
.def("process_all_packets", &StreamReader::process_all_packets)
.def("fill_buffer", &StreamReader::fill_buffer)
.def("is_buffer_ready", &StreamReader::is_buffer_ready)
.def("pop_chunks", &StreamReader::pop_chunks);
py::class_<StreamReaderFileObj>(m, "StreamReaderFileObj", py::module_local())
.def(py::init< .def(py::init<
py::object, py::object,
const c10::optional<std::string>&, const c10::optional<std::string>&,
...@@ -41,7 +99,10 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) { ...@@ -41,7 +99,10 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def("add_audio_stream", &StreamReaderFileObj::add_audio_stream) .def("add_audio_stream", &StreamReaderFileObj::add_audio_stream)
.def("add_video_stream", &StreamReaderFileObj::add_video_stream) .def("add_video_stream", &StreamReaderFileObj::add_video_stream)
.def("remove_stream", &StreamReaderFileObj::remove_stream) .def("remove_stream", &StreamReaderFileObj::remove_stream)
.def("process_packet", &StreamReaderFileObj::process_packet) .def(
"process_packet",
py::overload_cast<const c10::optional<double>&, const double>(
&StreamReader::process_packet))
.def("process_all_packets", &StreamReaderFileObj::process_all_packets) .def("process_all_packets", &StreamReaderFileObj::process_all_packets)
.def("fill_buffer", &StreamReaderFileObj::fill_buffer) .def("fill_buffer", &StreamReaderFileObj::fill_buffer)
.def("is_buffer_ready", &StreamReaderFileObj::is_buffer_ready) .def("is_buffer_ready", &StreamReaderFileObj::is_buffer_ready)
......
...@@ -3,24 +3,6 @@ ...@@ -3,24 +3,6 @@
namespace torchaudio { namespace torchaudio {
namespace io { namespace io {
namespace {
SrcInfoPyBind convert_pybind(SrcStreamInfo ssi) {
return SrcInfoPyBind(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.num_frames,
ssi.bits_per_sample,
ssi.metadata,
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}
} // namespace
StreamReaderFileObj::StreamReaderFileObj( StreamReaderFileObj::StreamReaderFileObj(
py::object fileobj_, py::object fileobj_,
...@@ -28,11 +10,7 @@ StreamReaderFileObj::StreamReaderFileObj( ...@@ -28,11 +10,7 @@ StreamReaderFileObj::StreamReaderFileObj(
const c10::optional<std::map<std::string, std::string>>& option, const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size) int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size), false), : FileObj(fileobj_, static_cast<int>(buffer_size), false),
StreamReaderBinding(pAVIO, format, option) {} StreamReader(pAVIO, format, option) {}
SrcInfoPyBind StreamReaderFileObj::get_src_stream_info(int64_t i) {
return convert_pybind(StreamReader::get_src_stream_info(static_cast<int>(i)));
}
} // namespace io } // namespace io
} // namespace torchaudio } // namespace torchaudio
#pragma once #pragma once
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h> #include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader_wrapper.h> #include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
namespace torchaudio { namespace torchaudio {
namespace io { namespace io {
...@@ -8,15 +8,13 @@ namespace io { ...@@ -8,15 +8,13 @@ namespace io {
// The reason we inherit FileObj instead of making it an attribute // The reason we inherit FileObj instead of making it an attribute
// is so that FileObj is instantiated first. // is so that FileObj is instantiated first.
// AVIOContext must be initialized before AVFormat, and outlive AVFormat. // AVIOContext must be initialized before AVFormat, and outlive AVFormat.
class StreamReaderFileObj : protected FileObj, public StreamReaderBinding { class StreamReaderFileObj : private FileObj, public StreamReader {
public: public:
StreamReaderFileObj( StreamReaderFileObj(
py::object fileobj, py::object fileobj,
const c10::optional<std::string>& format, const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option, const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size); int64_t buffer_size);
SrcInfoPyBind get_src_stream_info(int64_t i);
}; };
} // namespace io } // namespace io
......
...@@ -422,6 +422,39 @@ int StreamReader::process_packet_block(double timeout, double backoff) { ...@@ -422,6 +422,39 @@ int StreamReader::process_packet_block(double timeout, double backoff) {
} }
} }
void StreamReader::process_all_packets() {
int64_t ret = 0;
do {
ret = process_packet();
} while (!ret);
}
int StreamReader::process_packet(
const c10::optional<double>& timeout,
const double backoff) {
int code = [&]() -> int {
if (timeout.has_value()) {
return process_packet_block(timeout.value(), backoff);
}
return process_packet();
}();
TORCH_CHECK(
code >= 0, "Failed to process a packet. (" + av_err2string(code) + "). ");
return code;
}
int StreamReader::fill_buffer(
const c10::optional<double>& timeout,
const double backoff) {
while (!is_buffer_ready()) {
int code = process_packet(timeout, backoff);
if (code != 0) {
return code;
}
}
return 0;
}
// <0: Some error happened. // <0: Some error happened.
int StreamReader::drain() { int StreamReader::drain() {
int ret = 0, tmp = 0; int ret = 0, tmp = 0;
......
...@@ -283,7 +283,23 @@ class StreamReader { ...@@ -283,7 +283,23 @@ class StreamReader {
/// - ``>=0``: Keep retrying until the given time passes. /// - ``>=0``: Keep retrying until the given time passes.
/// - ``<0``: Keep retrying forever. /// - ``<0``: Keep retrying forever.
/// @param backoff Time to wait before retrying in milli seconds. /// @param backoff Time to wait before retrying in milli seconds.
int process_packet_block(double timeout, double backoff); int process_packet_block(const double timeout, const double backoff);
// High-level method used by Python bindings.
int process_packet(
const c10::optional<double>& timeout,
const double backoff);
/// Process packets unitl EOF
void process_all_packets();
/// Process packets until all the chunk buffers have at least one chunk
///
/// @param timeout See `process_packet_block()`
/// @param backoff See `process_packet_block()`
int fill_buffer(
const c10::optional<double>& timeout = {},
const double backoff = 10.);
///@} ///@}
......
...@@ -80,15 +80,13 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -80,15 +80,13 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.def("remove_stream", [](S s, int64_t i) { s->remove_stream(i); }) .def("remove_stream", [](S s, int64_t i) { s->remove_stream(i); })
.def( .def(
"process_packet", "process_packet",
[](S s, const c10::optional<double>& timeout, const double backoff) { [](S s, const c10::optional<double>& timeout, const double backoff)
return s->process_packet(timeout, backoff); -> int64_t { return s->process_packet(timeout, backoff); })
})
.def("process_all_packets", [](S s) { s->process_all_packets(); }) .def("process_all_packets", [](S s) { s->process_all_packets(); })
.def( .def(
"fill_buffer", "fill_buffer",
[](S s, const c10::optional<double>& timeout, const double backoff) { [](S s, const c10::optional<double>& timeout, const double backoff)
return s->fill_buffer(timeout, backoff); -> int64_t { return s->fill_buffer(timeout, backoff); })
})
.def("is_buffer_ready", [](S s) { return s->is_buffer_ready(); }) .def("is_buffer_ready", [](S s) { return s->is_buffer_ready(); })
.def("pop_chunks", [](S s) { return s->pop_chunks(); }); .def("pop_chunks", [](S s) { return s->pop_chunks(); });
} }
......
...@@ -36,39 +36,6 @@ OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) { ...@@ -36,39 +36,6 @@ OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) {
return convert(StreamReader::get_out_stream_info(static_cast<int>(i))); return convert(StreamReader::get_out_stream_info(static_cast<int>(i)));
} }
int64_t StreamReaderBinding::process_packet(
const c10::optional<double>& timeout,
const double backoff) {
int64_t code = [&]() {
if (timeout.has_value()) {
return StreamReader::process_packet_block(timeout.value(), backoff);
}
return StreamReader::process_packet();
}();
TORCH_CHECK(
code >= 0, "Failed to process a packet. (" + av_err2string(code) + "). ");
return code;
}
void StreamReaderBinding::process_all_packets() {
int64_t ret = 0;
do {
ret = process_packet();
} while (!ret);
}
int64_t StreamReaderBinding::fill_buffer(
const c10::optional<double>& timeout,
const double backoff) {
while (!is_buffer_ready()) {
int code = process_packet(timeout, backoff);
if (code != 0) {
return code;
}
}
return 0;
}
std::vector<c10::optional<ChunkData>> StreamReaderBinding::pop_chunks() { std::vector<c10::optional<ChunkData>> StreamReaderBinding::pop_chunks() {
std::vector<c10::optional<ChunkData>> ret; std::vector<c10::optional<ChunkData>> ret;
ret.reserve(static_cast<size_t>(num_out_streams())); ret.reserve(static_cast<size_t>(num_out_streams()));
......
...@@ -31,24 +31,6 @@ using SrcInfo = std::tuple< ...@@ -31,24 +31,6 @@ using SrcInfo = std::tuple<
double // frame_rate double // frame_rate
>; >;
using SrcInfoPyBind = std::tuple<
std::string, // media_type
std::string, // codec name
std::string, // codec long name
std::string, // format name
int64_t, // bit_rate
int64_t, // num_frames
int64_t, // bits_per_sample
std::map<std::string, std::string>, // metadata
// Audio
double, // sample_rate
int64_t, // num_channels
// Video
int64_t, // width
int64_t, // height
double // frame_rate
>;
using OutInfo = std::tuple< using OutInfo = std::tuple<
int64_t, // source index int64_t, // source index
std::string // filter description std::string // filter description
...@@ -65,16 +47,6 @@ struct StreamReaderBinding : public StreamReader, ...@@ -65,16 +47,6 @@ struct StreamReaderBinding : public StreamReader,
SrcInfo get_src_stream_info(int64_t i); SrcInfo get_src_stream_info(int64_t i);
OutInfo get_out_stream_info(int64_t i); OutInfo get_out_stream_info(int64_t i);
int64_t process_packet(
const c10::optional<double>& timeout = {},
const double backoff = 10.);
void process_all_packets();
int64_t fill_buffer(
const c10::optional<double>& timeout = {},
const double backoff = 10.);
std::vector<c10::optional<ChunkData>> pop_chunks(); std::vector<c10::optional<ChunkData>> pop_chunks();
}; };
......
#include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/binding_utils.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace torchaudio {
namespace io {
namespace {
struct StreamWriterBinding : public StreamWriter,
public torch::CustomClassHolder {
using StreamWriter::StreamWriter;
};
using S = const c10::intrusive_ptr<StreamWriterBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<StreamWriterBinding>("ffmpeg_StreamWriter")
.def(torch::init<>(
[](const std::string& dst, const c10::optional<std::string>& format) {
return c10::make_intrusive<StreamWriterBinding>(dst, format);
}))
.def(
"add_audio_stream",
[](S s,
int64_t sample_rate,
int64_t num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDictC10>& encoder_option,
const c10::optional<std::string>& encoder_format) {
s->add_audio_stream(
sample_rate,
num_channels,
format,
encoder,
from_c10(encoder_option),
encoder_format);
})
.def(
"add_video_stream",
[](S s,
double frame_rate,
int64_t width,
int64_t height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDictC10>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
frame_rate,
width,
height,
format,
encoder,
from_c10(encoder_option),
encoder_format,
hw_accel);
})
.def(
"set_metadata",
[](S s, const OptionDictC10& metadata) {
s->set_metadata(from_c10(metadata));
})
.def("dump_format", [](S s, int64_t i) { s->dump_format(i); })
.def(
"open",
[](S s, const c10::optional<OptionDictC10>& option) {
s->open(from_c10(option));
})
.def("close", [](S s) { s->close(); })
.def(
"write_audio_chunk",
[](S s, int64_t i, const torch::Tensor& chunk) {
s->write_audio_chunk(static_cast<int>(i), chunk);
})
.def(
"write_video_chunk",
[](S s, int64_t i, const torch::Tensor& chunk) {
s->write_video_chunk(static_cast<int>(i), chunk);
})
.def("flush", [](S s) { s->flush(); });
}
} // namespace
} // namespace io
} // namespace torchaudio
...@@ -9,13 +9,15 @@ from torchaudio.io import StreamWriter ...@@ -9,13 +9,15 @@ from torchaudio.io import StreamWriter
# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global # Note: need to comply TorchScript syntax -- need annotation and no f-string nor global
def _info_audio( def info_audio(
s: torch.classes.torchaudio.ffmpeg_StreamReader, src: str,
): format: Optional[str],
) -> AudioMetaData:
s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None)
i = s.find_best_audio_stream() i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i) sinfo = s.get_src_stream_info(i)
if sinfo[5] == 0: if sinfo[5] == 0:
waveform, _ = _load_audio(s) waveform = _load_audio(s)
num_frames = waveform.size(1) num_frames = waveform.size(1)
else: else:
num_frames = sinfo[5] num_frames = sinfo[5]
...@@ -28,21 +30,26 @@ def _info_audio( ...@@ -28,21 +30,26 @@ def _info_audio(
) )
def info_audio(
src: str,
format: Optional[str],
) -> AudioMetaData:
s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None)
return _info_audio(s)
def info_audio_fileobj( def info_audio_fileobj(
src, src,
format: Optional[str], format: Optional[str],
buffer_size: int = 4096, buffer_size: int = 4096,
) -> AudioMetaData: ) -> AudioMetaData:
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size) s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
return _info_audio(s) i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i)
if sinfo.num_frames == 0:
waveform = _load_audio_fileobj(s)
num_frames = waveform.size(1)
else:
num_frames = sinfo.num_frames
return AudioMetaData(
int(sinfo.sample_rate),
num_frames,
sinfo.num_channels,
sinfo.bits_per_sample,
sinfo.codec_name.upper(),
)
def _get_load_filter( def _get_load_filter(
...@@ -79,10 +86,8 @@ def _load_audio( ...@@ -79,10 +86,8 @@ def _load_audio(
num_frames: int = -1, num_frames: int = -1,
convert: bool = True, convert: bool = True,
channels_first: bool = True, channels_first: bool = True,
) -> Tuple[torch.Tensor, int]: ) -> torch.Tensor:
i = s.find_best_audio_stream() i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i)
sample_rate = int(sinfo[8])
option: Dict[str, str] = {} option: Dict[str, str] = {}
s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option) s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option)
s.process_all_packets() s.process_all_packets()
...@@ -93,7 +98,28 @@ def _load_audio( ...@@ -93,7 +98,28 @@ def _load_audio(
waveform = chunk[0] waveform = chunk[0]
if channels_first: if channels_first:
waveform = waveform.T waveform = waveform.T
return waveform, sample_rate return waveform
def _load_audio_fileobj(
s: torch.classes.torchaudio.ffmpeg_StreamReader,
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
channels_first: bool = True,
) -> torch.Tensor:
i = s.find_best_audio_stream()
option: Dict[str, str] = {}
s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option)
s.process_all_packets()
chunk = s.pop_chunks()[0]
if chunk is None:
raise RuntimeError("Failed to decode audio.")
assert chunk is not None
waveform = chunk.frames
if channels_first:
waveform = waveform.T
return waveform
def load_audio( def load_audio(
...@@ -105,7 +131,9 @@ def load_audio( ...@@ -105,7 +131,9 @@ def load_audio(
format: Optional[str] = None, format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None) s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None)
return _load_audio(s, frame_offset, num_frames, convert, channels_first) sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream())[8])
waveform = _load_audio(s, frame_offset, num_frames, convert, channels_first)
return waveform, sample_rate
def load_audio_fileobj( def load_audio_fileobj(
...@@ -118,7 +146,9 @@ def load_audio_fileobj( ...@@ -118,7 +146,9 @@ def load_audio_fileobj(
buffer_size: int = 4096, buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size) s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
return _load_audio(s, frame_offset, num_frames, convert, channels_first) sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate)
waveform = _load_audio_fileobj(s, frame_offset, num_frames, convert, channels_first)
return waveform, sample_rate
def _get_sample_format(dtype: torch.dtype) -> str: def _get_sample_format(dtype: torch.dtype) -> str:
......
...@@ -103,70 +103,44 @@ class SourceVideoStream(SourceStream): ...@@ -103,70 +103,44 @@ class SourceVideoStream(SourceStream):
"""Frame rate.""" """Frame rate."""
# Indices of SrcInfo returned by low-level `get_src_stream_info`
# - COMMON
_MEDIA_TYPE = 0
_CODEC = 1
_CODEC_LONG = 2
_FORMAT = 3
_BIT_RATE = 4
_NUM_FRAMES = 5
_BPS = 6
_METADATA = 7
# - AUDIO
_SAMPLE_RATE = 8
_NUM_CHANNELS = 9
# - VIDEO
_WIDTH = 10
_HEIGHT = 11
_FRAME_RATE = 12
def _parse_si(i): def _parse_si(i):
media_type = i[_MEDIA_TYPE] media_type = i.media_type
codec_name = i[_CODEC]
codec_long_name = i[_CODEC_LONG]
fmt = i[_FORMAT]
bit_rate = i[_BIT_RATE]
num_frames = i[_NUM_FRAMES]
bps = i[_BPS]
metadata = i[_METADATA]
if media_type == "audio": if media_type == "audio":
return SourceAudioStream( return SourceAudioStream(
media_type=media_type, media_type=i.media_type,
codec=codec_name, codec=i.codec_name,
codec_long_name=codec_long_name, codec_long_name=i.codec_long_name,
format=fmt, format=i.format,
bit_rate=bit_rate, bit_rate=i.bit_rate,
num_frames=num_frames, num_frames=i.num_frames,
bits_per_sample=bps, bits_per_sample=i.bits_per_sample,
metadata=metadata, metadata=i.metadata,
sample_rate=i[_SAMPLE_RATE], sample_rate=i.sample_rate,
num_channels=i[_NUM_CHANNELS], num_channels=i.num_channels,
) )
if media_type == "video": if media_type == "video":
return SourceVideoStream( return SourceVideoStream(
media_type=media_type, media_type=i.media_type,
codec=codec_name, codec=i.codec_name,
codec_long_name=codec_long_name, codec_long_name=i.codec_long_name,
format=fmt, format=i.format,
bit_rate=bit_rate, bit_rate=i.bit_rate,
num_frames=num_frames, num_frames=i.num_frames,
bits_per_sample=bps, bits_per_sample=i.bits_per_sample,
metadata=metadata, metadata=i.metadata,
width=i[_WIDTH], width=i.width,
height=i[_HEIGHT], height=i.height,
frame_rate=i[_FRAME_RATE], frame_rate=i.frame_rate,
) )
return SourceStream( return SourceStream(
media_type=media_type, media_type=i.media_type,
codec=codec_name, codec=i.codec_name,
codec_long_name=codec_long_name, codec_long_name=i.codec_long_name,
format=None, format=None,
bit_rate=None, bit_rate=None,
num_frames=None, num_frames=None,
bits_per_sample=None, bits_per_sample=None,
metadata=metadata, metadata=i.metadata,
) )
...@@ -182,10 +156,6 @@ class OutputStream: ...@@ -182,10 +156,6 @@ class OutputStream:
"""Description of filter graph applied to the source stream.""" """Description of filter graph applied to the source stream."""
def _parse_oi(i):
return OutputStream(i[0], i[1])
def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]): def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]):
descs = [] descs = []
if sample_rate is not None: if sample_rate is not None:
...@@ -461,7 +431,7 @@ class StreamReader: ...@@ -461,7 +431,7 @@ class StreamReader:
): ):
torch._C._log_api_usage_once("torchaudio.io.StreamReader") torch._C._log_api_usage_once("torchaudio.io.StreamReader")
if isinstance(src, str): if isinstance(src, str):
self._be = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, option) self._be = torchaudio.lib._torchaudio_ffmpeg.StreamReader(src, format, option)
elif hasattr(src, "read"): elif hasattr(src, "read"):
self._be = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, option, buffer_size) self._be = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, option, buffer_size)
else: else:
...@@ -533,7 +503,8 @@ class StreamReader: ...@@ -533,7 +503,8 @@ class StreamReader:
Returns: Returns:
OutputStream OutputStream
""" """
return _parse_oi(self._be.get_out_stream_info(i)) info = self._be.get_out_stream_info(i)
return OutputStream(info.source_index, info.filter_description)
def seek(self, timestamp: float, mode: str = "precise"): def seek(self, timestamp: float, mode: str = "precise"):
"""Seek the stream to the given timestamp [second] """Seek the stream to the given timestamp [second]
...@@ -843,7 +814,7 @@ class StreamReader: ...@@ -843,7 +814,7 @@ class StreamReader:
if chunk is None: if chunk is None:
ret.append(None) ret.append(None)
else: else:
ret.append(ChunkTensor(chunk[0], chunk[1])) ret.append(ChunkTensor(chunk.frames, chunk.pts))
return ret return ret
def fill_buffer(self, timeout: Optional[float] = None, backoff: float = 10.0) -> int: def fill_buffer(self, timeout: Optional[float] = None, backoff: float = 10.0) -> int:
......
...@@ -111,7 +111,7 @@ class StreamWriter: ...@@ -111,7 +111,7 @@ class StreamWriter:
): ):
torch._C._log_api_usage_once("torchaudio.io.StreamWriter") torch._C._log_api_usage_once("torchaudio.io.StreamWriter")
if isinstance(dst, str): if isinstance(dst, str):
self._s = torch.classes.torchaudio.ffmpeg_StreamWriter(dst, format) self._s = torchaudio.lib._torchaudio_ffmpeg.StreamWriter(dst, format)
elif hasattr(dst, "write"): elif hasattr(dst, "write"):
self._s = torchaudio.lib._torchaudio_ffmpeg.StreamWriterFileObj(dst, format, buffer_size) self._s = torchaudio.lib._torchaudio_ffmpeg.StreamWriterFileObj(dst, format, buffer_size)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment