Commit b46628ba authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Cleanup ffmpeg bidings (#3095)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3095

Reviewed By: nateanl

Differential Revision: D43544998

Pulled By: mthrok

fbshipit-source-id: 4359cdbbdbee53084016a84129cb3d65900b0457
parent b012b452
......@@ -9,7 +9,6 @@ set(
sources
ffmpeg.cpp
filter_graph.cpp
binding_utils.cpp
stream_reader/buffer/common.cpp
stream_reader/buffer/chunked_buffer.cpp
stream_reader/buffer/unchunked_buffer.cpp
......@@ -17,9 +16,8 @@ set(
stream_reader/sink.cpp
stream_reader/stream_processor.cpp
stream_reader/stream_reader.cpp
stream_reader/stream_reader_wrapper.cpp
stream_reader/stream_reader_binding.cpp
stream_writer/stream_writer.cpp
compat.cpp
utils.cpp
)
......@@ -40,10 +38,8 @@ torchaudio_library(
if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
set(
ext_sources
pybind/typedefs.cpp
pybind/fileobj.cpp
pybind/pybind.cpp
pybind/stream_reader.cpp
pybind/stream_writer.cpp
)
torchaudio_extension(
_torchaudio_ffmpeg
......
#include <torchaudio/csrc/ffmpeg/binding_utils.h>
namespace torchaudio::io {
OptionDictC10 to_c10(const OptionDict& src) {
OptionDictC10 ret;
for (auto const& [key, value] : src) {
ret.insert(key, value);
}
return ret;
}
OptionDict from_c10(const OptionDictC10& src) {
OptionDict ret;
for (const auto& it : src) {
ret.emplace(it.key(), it.value());
}
return ret;
}
c10::optional<OptionDict> from_c10(const c10::optional<OptionDictC10>& src) {
if (src) {
return {from_c10(src.value())};
}
return {};
}
} // namespace torchaudio::io
#pragma once
#include <torch/types.h>
namespace torchaudio::io {
using OptionDict = std::map<std::string, std::string>;
using OptionDictC10 = c10::Dict<std::string, std::string>;
OptionDictC10 to_c10(const OptionDict&);
OptionDict from_c10(const OptionDictC10&);
c10::optional<OptionDict> from_c10(const c10::optional<OptionDictC10>&);
} // namespace torchaudio::io
#include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
#include <stdexcept>
namespace torchaudio {
namespace io {
namespace {
torch::Tensor _load_audio(
StreamReader& s,
int i,
const c10::optional<std::string>& filter,
const bool& channels_first) {
s.add_audio_stream(i, -1, -1, filter, {}, {});
s.process_all_packets();
auto chunk = s.pop_chunks()[0];
TORCH_CHECK(chunk, "Failed to decode audio.");
auto waveform = chunk.value().frames;
return channels_first ? waveform.transpose(0, 1) : waveform;
}
std::tuple<torch::Tensor, int64_t> load(
const std::string& src,
const c10::optional<std::string>& format,
const c10::optional<std::string>& filter,
const bool& channels_first) {
StreamReader s{src, format, {}};
auto i = s.find_best_audio_stream();
auto sample_rate = s.get_src_stream_info(i).sample_rate;
auto waveform = _load_audio(s, i, filter, channels_first);
return {waveform, sample_rate};
}
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> info(
const std::string& src,
const c10::optional<std::string>& format) {
StreamReader s{src, format, {}};
auto i = s.find_best_audio_stream();
auto sinfo = s.get_src_stream_info(i);
int64_t num_frames = [&]() {
if (sinfo.num_frames == 0) {
torch::Tensor waveform = _load_audio(s, i, {}, false);
return waveform.size(0);
}
return sinfo.num_frames;
}();
return {
static_cast<int64_t>(sinfo.sample_rate),
static_cast<int64_t>(num_frames),
static_cast<int64_t>(sinfo.num_channels),
static_cast<int64_t>(sinfo.bits_per_sample),
sinfo.codec_name};
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::compat_load", &load);
m.def("torchaudio::compat_info", &info);
}
} // namespace
} // namespace io
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
#include <torchaudio/csrc/ffmpeg/pybind/fileobj.h>
namespace torchaudio {
namespace io {
......
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h>
#include <torchaudio/csrc/ffmpeg/pybind/fileobj.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace torchaudio {
namespace io {
namespace {
// The reason we inherit FileObj instead of making it an attribute
// is so that FileObj is instantiated first.
// AVIOContext must be initialized before AVFormat, and outlive AVFormat.
struct StreamReaderFileObj : private FileObj, public StreamReader {
StreamReaderFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: FileObj(fileobj, static_cast<int>(buffer_size), false),
StreamReader(pAVIO, format, option) {}
};
struct StreamWriterFileObj : private FileObj, public StreamWriter {
StreamWriterFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
int64_t buffer_size)
: FileObj(fileobj, static_cast<int>(buffer_size), true),
StreamWriter(pAVIO, format) {}
};
PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<Chunk>(m, "Chunk", py::module_local())
.def_readwrite("frames", &Chunk::frames)
......
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
namespace torchaudio {
namespace io {
StreamReaderFileObj::StreamReaderFileObj(
py::object fileobj_,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size), false),
StreamReader(pAVIO, format, option) {}
} // namespace io
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
namespace torchaudio {
namespace io {
// The reason we inherit FileObj instead of making it an attribute
// is so that FileObj is instantiated first.
// AVIOContext must be initialized before AVFormat, and outlive AVFormat.
class StreamReaderFileObj : private FileObj, public StreamReader {
public:
StreamReaderFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size);
};
} // namespace io
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h>
namespace torchaudio {
namespace io {
StreamWriterFileObj::StreamWriterFileObj(
py::object fileobj_,
const c10::optional<std::string>& format,
int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size), true),
StreamWriter(pAVIO, format) {}
} // namespace io
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace torchaudio {
namespace io {
class StreamWriterFileObj : private FileObj, public StreamWriter {
public:
StreamWriterFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
int64_t buffer_size);
};
} // namespace io
} // namespace torchaudio
#include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/binding_utils.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader_wrapper.h>
#include <stdexcept>
namespace torchaudio {
namespace io {
namespace {
using S = const c10::intrusive_ptr<StreamReaderBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::ffmpeg_init", []() { avdevice_register_all(); });
m.def("torchaudio::ffmpeg_get_log_level", []() -> int64_t {
return static_cast<int64_t>(av_log_get_level());
});
m.def("torchaudio::ffmpeg_set_log_level", [](int64_t level) {
av_log_set_level(static_cast<int>(level));
});
m.class_<StreamReaderBinding>("ffmpeg_StreamReader")
.def(torch::init<>([](const std::string& src,
const c10::optional<std::string>& format,
const c10::optional<OptionDictC10>& option) {
return c10::make_intrusive<StreamReaderBinding>(
src, format, from_c10(option));
}))
.def("num_src_streams", [](S self) { return self->num_src_streams(); })
.def("num_out_streams", [](S self) { return self->num_out_streams(); })
.def("get_metadata", [](S self) { return to_c10(self->get_metadata()); })
.def(
"get_src_stream_info",
[](S s, int64_t i) { return s->get_src_stream_info(i); })
.def(
"get_out_stream_info",
[](S s, int64_t i) { return s->get_out_stream_info(i); })
.def(
"find_best_audio_stream",
[](S s) { return s->find_best_audio_stream(); })
.def(
"find_best_video_stream",
[](S s) { return s->find_best_video_stream(); })
.def("seek", [](S s, double t, int64_t mode) { return s->seek(t, mode); })
.def(
"add_audio_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDictC10>& decoder_option) {
s->add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
from_c10(decoder_option));
})
.def(
"add_video_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDictC10>& decoder_option,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
from_c10(decoder_option),
hw_accel);
})
.def("remove_stream", [](S s, int64_t i) { s->remove_stream(i); })
.def(
"process_packet",
[](S s, const c10::optional<double>& timeout, const double backoff)
-> int64_t { return s->process_packet(timeout, backoff); })
.def("process_all_packets", [](S s) { s->process_all_packets(); })
.def(
"fill_buffer",
[](S s, const c10::optional<double>& timeout, const double backoff)
-> int64_t { return s->fill_buffer(timeout, backoff); })
.def("is_buffer_ready", [](S s) { return s->is_buffer_ready(); })
.def("pop_chunks", [](S s) { return s->pop_chunks(); });
}
} // namespace
} // namespace io
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/binding_utils.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader_wrapper.h>
namespace torchaudio {
namespace io {
namespace {
SrcInfo convert(SrcStreamInfo ssi) {
return SrcInfo(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.num_frames,
ssi.bits_per_sample,
to_c10(ssi.metadata),
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}
OutInfo convert(OutputStreamInfo osi) {
return OutInfo(
std::forward_as_tuple(osi.source_index, osi.filter_description));
}
} // namespace
SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) {
return convert(StreamReader::get_src_stream_info(static_cast<int>(i)));
}
OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) {
return convert(StreamReader::get_out_stream_info(static_cast<int>(i)));
}
std::vector<c10::optional<ChunkData>> StreamReaderBinding::pop_chunks() {
std::vector<c10::optional<ChunkData>> ret;
ret.reserve(static_cast<size_t>(num_out_streams()));
for (auto& c : StreamReader::pop_chunks()) {
if (c) {
ret.emplace_back(std::forward_as_tuple(c->frames, c->pts));
} else {
ret.emplace_back();
}
}
return ret;
}
} // namespace io
} // namespace torchaudio
#pragma once
#include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
namespace torchaudio {
namespace io {
// Because TorchScript requires c10::Dict type to pass dict,
// while PyBind11 requires std::map type to pass dict,
// we duplicate the return tuple.
// Even though all the PyBind-based implementations are placed
// in `pybind` directory, because std::map does not require pybind11
// header, we define both of them here, for the sake of
// better locality/maintainability.
using SrcInfo = std::tuple<
std::string, // media_type
std::string, // codec name
std::string, // codec long name
std::string, // format name
int64_t, // bit_rate
int64_t, // num_frames
int64_t, // bits_per_sample
c10::Dict<std::string, std::string>, // metadata
// Audio
double, // sample_rate
int64_t, // num_channels
// Video
int64_t, // width
int64_t, // height
double // frame_rate
>;
using OutInfo = std::tuple<
int64_t, // source index
std::string // filter description
>;
using ChunkData = std::tuple<torch::Tensor, double>;
// Structure to implement wrapper API around StreamReader, which is more
// suitable for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderBinding : public StreamReader,
public torch::CustomClassHolder {
using StreamReader::StreamReader;
SrcInfo get_src_stream_info(int64_t i);
OutInfo get_out_stream_info(int64_t i);
std::vector<c10::optional<ChunkData>> pop_chunks();
};
} // namespace io
} // namespace torchaudio
......@@ -97,6 +97,13 @@ std::string get_build_config() {
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::ffmpeg_init", []() { avdevice_register_all(); });
m.def("torchaudio::ffmpeg_get_log_level", []() -> int64_t {
return static_cast<int64_t>(av_log_get_level());
});
m.def("torchaudio::ffmpeg_set_log_level", [](int64_t level) {
av_log_set_level(static_cast<int>(level));
});
m.def("torchaudio::ffmpeg_get_versions", &get_versions);
m.def("torchaudio::ffmpeg_get_muxers", []() { return get_muxers(false); });
m.def(
......
......@@ -13,21 +13,8 @@ def info_audio(
src: str,
format: Optional[str],
) -> AudioMetaData:
s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None)
i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i)
if sinfo[5] == 0:
waveform = _load_audio(s)
num_frames = waveform.size(1)
else:
num_frames = sinfo[5]
return AudioMetaData(
int(sinfo[8]),
num_frames,
sinfo[9],
sinfo[6],
sinfo[1].upper(),
)
i = torch.ops.torchaudio.compat_info(src, format)
return AudioMetaData(i[0], i[1], i[2], i[3], i[4].upper())
def info_audio_fileobj(
......@@ -79,47 +66,19 @@ def _get_load_filter(
return "{},{}".format(atrim, aformat)
# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global
def _load_audio(
s: torch.classes.torchaudio.ffmpeg_StreamReader,
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
channels_first: bool = True,
) -> torch.Tensor:
i = s.find_best_audio_stream()
option: Dict[str, str] = {}
s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option)
s.process_all_packets()
chunk = s.pop_chunks()[0]
if chunk is None:
raise RuntimeError("Failed to decode audio.")
assert chunk is not None
waveform = chunk[0]
if channels_first:
waveform = waveform.T
return waveform
def _load_audio_fileobj(
s: torch.classes.torchaudio.ffmpeg_StreamReader,
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
s: torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj,
filter: Optional[str] = None,
channels_first: bool = True,
) -> torch.Tensor:
i = s.find_best_audio_stream()
option: Dict[str, str] = {}
s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option)
s.add_audio_stream(i, -1, -1, filter, None, None)
s.process_all_packets()
chunk = s.pop_chunks()[0]
if chunk is None:
raise RuntimeError("Failed to decode audio.")
assert chunk is not None
waveform = chunk.frames
if channels_first:
waveform = waveform.T
return waveform
return waveform.T if channels_first else waveform
def load_audio(
......@@ -130,10 +89,8 @@ def load_audio(
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None)
sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream())[8])
waveform = _load_audio(s, frame_offset, num_frames, convert, channels_first)
return waveform, sample_rate
filter = _get_load_filter(frame_offset, num_frames, convert)
return torch.ops.torchaudio.compat_load(src, format, filter, channels_first)
def load_audio_fileobj(
......@@ -147,7 +104,8 @@ def load_audio_fileobj(
) -> Tuple[torch.Tensor, int]:
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate)
waveform = _load_audio_fileobj(s, frame_offset, num_frames, convert, channels_first)
filter = _get_load_filter(frame_offset, num_frames, convert)
waveform = _load_audio_fileobj(s, filter, channels_first)
return waveform, sample_rate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment