Commit e2641452 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Clean up the interface around dictionary (#2533)

Summary:
Python dictionary is bound to different types in TorchBind and PyBind.
StreamReader has methods that receive and return dictionary.

This commit cleans up the treatment of dictionary and consolidate
helper functions.

* The core implementation and TorchBind all uses `c10::Dict`.
* PyBind version uses `std::map` and converts it to `c10::Dict`.
* The helper functions to convert `std::map` <-> `c10::Dict` are consolidated in pybind directory.
* The wrapper methods are implemented in `pybind` dir.

Pull Request resolved: https://github.com/pytorch/audio/pull/2533

Reviewed By: hwangjeff

Differential Revision: D37731866

Pulled By: mthrok

fbshipit-source-id: 5a5cf1372668f7d3aacc0bb461bc69fa07212f3f
parent 05d2580a
......@@ -11,15 +11,18 @@ namespace ffmpeg {
////////////////////////////////////////////////////////////////////////////////
// AVDictionary
////////////////////////////////////////////////////////////////////////////////
AVDictionary* get_option_dict(const OptionDict& option) {
AVDictionary* get_option_dict(const c10::optional<OptionDict>& option) {
AVDictionary* opt = nullptr;
for (const auto& it : option) {
av_dict_set(&opt, it.first.c_str(), it.second.c_str(), 0);
if (option) {
for (const auto& it : option.value()) {
av_dict_set(&opt, it.key().c_str(), it.value().c_str(), 0);
}
}
return opt;
}
void clean_up_dict(AVDictionary* p) {
if (p) {
std::vector<std::string> unused_keys;
// Check and copy unused keys, clean up the original dictionary
AVDictionaryEntry* t = nullptr;
......@@ -27,10 +30,10 @@ void clean_up_dict(AVDictionary* p) {
unused_keys.emplace_back(t->key);
}
av_dict_free(&p);
if (!unused_keys.empty()) {
throw std::runtime_error(
"Unexpected options: " + c10::Join(", ", unused_keys));
TORCH_CHECK(
unused_keys.empty(),
"Unexpected options: ",
c10::Join(", ", unused_keys));
}
}
......
......@@ -25,7 +25,7 @@ extern "C" {
namespace torchaudio {
namespace ffmpeg {
using OptionDict = std::map<std::string, std::string>;
using OptionDict = c10::Dict<std::string, std::string>;
// https://github.com/FFmpeg/FFmpeg/blob/4e6debe1df7d53f3f59b37449b82265d5c08a172/doc/APIchanges#L252-L260
// Starting from libavformat 59 (ffmpeg 5),
......@@ -76,7 +76,7 @@ class Wrapper {
// IIRC-semantic. Instead we provide helper functions.
// Convert standard dict to FFmpeg native type
AVDictionary* get_option_dict(const OptionDict& option);
AVDictionary* get_option_dict(const c10::optional<OptionDict>& option);
// Clean up the dict after use. If there is an unsed key, throw runtime error
void clean_up_dict(AVDictionary* p);
......
......@@ -12,7 +12,7 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def(py::init<
py::object,
const c10::optional<std::string>&,
const c10::optional<OptionDict>&,
const c10::optional<std::map<std::string, std::string>>&,
int64_t>())
.def("num_src_streams", &StreamReaderFileObj::num_src_streams)
.def("num_out_streams", &StreamReaderFileObj::num_out_streams)
......@@ -23,9 +23,7 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
"find_best_video_stream",
&StreamReaderFileObj::find_best_video_stream)
.def("get_metadata", &StreamReaderFileObj::get_metadata)
.def(
"get_src_stream_info",
&StreamReaderFileObj::get_src_stream_info_pybind)
.def("get_src_stream_info", &StreamReaderFileObj::get_src_stream_info)
.def("get_out_stream_info", &StreamReaderFileObj::get_out_stream_info)
.def("seek", &StreamReaderFileObj::seek)
.def("add_audio_stream", &StreamReaderFileObj::add_audio_stream)
......
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
SrcInfoPyBind convert_pybind(SrcStreamInfo ssi) {
return SrcInfoPyBind(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.num_frames,
ssi.bits_per_sample,
dict2map(ssi.metadata),
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}
} // namespace
StreamReaderFileObj::StreamReaderFileObj(
py::object fileobj_,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size)),
StreamReaderBinding(get_input_format_context(
static_cast<std::string>(py::str(fileobj_.attr("__str__")())),
format,
option.value_or(OptionDict{}),
map2dict(option),
pAVIO)) {}
std::map<std::string, std::string> StreamReaderFileObj::get_metadata() const {
std::map<std::string, std::string> ret;
for (const auto& it : StreamReader::get_metadata()) {
ret.insert({it.key(), it.value()});
}
return ret;
return dict2map(StreamReader::get_metadata());
};
SrcInfoPyBind StreamReaderFileObj::get_src_stream_info(int64_t i) {
return convert_pybind(StreamReader::get_src_stream_info(static_cast<int>(i)));
}
void StreamReaderFileObj::add_audio_stream(
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<std::map<std::string, std::string>>& decoder_option) {
StreamReader::add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map2dict(decoder_option));
}
void StreamReaderFileObj::add_video_stream(
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<std::map<std::string, std::string>>& decoder_option,
const c10::optional<std::string>& hw_accel) {
StreamReader::add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map2dict(decoder_option),
hw_accel);
}
} // namespace ffmpeg
} // namespace torchaudio
......@@ -13,10 +13,28 @@ class StreamReaderFileObj : protected FileObj, public StreamReaderBinding {
StreamReaderFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size);
std::map<std::string, std::string> get_metadata() const;
SrcInfoPyBind get_src_stream_info(int64_t i);
void add_audio_stream(
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<std::map<std::string, std::string>>& decoder_option);
void add_video_stream(
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<std::map<std::string, std::string>>& decoder_option,
const c10::optional<std::string>& hw_accel);
};
} // namespace ffmpeg
......
......@@ -72,5 +72,25 @@ FileObj::FileObj(py::object fileobj_, int buffer_size)
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {}
c10::optional<OptionDict> map2dict(
const c10::optional<std::map<std::string, std::string>>& src) {
if (!src) {
return {};
}
OptionDict dict;
for (const auto& it : src.value()) {
dict.insert(it.first.c_str(), it.second.c_str());
}
return c10::optional<OptionDict>{dict};
}
std::map<std::string, std::string> dict2map(const OptionDict& src) {
std::map<std::string, std::string> ret;
for (const auto& it : src) {
ret.insert({it.key(), it.value()});
}
return ret;
}
} // namespace ffmpeg
} // namespace torchaudio
......@@ -12,5 +12,10 @@ struct FileObj {
FileObj(py::object fileobj, int buffer_size);
};
c10::optional<OptionDict> map2dict(
const c10::optional<std::map<std::string, std::string>>& src);
std::map<std::string, std::string> dict2map(const OptionDict& src);
} // namespace ffmpeg
} // namespace torchaudio
......@@ -68,7 +68,7 @@ const AVCodecHWConfig* get_cuda_config(const AVCodec* pCodec) {
void init_codec_context(
AVCodecContext* pCodecContext,
AVCodecParameters* pParams,
const OptionDict& decoder_option,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device,
AVBufferRefPtr& pHWBufferRef) {
int ret = avcodec_parameters_to_context(pCodecContext, pParams);
......@@ -128,7 +128,7 @@ void init_codec_context(
Decoder::Decoder(
AVCodecParameters* pParam,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device)
: pCodecContext(get_decode_context(pParam->codec_id, decoder_name)) {
init_codec_context(
......
......@@ -14,7 +14,7 @@ class Decoder {
Decoder(
AVCodecParameters* pParam,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device);
// Custom destructor to clean up the resources
~Decoder() = default;
......
......@@ -9,7 +9,7 @@ using KeyType = StreamProcessor::KeyType;
StreamProcessor::StreamProcessor(
AVCodecParameters* codecpar,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device)
: decoder(codecpar, decoder_name, decoder_option, device) {}
......
......@@ -28,7 +28,7 @@ class StreamProcessor {
StreamProcessor(
AVCodecParameters* codecpar,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device);
~StreamProcessor() = default;
// Non-copyable
......
......@@ -72,10 +72,9 @@ int64_t StreamReader::num_src_streams() const {
}
namespace {
c10::Dict<std::string, std::string> parse_metadata(
const AVDictionary* metadata) {
OptionDict parse_metadata(const AVDictionary* metadata) {
AVDictionaryEntry* tag = nullptr;
c10::Dict<std::string, std::string> ret;
OptionDict ret;
while ((tag = av_dict_get(metadata, "", tag, AV_DICT_IGNORE_SUFFIX))) {
ret.insert(std::string(tag->key), std::string(tag->value));
}
......@@ -83,7 +82,7 @@ c10::Dict<std::string, std::string> parse_metadata(
}
} // namespace
c10::Dict<std::string, std::string> StreamReader::get_metadata() const {
OptionDict StreamReader::get_metadata() const {
return parse_metadata(pFormatContext->metadata);
}
......@@ -188,7 +187,7 @@ void StreamReader::add_audio_stream(
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option) {
const c10::optional<OptionDict>& decoder_option) {
add_stream(
static_cast<int>(i),
AVMEDIA_TYPE_AUDIO,
......@@ -206,7 +205,7 @@ void StreamReader::add_video_stream(
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option,
const c10::optional<OptionDict>& decoder_option,
const c10::optional<std::string>& hw_accel) {
const torch::Device device = [&]() {
if (!hw_accel) {
......@@ -245,7 +244,7 @@ void StreamReader::add_stream(
int num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device) {
validate_src_stream_type(i, media_type);
......
......@@ -44,7 +44,7 @@ class StreamReader {
int64_t find_best_audio_stream() const;
int64_t find_best_video_stream() const;
// Fetch metadata of the source
c10::Dict<std::string, std::string> get_metadata() const;
OptionDict get_metadata() const;
// Fetch information about source streams
int64_t num_src_streams() const;
SrcStreamInfo get_src_stream_info(int i) const;
......@@ -65,14 +65,14 @@ class StreamReader {
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option);
const c10::optional<OptionDict>& decoder_option);
void add_video_stream(
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option,
const c10::optional<OptionDict>& decoder_option,
const c10::optional<std::string>& hw_accel);
void remove_stream(int64_t i);
......@@ -84,7 +84,7 @@ class StreamReader {
int num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device);
public:
......
......@@ -7,23 +7,12 @@ namespace ffmpeg {
namespace {
OptionDict map(const c10::optional<c10::Dict<std::string, std::string>>& dict) {
OptionDict ret;
if (!dict.has_value()) {
return ret;
}
for (const auto& it : dict.value()) {
ret.insert({it.key(), it.value()});
}
return ret;
}
c10::intrusive_ptr<StreamReaderBinding> init(
const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<c10::Dict<std::string, std::string>>& option) {
const c10::optional<OptionDict>& option) {
return c10::make_intrusive<StreamReaderBinding>(
get_input_format_context(src, device, map(option)));
get_input_format_context(src, device, option));
}
using S = const c10::intrusive_ptr<StreamReaderBinding>&;
......@@ -62,15 +51,14 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options) {
const c10::optional<OptionDict>& decoder_option) {
s->add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map(decoder_options));
decoder_option);
})
.def(
"add_video_stream",
......@@ -80,8 +68,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options,
const c10::optional<OptionDict>& decoder_option,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
i,
......@@ -89,7 +76,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
num_chunks,
filter_desc,
decoder,
map(decoder_options),
decoder_option,
hw_accel);
})
.def("remove_stream", [](S s, int64_t i) { s->remove_stream(i); })
......
......@@ -4,17 +4,6 @@ namespace torchaudio {
namespace ffmpeg {
namespace {
// TODO:
// merge the implementation with the one from stream_reader_binding.cpp
std::map<std::string, std::string> convert_map(
const c10::Dict<std::string, std::string>& src) {
std::map<std::string, std::string> ret;
for (const auto& it : src) {
ret.insert({it.key(), it.value()});
}
return ret;
}
SrcInfo convert(SrcStreamInfo ssi) {
return SrcInfo(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
......@@ -32,23 +21,6 @@ SrcInfo convert(SrcStreamInfo ssi) {
ssi.frame_rate));
}
SrcInfoPyBind convert_pybind(SrcStreamInfo ssi) {
return SrcInfoPyBind(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.num_frames,
ssi.bits_per_sample,
convert_map(ssi.metadata),
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}
OutInfo convert(OutputStreamInfo osi) {
return OutInfo(
std::forward_as_tuple(osi.source_index, osi.filter_description));
......@@ -58,7 +30,7 @@ OutInfo convert(OutputStreamInfo osi) {
AVFormatInputContextPtr get_input_format_context(
const std::string& src,
const c10::optional<std::string>& device,
const OptionDict& option,
const c10::optional<OptionDict>& option,
AVIOContext* io_ctx) {
AVFormatContext* pFormat = avformat_alloc_context();
if (!pFormat) {
......@@ -101,10 +73,6 @@ SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) {
return convert(StreamReader::get_src_stream_info(static_cast<int>(i)));
}
SrcInfoPyBind StreamReaderBinding::get_src_stream_info_pybind(int64_t i) {
return convert_pybind(StreamReader::get_src_stream_info(static_cast<int>(i)));
}
OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) {
return convert(StreamReader::get_out_stream_info(static_cast<int>(i)));
}
......
......@@ -9,7 +9,7 @@ namespace ffmpeg {
AVFormatInputContextPtr get_input_format_context(
const std::string& src,
const c10::optional<std::string>& device,
const OptionDict& option,
const c10::optional<OptionDict>& option,
AVIOContext* io_ctx = nullptr);
// Because TorchScript requires c10::Dict type to pass dict,
......@@ -28,7 +28,7 @@ using SrcInfo = std::tuple<
int64_t, // bit_rate
int64_t, // num_frames
int64_t, // bits_per_sample
c10::Dict<std::string, std::string>, // metadata
OptionDict, // metadata
// Audio
double, // sample_rate
int64_t, // num_channels
......@@ -67,7 +67,6 @@ struct StreamReaderBinding : public StreamReader,
public torch::CustomClassHolder {
explicit StreamReaderBinding(AVFormatInputContextPtr&& p);
SrcInfo get_src_stream_info(int64_t i);
SrcInfoPyBind get_src_stream_info_pybind(int64_t i);
OutInfo get_out_stream_info(int64_t i);
int64_t process_packet(
......
......@@ -14,7 +14,7 @@ struct SrcStreamInfo {
int64_t bit_rate = 0;
int64_t num_frames = 0;
int bits_per_sample = 0;
c10::Dict<std::string, std::string> metadata{};
OptionDict metadata{};
// Audio
double sample_rate = 0;
int num_channels = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment