Commit eed57534 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor Streamer implementation (#2402)

Summary:
* Move the helper wrapping code in TorchBind layer to proper wrapper class for so that it will be re-used in PyBind11.
* Move `add_basic_[audio|video]_stream` methods from C++ to Python, as they are just string manipulation. This will make PyBind11-based binding simpler as it needs not to deal with dtype.
* Move `add_[audio|video]_stream` wrapper signature to Streamer core, so that Streamer directly deals with `c10::optional`.†

† Related to this, there is a slight change in how the empty filter expression is stored. Originally, if an empty filter expression was given to `add_[audio|video]_stream` method, the `StreamReaderOutputStream` was showing it as empty string `""`, even though internally it was using `"anull"` or `"null"`. Now `StreamReaderOutputStream` shows the corresponding filter expression that is actually being used.

Ref https://github.com/pytorch/audio/issues/2400

Pull Request resolved: https://github.com/pytorch/audio/pull/2402

Reviewed By: nateanl

Differential Revision: D36488808

Pulled By: mthrok

fbshipit-source-id: 877ca731364d10fc0cb9d97e75d55df9180f2047
parent 647f28e4
......@@ -165,7 +165,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_audio_stream
assert sinfo.filter_description == ""
assert sinfo.filter_description == "anull"
sinfo = s.get_out_stream_info(1)
assert sinfo.source_index == s.default_audio_stream
......@@ -185,7 +185,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_video_stream
assert sinfo.filter_description == ""
assert sinfo.filter_description == "null"
sinfo = s.get_out_stream_info(1)
assert sinfo.source_index == s.default_video_stream
......
......@@ -181,6 +181,7 @@ if(USE_FFMPEG)
ffmpeg/sink.cpp
ffmpeg/stream_processor.cpp
ffmpeg/streamer.cpp
ffmpeg/stream_reader_wrapper.cpp
)
message(STATUS "FFMPEG_ROOT=$ENV{FFMPEG_ROOT}")
find_package(FFMPEG 4.1 REQUIRED COMPONENTS avdevice avfilter avformat avcodec avutil)
......
......@@ -8,8 +8,8 @@ namespace ffmpeg {
////////////////////////////////////////////////////////////////////////////////
Decoder::Decoder(
AVCodecParameters* pParam,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const torch::Device& device)
: pCodecContext(get_decode_context(pParam->codec_id, decoder_name)) {
init_codec_context(
......
......@@ -13,8 +13,8 @@ class Decoder {
// Default constructable
Decoder(
AVCodecParameters* pParam,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const torch::Device& device);
// Custom destructor to clean up the resources
~Decoder() = default;
......
......@@ -17,10 +17,9 @@ void AVFormatContextDeleter::operator()(AVFormatContext* p) {
namespace {
AVDictionary* get_option_dict(
const std::map<std::string, std::string>& option) {
AVDictionary* get_option_dict(const OptionDict& option) {
AVDictionary* opt = nullptr;
for (auto& it : option) {
for (const auto& it : option) {
av_dict_set(&opt, it.first.c_str(), it.second.c_str(), 0);
}
return opt;
......@@ -66,12 +65,25 @@ std::string join(std::vector<std::string> vars) {
AVFormatContextPtr get_input_format_context(
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option) {
const c10::optional<std::string>& device,
const OptionDict& option) {
AVFormatContext* pFormat = NULL;
AVINPUT_FORMAT_CONST AVInputFormat* pInput =
device.empty() ? NULL : av_find_input_format(device.c_str());
AVINPUT_FORMAT_CONST AVInputFormat* pInput = [&]() -> AVInputFormat* {
if (device.has_value()) {
std::string device_str = device.value();
AVINPUT_FORMAT_CONST AVInputFormat* p =
av_find_input_format(device_str.c_str());
if (!p) {
std::ostringstream msg;
msg << "Unsupported device: \"" << device_str << "\"";
throw std::runtime_error(msg.str());
}
return p;
}
return nullptr;
}();
AVDictionary* opt = get_option_dict(option);
int ret = avformat_open_input(&pFormat, src.c_str(), pInput, &opt);
......@@ -148,18 +160,18 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) {
namespace {
const AVCodec* get_decode_codec(
enum AVCodecID codec_id,
const std::string& decoder_name) {
const AVCodec* pCodec = decoder_name.empty()
const c10::optional<std::string>& decoder_name) {
const AVCodec* pCodec = !decoder_name.has_value()
? avcodec_find_decoder(codec_id)
: avcodec_find_decoder_by_name(decoder_name.c_str());
: avcodec_find_decoder_by_name(decoder_name.value().c_str());
if (!pCodec) {
std::stringstream ss;
if (decoder_name.empty()) {
if (!decoder_name.has_value()) {
ss << "Unsupported codec: \"" << avcodec_get_name(codec_id) << "\", ("
<< codec_id << ").";
} else {
ss << "Unsupported codec: \"" << decoder_name << "\".";
ss << "Unsupported codec: \"" << decoder_name.value() << "\".";
}
throw std::runtime_error(ss.str());
}
......@@ -170,7 +182,7 @@ const AVCodec* get_decode_codec(
AVCodecContextPtr get_decode_context(
enum AVCodecID codec_id,
const std::string& decoder_name) {
const c10::optional<std::string>& decoder_name) {
const AVCodec* pCodec = get_decode_codec(codec_id, decoder_name);
AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec);
......@@ -216,8 +228,8 @@ const AVCodecHWConfig* get_cuda_config(const AVCodec* pCodec) {
void init_codec_context(
AVCodecContext* pCodecContext,
AVCodecParameters* pParams,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const torch::Device& device,
AVBufferRefPtr& pHWBufferRef) {
const AVCodec* pCodec = get_decode_codec(pParams->codec_id, decoder_name);
......
......@@ -23,6 +23,8 @@ extern "C" {
namespace torchaudio {
namespace ffmpeg {
using OptionDict = std::map<std::string, std::string>;
// Replacement of av_err2str, which causes
// `error: taking address of temporary array`
// https://github.com/joncampbell123/composite-video-simulator/issues/5
......@@ -71,8 +73,8 @@ struct AVFormatContextPtr
// create format context for reading media
AVFormatContextPtr get_input_format_context(
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option);
const c10::optional<std::string>& device,
const OptionDict& option);
////////////////////////////////////////////////////////////////////////////////
// AVPacket
......@@ -141,14 +143,14 @@ struct AVCodecContextPtr
// Allocate codec context from either decoder name or ID
AVCodecContextPtr get_decode_context(
enum AVCodecID codec_id,
const std::string& decoder);
const c10::optional<std::string>& decoder);
// Initialize codec context with the parameters
void init_codec_context(
AVCodecContext* pCodecContext,
AVCodecParameters* pParams,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const torch::Device& device,
AVBufferRefPtr& pHWBufferRef);
......
......@@ -7,10 +7,11 @@ namespace ffmpeg {
FilterGraph::FilterGraph(
AVRational time_base,
AVCodecParameters* codecpar,
std::string filter_description)
const c10::optional<std::string>& filter_description)
: input_time_base(time_base),
codecpar(codecpar),
filter_description(std::move(filter_description)),
filter_description(filter_description.value_or(
codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")),
media_type(codecpar->codec_type) {
init();
}
......@@ -49,10 +50,10 @@ std::string get_video_src_args(
std::snprintf(
args,
sizeof(args),
"video_size=%dx%d:pix_fmt=%d:time_base=%d/%d:pixel_aspect=%d/%d",
"video_size=%dx%d:pix_fmt=%s:time_base=%d/%d:pixel_aspect=%d/%d",
codecpar->width,
codecpar->height,
static_cast<AVPixelFormat>(codecpar->format),
av_get_pix_fmt_name(static_cast<AVPixelFormat>(codecpar->format)),
time_base.num,
time_base.den,
codecpar->sample_aspect_ratio.num,
......@@ -165,16 +166,12 @@ void FilterGraph::add_process() {
// If you are debugging this part of the code, you might get confused.
InOuts in{"in", buffersrc_ctx}, out{"out", buffersink_ctx};
std::string desc = filter_description.empty()
? (media_type == AVMEDIA_TYPE_AUDIO) ? "anull" : "null"
: filter_description;
int ret =
avfilter_graph_parse_ptr(pFilterGraph, desc.c_str(), out, in, nullptr);
int ret = avfilter_graph_parse_ptr(
pFilterGraph, filter_description.c_str(), out, in, nullptr);
if (ret < 0) {
throw std::runtime_error(
"Failed to create the filter from \"" + desc + "\" (" +
"Failed to create the filter from \"" + filter_description + "\" (" +
av_err2string(ret) + ".)");
}
}
......
......@@ -24,7 +24,7 @@ class FilterGraph {
FilterGraph(
AVRational time_base,
AVCodecParameters* codecpar,
std::string filter_desc);
const c10::optional<std::string>& filter_desc);
// Custom destructor to release AVFilterGraph*
~FilterGraph() = default;
// Non-copyable
......
#include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/streamer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader_wrapper.h>
#include <stdexcept>
namespace torchaudio {
......@@ -7,357 +7,38 @@ namespace ffmpeg {
namespace {
using OptionDict = c10::Dict<std::string, std::string>;
std::map<std::string, std::string> convert_dict(
const c10::optional<OptionDict>& option) {
std::map<std::string, std::string> opts;
if (option) {
for (auto& it : option.value()) {
opts[it.key()] = it.value();
}
OptionDict map(const c10::optional<c10::Dict<std::string, std::string>>& dict) {
OptionDict ret;
if (!dict.has_value()) {
return ret;
}
for (const auto& it : dict.value()) {
ret.insert({it.key(), it.value()});
}
return opts;
return ret;
}
struct StreamerHolder : torch::CustomClassHolder {
Streamer s;
StreamerHolder(
const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option)
: s(src, device.value_or(""), convert_dict(option)) {}
};
using S = c10::intrusive_ptr<StreamerHolder>;
S init(
c10::intrusive_ptr<StreamReaderBinding> init(
const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option) {
return c10::make_intrusive<StreamerHolder>(src, device, option);
}
using SrcInfo = std::tuple<
std::string, // media_type
std::string, // codec name
std::string, // codec long name
std::string, // format name
int64_t, // bit_rate
// Audio
double, // sample_rate
int64_t, // num_channels
// Video
int64_t, // width
int64_t, // height
double // frame_rate
>;
SrcInfo convert(SrcStreamInfo ssi) {
return SrcInfo(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}
SrcInfo get_src_stream_info(S s, int64_t i) {
return convert(s->s.get_src_stream_info(i));
}
using OutInfo = std::tuple<
int64_t, // source index
std::string // filter description
>;
OutInfo convert(OutputStreamInfo osi) {
return OutInfo(
std::forward_as_tuple(osi.source_index, osi.filter_description));
}
OutInfo get_out_stream_info(S s, int64_t i) {
return convert(s->s.get_out_stream_info(i));
}
int64_t num_src_streams(S s) {
return s->s.num_src_streams();
}
int64_t num_out_streams(S s) {
return s->s.num_out_streams();
}
int64_t find_best_audio_stream(S s) {
return s->s.find_best_audio_stream();
}
int64_t find_best_video_stream(S s) {
return s->s.find_best_video_stream();
}
void seek(S s, double timestamp) {
s->s.seek(timestamp);
}
template <typename... Args>
std::string string_format(const std::string& format, Args... args) {
char buffer[512];
std::snprintf(buffer, sizeof(buffer), format.c_str(), args...);
return std::string(buffer);
}
std::string join(
const std::vector<std::string>& components,
const std::string& delim) {
std::ostringstream s;
for (int i = 0; i < components.size(); ++i) {
if (i)
s << delim;
s << components[i];
}
return s.str();
}
std::string get_afilter_desc(
const c10::optional<int64_t>& sample_rate,
const c10::optional<c10::ScalarType>& dtype) {
std::vector<std::string> components;
if (sample_rate) {
// TODO: test float sample rate
components.emplace_back(
string_format("aresample=%d", static_cast<int>(sample_rate.value())));
}
if (dtype) {
AVSampleFormat fmt = [&]() {
switch (dtype.value()) {
case c10::ScalarType::Byte:
return AV_SAMPLE_FMT_U8P;
case c10::ScalarType::Short:
return AV_SAMPLE_FMT_S16P;
case c10::ScalarType::Int:
return AV_SAMPLE_FMT_S32P;
case c10::ScalarType::Long:
return AV_SAMPLE_FMT_S64P;
case c10::ScalarType::Float:
return AV_SAMPLE_FMT_FLTP;
case c10::ScalarType::Double:
return AV_SAMPLE_FMT_DBLP;
default:
throw std::runtime_error("Unexpected dtype.");
}
}();
components.emplace_back(
string_format("aformat=sample_fmts=%s", av_get_sample_fmt_name(fmt)));
}
return join(components, ",");
}
std::string get_vfilter_desc(
const c10::optional<double>& frame_rate,
const c10::optional<int64_t>& width,
const c10::optional<int64_t>& height,
const c10::optional<std::string>& format) {
// TODO:
// - Add `flags` for different scale algorithm
// https://ffmpeg.org/ffmpeg-filters.html#scale
// - Consider `framerate` as well
// https://ffmpeg.org/ffmpeg-filters.html#framerate
// - scale
// https://ffmpeg.org/ffmpeg-filters.html#scale-1
// https://ffmpeg.org/ffmpeg-scaler.html#toc-Scaler-Options
// - framerate
// https://ffmpeg.org/ffmpeg-filters.html#framerate
// TODO:
// - format
// https://ffmpeg.org/ffmpeg-filters.html#toc-format-1
// - fps
// https://ffmpeg.org/ffmpeg-filters.html#fps-1
std::vector<std::string> components;
if (frame_rate)
components.emplace_back(string_format("fps=%lf", frame_rate.value()));
std::vector<std::string> scale_components;
if (width)
scale_components.emplace_back(string_format("width=%d", width.value()));
if (height)
scale_components.emplace_back(string_format("height=%d", height.value()));
if (scale_components.size())
components.emplace_back(
string_format("scale=%s", join(scale_components, ":").c_str()));
if (format) {
// TODO:
// Check other useful formats
// https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
AVPixelFormat fmt = [&]() {
const std::map<const std::string, enum AVPixelFormat> valid_choices {
{"RGB", AV_PIX_FMT_RGB24},
{"BGR", AV_PIX_FMT_BGR24},
{"YUV", AV_PIX_FMT_YUV420P},
{"GRAY", AV_PIX_FMT_GRAY8},
};
const std::string val = format.value();
if (valid_choices.find(val) == valid_choices.end()) {
std::stringstream ss;
ss << "Unexpected output video format: \"" << val << "\"."
<< "Valid choices are; ";
int i = 0;
for (const auto& p : valid_choices) {
if (i == 0) {
ss << "\"" << p.first << "\"";
} else {
ss << ", \"" << p.first << "\"";
}
}
throw std::runtime_error(ss.str());
}
return valid_choices.at(val);
}();
components.emplace_back(
string_format("format=pix_fmts=%s", av_get_pix_fmt_name(fmt)));
}
return join(components, ",");
};
void add_basic_audio_stream(
S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<int64_t>& sample_rate,
const c10::optional<c10::ScalarType>& dtype) {
std::string filter_desc = get_afilter_desc(sample_rate, dtype);
s->s.add_audio_stream(i, frames_per_chunk, num_chunks, filter_desc, "", {});
}
void add_basic_video_stream(
S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<double>& frame_rate,
const c10::optional<int64_t>& width,
const c10::optional<int64_t>& height,
const c10::optional<std::string>& format) {
std::string filter_desc = get_vfilter_desc(frame_rate, width, height, format);
s->s.add_video_stream(
static_cast<int>(i),
static_cast<int>(frames_per_chunk),
static_cast<int>(num_chunks),
std::move(filter_desc),
"",
{},
torch::Device(c10::DeviceType::CPU));
}
void add_audio_stream(
S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_options) {
s->s.add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc.value_or(""),
decoder.value_or(""),
convert_dict(decoder_options));
}
void add_video_stream(
S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_options,
const c10::optional<std::string>& hw_accel) {
const torch::Device device = [&]() {
if (!hw_accel) {
return torch::Device{c10::DeviceType::CPU};
}
#ifdef USE_CUDA
torch::Device d{hw_accel.value()};
if (d.type() != c10::DeviceType::CUDA) {
std::stringstream ss;
ss << "Only CUDA is supported for hardware acceleration. Found: "
<< device.str();
throw std::runtime_error(ss.str());
}
return d;
#else
throw std::runtime_error(
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}();
s->s.add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc.value_or(""),
decoder.value_or(""),
convert_dict(decoder_options),
device);
}
void remove_stream(S s, int64_t i) {
s->s.remove_stream(i);
}
int64_t process_packet(
Streamer& s,
const c10::optional<double>& timeout = c10::optional<double>(),
const double backoff = 10.) {
int64_t code = [&]() {
if (timeout.has_value()) {
return s.process_packet_block(timeout.value(), backoff);
}
return s.process_packet();
}();
if (code < 0) {
throw std::runtime_error(
"Failed to process a packet. (" + av_err2string(code) + "). ");
}
return code;
}
void process_all_packets(Streamer& s) {
int ret = 0;
do {
ret = process_packet(s);
} while (!ret);
}
bool is_buffer_ready(S s) {
return s->s.is_buffer_ready();
}
std::vector<c10::optional<torch::Tensor>> pop_chunks(S s) {
return s->s.pop_chunks();
const c10::optional<c10::Dict<std::string, std::string>>& option) {
return c10::make_intrusive<StreamReaderBinding>(
get_input_format_context(src, device, map(option)));
}
std::tuple<c10::optional<torch::Tensor>, int64_t> load(const std::string& src) {
Streamer s{src, "", {}};
StreamReaderBinding s{get_input_format_context(src, {}, {})};
int i = s.find_best_audio_stream();
auto sinfo = s.get_src_stream_info(i);
auto sinfo = s.Streamer::get_src_stream_info(i);
int64_t sample_rate = static_cast<int64_t>(sinfo.sample_rate);
s.add_audio_stream(i, -1, -1, "", "", {});
process_all_packets(s);
s.add_audio_stream(i, -1, -1, {}, {}, {});
s.process_all_packets();
auto tensors = s.pop_chunks();
return std::make_tuple<>(tensors[0], sample_rate);
}
using S = const c10::intrusive_ptr<StreamReaderBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::ffmpeg_init", []() {
avdevice_register_all();
......@@ -365,38 +46,84 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
av_log_set_level(AV_LOG_ERROR);
});
m.def("torchaudio::ffmpeg_load", load);
m.class_<StreamerHolder>("ffmpeg_Streamer");
m.class_<StreamReaderBinding>("ffmpeg_Streamer");
m.def("torchaudio::ffmpeg_streamer_init", init);
m.def("torchaudio::ffmpeg_streamer_num_src_streams", num_src_streams);
m.def("torchaudio::ffmpeg_streamer_num_out_streams", num_out_streams);
m.def("torchaudio::ffmpeg_streamer_get_src_stream_info", get_src_stream_info);
m.def("torchaudio::ffmpeg_streamer_get_out_stream_info", get_out_stream_info);
m.def(
"torchaudio::ffmpeg_streamer_find_best_audio_stream",
find_best_audio_stream);
m.def(
"torchaudio::ffmpeg_streamer_find_best_video_stream",
find_best_video_stream);
m.def("torchaudio::ffmpeg_streamer_seek", seek);
m.def("torchaudio::ffmpeg_streamer_num_src_streams", [](S s) {
return s->num_src_streams();
});
m.def("torchaudio::ffmpeg_streamer_num_out_streams", [](S s) {
return s->num_out_streams();
});
m.def("torchaudio::ffmpeg_streamer_get_src_stream_info", [](S s, int64_t i) {
return s->get_src_stream_info(i);
});
m.def("torchaudio::ffmpeg_streamer_get_out_stream_info", [](S s, int64_t i) {
return s->get_out_stream_info(i);
});
m.def("torchaudio::ffmpeg_streamer_find_best_audio_stream", [](S s) {
return s->find_best_audio_stream();
});
m.def("torchaudio::ffmpeg_streamer_find_best_video_stream", [](S s) {
return s->find_best_video_stream();
});
m.def("torchaudio::ffmpeg_streamer_seek", [](S s, double t) {
return s->seek(t);
});
m.def(
"torchaudio::ffmpeg_streamer_add_basic_audio_stream",
add_basic_audio_stream);
"torchaudio::ffmpeg_streamer_add_audio_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options) {
s->add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map(decoder_options));
});
m.def(
"torchaudio::ffmpeg_streamer_add_basic_video_stream",
add_basic_video_stream);
m.def("torchaudio::ffmpeg_streamer_add_audio_stream", add_audio_stream);
m.def("torchaudio::ffmpeg_streamer_add_video_stream", add_video_stream);
m.def("torchaudio::ffmpeg_streamer_remove_stream", remove_stream);
"torchaudio::ffmpeg_streamer_add_video_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map(decoder_options),
hw_accel);
});
m.def("torchaudio::ffmpeg_streamer_remove_stream", [](S s, int64_t i) {
s->remove_stream(i);
});
m.def(
"torchaudio::ffmpeg_streamer_process_packet",
[](S s, const c10::optional<double>& timeout, double backoff) {
return process_packet(s->s, timeout, backoff);
[](S s, const c10::optional<double>& timeout, const double backoff) {
return s->process_packet(timeout, backoff);
});
m.def("torchaudio::ffmpeg_streamer_process_all_packets", [](S s) {
return process_all_packets(s->s);
s->process_all_packets();
});
m.def("torchaudio::ffmpeg_streamer_is_buffer_ready", [](S s) {
return s->is_buffer_ready();
});
m.def("torchaudio::ffmpeg_streamer_pop_chunks", [](S s) {
return s->pop_chunks();
});
m.def("torchaudio::ffmpeg_streamer_is_buffer_ready", is_buffer_ready);
m.def("torchaudio::ffmpeg_streamer_pop_chunks", pop_chunks);
}
} // namespace
......
......@@ -30,9 +30,9 @@ Sink::Sink(
AVCodecParameters* codecpar,
int frames_per_chunk,
int num_chunks,
std::string filter_description,
const c10::optional<std::string>& filter_description,
const torch::Device& device)
: filter(input_time_base, codecpar, std::move(filter_description)),
: filter(input_time_base, codecpar, filter_description),
buffer(get_buffer(
codecpar->codec_type,
frames_per_chunk,
......
......@@ -18,7 +18,7 @@ class Sink {
AVCodecParameters* codecpar,
int frames_per_chunk,
int num_chunks,
std::string filter_description,
const c10::optional<std::string>& filter_description,
const torch::Device& device);
int process_frame(AVFrame* frame);
......
......@@ -8,8 +8,8 @@ using KeyType = StreamProcessor::KeyType;
StreamProcessor::StreamProcessor(
AVCodecParameters* codecpar,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const torch::Device& device)
: decoder(codecpar, decoder_name, decoder_option, device) {}
......@@ -21,7 +21,7 @@ KeyType StreamProcessor::add_stream(
AVCodecParameters* codecpar,
int frames_per_chunk,
int num_chunks,
std::string filter_description,
const c10::optional<std::string>& filter_description,
const torch::Device& device) {
switch (codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO:
......@@ -39,7 +39,7 @@ KeyType StreamProcessor::add_stream(
codecpar,
frames_per_chunk,
num_chunks,
std::move(filter_description),
filter_description,
device));
decoder_time_base = av_q2d(input_time_base);
return key;
......
......@@ -27,8 +27,8 @@ class StreamProcessor {
public:
StreamProcessor(
AVCodecParameters* codecpar,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const OptionDict& decoder_option,
const torch::Device& device);
~StreamProcessor() = default;
// Non-copyable
......@@ -52,7 +52,7 @@ class StreamProcessor {
AVCodecParameters* codecpar,
int frames_per_chunk,
int num_chunks,
std::string filter_description,
const c10::optional<std::string>& filter_description,
const torch::Device& device);
// 1. Remove the stream
......
#include <torchaudio/csrc/ffmpeg/stream_reader_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
SrcInfo convert(SrcStreamInfo ssi) {
return SrcInfo(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}
OutInfo convert(OutputStreamInfo osi) {
return OutInfo(
std::forward_as_tuple(osi.source_index, osi.filter_description));
}
} // namespace
StreamReaderBinding::StreamReaderBinding(AVFormatContextPtr&& p)
: Streamer(std::move(p)) {}
SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) {
return convert(Streamer::get_src_stream_info(i));
}
OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) {
return convert(Streamer::get_out_stream_info(i));
}
int64_t StreamReaderBinding::process_packet(
const c10::optional<double>& timeout,
const double backoff) {
int64_t code = [&]() {
if (timeout.has_value()) {
return Streamer::process_packet_block(timeout.value(), backoff);
}
return Streamer::process_packet();
}();
if (code < 0) {
throw std::runtime_error(
"Failed to process a packet. (" + av_err2string(code) + "). ");
}
return code;
}
void StreamReaderBinding::process_all_packets() {
int64_t ret = 0;
do {
ret = process_packet();
} while (!ret);
}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/streamer.h>
namespace torchaudio {
namespace ffmpeg {
using SrcInfo = std::tuple<
std::string, // media_type
std::string, // codec name
std::string, // codec long name
std::string, // format name
int64_t, // bit_rate
// Audio
double, // sample_rate
int64_t, // num_channels
// Video
int64_t, // width
int64_t, // height
double // frame_rate
>;
using OutInfo = std::tuple<
int64_t, // source index
std::string // filter description
>;
// Structure to implement wrapper API around Streamer, which is more suitable
// for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderBinding : public Streamer, public torch::CustomClassHolder {
explicit StreamReaderBinding(AVFormatContextPtr&& p);
SrcInfo get_src_stream_info(int64_t i);
OutInfo get_out_stream_info(int64_t i);
int64_t process_packet(
const c10::optional<double>& timeout = c10::optional<double>(),
const double backoff = 10.);
void process_all_packets();
};
} // namespace ffmpeg
} // namespace torchaudio
......@@ -42,11 +42,7 @@ void Streamer::validate_src_stream_type(int i, AVMediaType type) {
//////////////////////////////////////////////////////////////////////////////
// Initialization / resource allocations
//////////////////////////////////////////////////////////////////////////////
Streamer::Streamer(
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option)
: pFormatContext(get_input_format_context(src, device, option)) {
Streamer::Streamer(AVFormatContextPtr&& p) : pFormatContext(std::move(p)) {
if (avformat_find_stream_info(pFormatContext, nullptr) < 0) {
throw std::runtime_error("Failed to find stream information.");
}
......@@ -67,7 +63,7 @@ Streamer::Streamer(
////////////////////////////////////////////////////////////////////////////////
// Query methods
////////////////////////////////////////////////////////////////////////////////
int Streamer::num_src_streams() const {
int64_t Streamer::num_src_streams() const {
return pFormatContext->nb_streams;
}
......@@ -103,7 +99,7 @@ SrcStreamInfo Streamer::get_src_stream_info(int i) const {
return ret;
}
int Streamer::num_out_streams() const {
int64_t Streamer::num_out_streams() const {
return stream_indices.size();
}
......@@ -117,12 +113,12 @@ OutputStreamInfo Streamer::get_out_stream_info(int i) const {
return ret;
}
int Streamer::find_best_audio_stream() const {
int64_t Streamer::find_best_audio_stream() const {
return av_find_best_stream(
pFormatContext, AVMEDIA_TYPE_AUDIO, -1, -1, NULL, 0);
}
int Streamer::find_best_video_stream() const {
int64_t Streamer::find_best_video_stream() const {
return av_find_best_stream(
pFormatContext, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0);
}
......@@ -157,37 +153,56 @@ void Streamer::seek(double timestamp) {
}
void Streamer::add_audio_stream(
int i,
int frames_per_chunk,
int num_chunks,
std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option) {
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option) {
add_stream(
i,
AVMEDIA_TYPE_AUDIO,
frames_per_chunk,
num_chunks,
std::move(filter_desc),
filter_desc,
decoder,
decoder_option,
torch::Device(torch::DeviceType::CPU));
}
void Streamer::add_video_stream(
int i,
int frames_per_chunk,
int num_chunks,
std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option,
const torch::Device& device) {
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option,
const c10::optional<std::string>& hw_accel) {
const torch::Device device = [&]() {
if (!hw_accel) {
return torch::Device{c10::DeviceType::CPU};
}
#ifdef USE_CUDA
torch::Device d{hw_accel.value()};
if (d.type() != c10::DeviceType::CUDA) {
std::stringstream ss;
ss << "Only CUDA is supported for hardware acceleration. Found: "
<< device.str();
throw std::runtime_error(ss.str());
}
return d;
#else
throw std::runtime_error(
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}();
add_stream(
i,
AVMEDIA_TYPE_VIDEO,
frames_per_chunk,
num_chunks,
std::move(filter_desc),
filter_desc,
decoder,
decoder_option,
device);
......@@ -198,9 +213,9 @@ void Streamer::add_stream(
AVMediaType media_type,
int frames_per_chunk,
int num_chunks,
std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option,
const torch::Device& device) {
validate_src_stream_type(i, media_type);
......@@ -214,12 +229,12 @@ void Streamer::add_stream(
stream->codecpar,
frames_per_chunk,
num_chunks,
std::move(filter_desc),
filter_desc,
device);
stream_indices.push_back(std::make_pair<>(i, key));
}
void Streamer::remove_stream(int i) {
void Streamer::remove_stream(int64_t i) {
validate_output_stream_index(i);
auto it = stream_indices.begin() + i;
int iP = it->first;
......
......@@ -19,11 +19,7 @@ class Streamer {
std::vector<std::pair<int, int>> stream_indices;
public:
// Open the input and allocate the resource
Streamer(
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option);
explicit Streamer(AVFormatContextPtr&& p);
~Streamer() = default;
// Non-copyable
Streamer(const Streamer&) = delete;
......@@ -46,13 +42,13 @@ class Streamer {
//////////////////////////////////////////////////////////////////////////////
public:
// Find a suitable audio/video streams using heuristics from ffmpeg
int find_best_audio_stream() const;
int find_best_video_stream() const;
int64_t find_best_audio_stream() const;
int64_t find_best_video_stream() const;
// Fetch information about source streams
int num_src_streams() const;
int64_t num_src_streams() const;
SrcStreamInfo get_src_stream_info(int i) const;
// Fetch information about output streams
int num_out_streams() const;
int64_t num_out_streams() const;
OutputStreamInfo get_out_stream_info(int i) const;
// Check if all the buffers of the output streams are ready.
bool is_buffer_ready() const;
......@@ -63,21 +59,21 @@ class Streamer {
void seek(double timestamp);
void add_audio_stream(
int i,
int frames_per_chunk,
int num_chunks,
std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option);
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option);
void add_video_stream(
int i,
int frames_per_chunk,
int num_chunks,
std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option,
const torch::Device& device);
void remove_stream(int i);
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option,
const c10::optional<std::string>& hw_accel);
void remove_stream(int64_t i);
private:
void add_stream(
......@@ -85,9 +81,9 @@ class Streamer {
AVMediaType media_type,
int frames_per_chunk,
int num_chunks,
std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const OptionDict& decoder_option,
const torch::Device& device);
public:
......
......@@ -154,6 +154,45 @@ def _parse_oi(i):
return StreamReaderOutputStream(i[0], i[1])
def _get_afilter_desc(sample_rate: Optional[int], dtype: torch.dtype):
descs = []
if sample_rate is not None:
descs.append(f"aresample={sample_rate}")
if dtype is not None:
fmt = {
torch.uint8: "u8p",
torch.int16: "s16p",
torch.int32: "s32p",
torch.long: "s64p",
torch.float32: "fltp",
torch.float64: "dblp",
}[dtype]
descs.append(f"aformat=sample_fmts={fmt}")
return ",".join(descs) if descs else None
def _get_vfilter_desc(frame_rate: Optional[float], width: Optional[int], height: Optional[int], format: Optional[str]):
descs = []
if frame_rate is not None:
descs.append(f"fps={frame_rate}")
scales = []
if width is not None:
scales.append(f"width={width}")
if height is not None:
scales.append(f"height={height}")
if scales:
descs.append(f"scale={':'.join(scales)}")
if format is not None:
fmt = {
"RGB": "rgb24",
"BGR": "bgr24",
"YUV": "yuv420p",
"GRAY": "gray",
}[format]
descs.append(f"format=pix_fmts={fmt}")
return ",".join(descs) if descs else None
class StreamReader:
"""Fetch and decode audio/video streams chunk by chunk.
......@@ -297,8 +336,14 @@ class StreamReader:
`[-1, 1]`.
"""
i = self.default_audio_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_basic_audio_stream(
self._s, i, frames_per_chunk, buffer_chunk_size, sample_rate, dtype
torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream(
self._s,
i,
frames_per_chunk,
buffer_chunk_size,
_get_afilter_desc(sample_rate, dtype),
None,
None,
)
def add_basic_video_stream(
......@@ -338,15 +383,15 @@ class StreamReader:
- `GRAY`: 8 bits * 1 channels
"""
i = self.default_video_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_basic_video_stream(
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
self._s,
i,
frames_per_chunk,
buffer_chunk_size,
frame_rate,
width,
height,
format,
_get_vfilter_desc(frame_rate, width, height, format),
None,
None,
None,
)
def add_audio_stream(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment