Commit eed57534 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor Streamer implementation (#2402)

Summary:
* Move the helper wrapping code in TorchBind layer to proper wrapper class for so that it will be re-used in PyBind11.
* Move `add_basic_[audio|video]_stream` methods from C++ to Python, as they are just string manipulation. This will make PyBind11-based binding simpler as it needs not to deal with dtype.
* Move `add_[audio|video]_stream` wrapper signature to Streamer core, so that Streamer directly deals with `c10::optional`.†

† Related to this, there is a slight change in how the empty filter expression is stored. Originally, if an empty filter expression was given to `add_[audio|video]_stream` method, the `StreamReaderOutputStream` was showing it as empty string `""`, even though internally it was using `"anull"` or `"null"`. Now `StreamReaderOutputStream` shows the corresponding filter expression that is actually being used.

Ref https://github.com/pytorch/audio/issues/2400

Pull Request resolved: https://github.com/pytorch/audio/pull/2402

Reviewed By: nateanl

Differential Revision: D36488808

Pulled By: mthrok

fbshipit-source-id: 877ca731364d10fc0cb9d97e75d55df9180f2047
parent 647f28e4
...@@ -165,7 +165,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -165,7 +165,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
sinfo = s.get_out_stream_info(0) sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_audio_stream assert sinfo.source_index == s.default_audio_stream
assert sinfo.filter_description == "" assert sinfo.filter_description == "anull"
sinfo = s.get_out_stream_info(1) sinfo = s.get_out_stream_info(1)
assert sinfo.source_index == s.default_audio_stream assert sinfo.source_index == s.default_audio_stream
...@@ -185,7 +185,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -185,7 +185,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
sinfo = s.get_out_stream_info(0) sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_video_stream assert sinfo.source_index == s.default_video_stream
assert sinfo.filter_description == "" assert sinfo.filter_description == "null"
sinfo = s.get_out_stream_info(1) sinfo = s.get_out_stream_info(1)
assert sinfo.source_index == s.default_video_stream assert sinfo.source_index == s.default_video_stream
......
...@@ -181,6 +181,7 @@ if(USE_FFMPEG) ...@@ -181,6 +181,7 @@ if(USE_FFMPEG)
ffmpeg/sink.cpp ffmpeg/sink.cpp
ffmpeg/stream_processor.cpp ffmpeg/stream_processor.cpp
ffmpeg/streamer.cpp ffmpeg/streamer.cpp
ffmpeg/stream_reader_wrapper.cpp
) )
message(STATUS "FFMPEG_ROOT=$ENV{FFMPEG_ROOT}") message(STATUS "FFMPEG_ROOT=$ENV{FFMPEG_ROOT}")
find_package(FFMPEG 4.1 REQUIRED COMPONENTS avdevice avfilter avformat avcodec avutil) find_package(FFMPEG 4.1 REQUIRED COMPONENTS avdevice avfilter avformat avcodec avutil)
......
...@@ -8,8 +8,8 @@ namespace ffmpeg { ...@@ -8,8 +8,8 @@ namespace ffmpeg {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
Decoder::Decoder( Decoder::Decoder(
AVCodecParameters* pParam, AVCodecParameters* pParam,
const std::string& decoder_name, const c10::optional<std::string>& decoder_name,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device) const torch::Device& device)
: pCodecContext(get_decode_context(pParam->codec_id, decoder_name)) { : pCodecContext(get_decode_context(pParam->codec_id, decoder_name)) {
init_codec_context( init_codec_context(
......
...@@ -13,8 +13,8 @@ class Decoder { ...@@ -13,8 +13,8 @@ class Decoder {
// Default constructable // Default constructable
Decoder( Decoder(
AVCodecParameters* pParam, AVCodecParameters* pParam,
const std::string& decoder_name, const c10::optional<std::string>& decoder_name,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device); const torch::Device& device);
// Custom destructor to clean up the resources // Custom destructor to clean up the resources
~Decoder() = default; ~Decoder() = default;
......
...@@ -17,10 +17,9 @@ void AVFormatContextDeleter::operator()(AVFormatContext* p) { ...@@ -17,10 +17,9 @@ void AVFormatContextDeleter::operator()(AVFormatContext* p) {
namespace { namespace {
AVDictionary* get_option_dict( AVDictionary* get_option_dict(const OptionDict& option) {
const std::map<std::string, std::string>& option) {
AVDictionary* opt = nullptr; AVDictionary* opt = nullptr;
for (auto& it : option) { for (const auto& it : option) {
av_dict_set(&opt, it.first.c_str(), it.second.c_str(), 0); av_dict_set(&opt, it.first.c_str(), it.second.c_str(), 0);
} }
return opt; return opt;
...@@ -66,12 +65,25 @@ std::string join(std::vector<std::string> vars) { ...@@ -66,12 +65,25 @@ std::string join(std::vector<std::string> vars) {
AVFormatContextPtr get_input_format_context( AVFormatContextPtr get_input_format_context(
const std::string& src, const std::string& src,
const std::string& device, const c10::optional<std::string>& device,
const std::map<std::string, std::string>& option) { const OptionDict& option) {
AVFormatContext* pFormat = NULL; AVFormatContext* pFormat = NULL;
AVINPUT_FORMAT_CONST AVInputFormat* pInput = AVINPUT_FORMAT_CONST AVInputFormat* pInput = [&]() -> AVInputFormat* {
device.empty() ? NULL : av_find_input_format(device.c_str()); if (device.has_value()) {
std::string device_str = device.value();
AVINPUT_FORMAT_CONST AVInputFormat* p =
av_find_input_format(device_str.c_str());
if (!p) {
std::ostringstream msg;
msg << "Unsupported device: \"" << device_str << "\"";
throw std::runtime_error(msg.str());
}
return p;
}
return nullptr;
}();
AVDictionary* opt = get_option_dict(option); AVDictionary* opt = get_option_dict(option);
int ret = avformat_open_input(&pFormat, src.c_str(), pInput, &opt); int ret = avformat_open_input(&pFormat, src.c_str(), pInput, &opt);
...@@ -148,18 +160,18 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) { ...@@ -148,18 +160,18 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) {
namespace { namespace {
const AVCodec* get_decode_codec( const AVCodec* get_decode_codec(
enum AVCodecID codec_id, enum AVCodecID codec_id,
const std::string& decoder_name) { const c10::optional<std::string>& decoder_name) {
const AVCodec* pCodec = decoder_name.empty() const AVCodec* pCodec = !decoder_name.has_value()
? avcodec_find_decoder(codec_id) ? avcodec_find_decoder(codec_id)
: avcodec_find_decoder_by_name(decoder_name.c_str()); : avcodec_find_decoder_by_name(decoder_name.value().c_str());
if (!pCodec) { if (!pCodec) {
std::stringstream ss; std::stringstream ss;
if (decoder_name.empty()) { if (!decoder_name.has_value()) {
ss << "Unsupported codec: \"" << avcodec_get_name(codec_id) << "\", (" ss << "Unsupported codec: \"" << avcodec_get_name(codec_id) << "\", ("
<< codec_id << ")."; << codec_id << ").";
} else { } else {
ss << "Unsupported codec: \"" << decoder_name << "\"."; ss << "Unsupported codec: \"" << decoder_name.value() << "\".";
} }
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
...@@ -170,7 +182,7 @@ const AVCodec* get_decode_codec( ...@@ -170,7 +182,7 @@ const AVCodec* get_decode_codec(
AVCodecContextPtr get_decode_context( AVCodecContextPtr get_decode_context(
enum AVCodecID codec_id, enum AVCodecID codec_id,
const std::string& decoder_name) { const c10::optional<std::string>& decoder_name) {
const AVCodec* pCodec = get_decode_codec(codec_id, decoder_name); const AVCodec* pCodec = get_decode_codec(codec_id, decoder_name);
AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec); AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec);
...@@ -216,8 +228,8 @@ const AVCodecHWConfig* get_cuda_config(const AVCodec* pCodec) { ...@@ -216,8 +228,8 @@ const AVCodecHWConfig* get_cuda_config(const AVCodec* pCodec) {
void init_codec_context( void init_codec_context(
AVCodecContext* pCodecContext, AVCodecContext* pCodecContext,
AVCodecParameters* pParams, AVCodecParameters* pParams,
const std::string& decoder_name, const c10::optional<std::string>& decoder_name,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device, const torch::Device& device,
AVBufferRefPtr& pHWBufferRef) { AVBufferRefPtr& pHWBufferRef) {
const AVCodec* pCodec = get_decode_codec(pParams->codec_id, decoder_name); const AVCodec* pCodec = get_decode_codec(pParams->codec_id, decoder_name);
......
...@@ -23,6 +23,8 @@ extern "C" { ...@@ -23,6 +23,8 @@ extern "C" {
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
using OptionDict = std::map<std::string, std::string>;
// Replacement of av_err2str, which causes // Replacement of av_err2str, which causes
// `error: taking address of temporary array` // `error: taking address of temporary array`
// https://github.com/joncampbell123/composite-video-simulator/issues/5 // https://github.com/joncampbell123/composite-video-simulator/issues/5
...@@ -71,8 +73,8 @@ struct AVFormatContextPtr ...@@ -71,8 +73,8 @@ struct AVFormatContextPtr
// create format context for reading media // create format context for reading media
AVFormatContextPtr get_input_format_context( AVFormatContextPtr get_input_format_context(
const std::string& src, const std::string& src,
const std::string& device, const c10::optional<std::string>& device,
const std::map<std::string, std::string>& option); const OptionDict& option);
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVPacket // AVPacket
...@@ -141,14 +143,14 @@ struct AVCodecContextPtr ...@@ -141,14 +143,14 @@ struct AVCodecContextPtr
// Allocate codec context from either decoder name or ID // Allocate codec context from either decoder name or ID
AVCodecContextPtr get_decode_context( AVCodecContextPtr get_decode_context(
enum AVCodecID codec_id, enum AVCodecID codec_id,
const std::string& decoder); const c10::optional<std::string>& decoder);
// Initialize codec context with the parameters // Initialize codec context with the parameters
void init_codec_context( void init_codec_context(
AVCodecContext* pCodecContext, AVCodecContext* pCodecContext,
AVCodecParameters* pParams, AVCodecParameters* pParams,
const std::string& decoder_name, const c10::optional<std::string>& decoder_name,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device, const torch::Device& device,
AVBufferRefPtr& pHWBufferRef); AVBufferRefPtr& pHWBufferRef);
......
...@@ -7,10 +7,11 @@ namespace ffmpeg { ...@@ -7,10 +7,11 @@ namespace ffmpeg {
FilterGraph::FilterGraph( FilterGraph::FilterGraph(
AVRational time_base, AVRational time_base,
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
std::string filter_description) const c10::optional<std::string>& filter_description)
: input_time_base(time_base), : input_time_base(time_base),
codecpar(codecpar), codecpar(codecpar),
filter_description(std::move(filter_description)), filter_description(filter_description.value_or(
codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")),
media_type(codecpar->codec_type) { media_type(codecpar->codec_type) {
init(); init();
} }
...@@ -49,10 +50,10 @@ std::string get_video_src_args( ...@@ -49,10 +50,10 @@ std::string get_video_src_args(
std::snprintf( std::snprintf(
args, args,
sizeof(args), sizeof(args),
"video_size=%dx%d:pix_fmt=%d:time_base=%d/%d:pixel_aspect=%d/%d", "video_size=%dx%d:pix_fmt=%s:time_base=%d/%d:pixel_aspect=%d/%d",
codecpar->width, codecpar->width,
codecpar->height, codecpar->height,
static_cast<AVPixelFormat>(codecpar->format), av_get_pix_fmt_name(static_cast<AVPixelFormat>(codecpar->format)),
time_base.num, time_base.num,
time_base.den, time_base.den,
codecpar->sample_aspect_ratio.num, codecpar->sample_aspect_ratio.num,
...@@ -165,16 +166,12 @@ void FilterGraph::add_process() { ...@@ -165,16 +166,12 @@ void FilterGraph::add_process() {
// If you are debugging this part of the code, you might get confused. // If you are debugging this part of the code, you might get confused.
InOuts in{"in", buffersrc_ctx}, out{"out", buffersink_ctx}; InOuts in{"in", buffersrc_ctx}, out{"out", buffersink_ctx};
std::string desc = filter_description.empty() int ret = avfilter_graph_parse_ptr(
? (media_type == AVMEDIA_TYPE_AUDIO) ? "anull" : "null" pFilterGraph, filter_description.c_str(), out, in, nullptr);
: filter_description;
int ret =
avfilter_graph_parse_ptr(pFilterGraph, desc.c_str(), out, in, nullptr);
if (ret < 0) { if (ret < 0) {
throw std::runtime_error( throw std::runtime_error(
"Failed to create the filter from \"" + desc + "\" (" + "Failed to create the filter from \"" + filter_description + "\" (" +
av_err2string(ret) + ".)"); av_err2string(ret) + ".)");
} }
} }
......
...@@ -24,7 +24,7 @@ class FilterGraph { ...@@ -24,7 +24,7 @@ class FilterGraph {
FilterGraph( FilterGraph(
AVRational time_base, AVRational time_base,
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
std::string filter_desc); const c10::optional<std::string>& filter_desc);
// Custom destructor to release AVFilterGraph* // Custom destructor to release AVFilterGraph*
~FilterGraph() = default; ~FilterGraph() = default;
// Non-copyable // Non-copyable
......
#include <torch/script.h> #include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/streamer.h> #include <torchaudio/csrc/ffmpeg/stream_reader_wrapper.h>
#include <stdexcept> #include <stdexcept>
namespace torchaudio { namespace torchaudio {
...@@ -7,357 +7,38 @@ namespace ffmpeg { ...@@ -7,357 +7,38 @@ namespace ffmpeg {
namespace { namespace {
using OptionDict = c10::Dict<std::string, std::string>; OptionDict map(const c10::optional<c10::Dict<std::string, std::string>>& dict) {
OptionDict ret;
std::map<std::string, std::string> convert_dict( if (!dict.has_value()) {
const c10::optional<OptionDict>& option) { return ret;
std::map<std::string, std::string> opts; }
if (option) { for (const auto& it : dict.value()) {
for (auto& it : option.value()) { ret.insert({it.key(), it.value()});
opts[it.key()] = it.value();
}
} }
return opts; return ret;
} }
struct StreamerHolder : torch::CustomClassHolder { c10::intrusive_ptr<StreamReaderBinding> init(
Streamer s;
StreamerHolder(
const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option)
: s(src, device.value_or(""), convert_dict(option)) {}
};
using S = c10::intrusive_ptr<StreamerHolder>;
S init(
const std::string& src, const std::string& src,
const c10::optional<std::string>& device, const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option) { const c10::optional<c10::Dict<std::string, std::string>>& option) {
return c10::make_intrusive<StreamerHolder>(src, device, option); return c10::make_intrusive<StreamReaderBinding>(
} get_input_format_context(src, device, map(option)));
using SrcInfo = std::tuple<
std::string, // media_type
std::string, // codec name
std::string, // codec long name
std::string, // format name
int64_t, // bit_rate
// Audio
double, // sample_rate
int64_t, // num_channels
// Video
int64_t, // width
int64_t, // height
double // frame_rate
>;
SrcInfo convert(SrcStreamInfo ssi) {
return SrcInfo(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}
SrcInfo get_src_stream_info(S s, int64_t i) {
return convert(s->s.get_src_stream_info(i));
}
using OutInfo = std::tuple<
int64_t, // source index
std::string // filter description
>;
OutInfo convert(OutputStreamInfo osi) {
return OutInfo(
std::forward_as_tuple(osi.source_index, osi.filter_description));
}
OutInfo get_out_stream_info(S s, int64_t i) {
return convert(s->s.get_out_stream_info(i));
}
int64_t num_src_streams(S s) {
return s->s.num_src_streams();
}
int64_t num_out_streams(S s) {
return s->s.num_out_streams();
}
int64_t find_best_audio_stream(S s) {
return s->s.find_best_audio_stream();
}
int64_t find_best_video_stream(S s) {
return s->s.find_best_video_stream();
}
void seek(S s, double timestamp) {
s->s.seek(timestamp);
}
template <typename... Args>
std::string string_format(const std::string& format, Args... args) {
char buffer[512];
std::snprintf(buffer, sizeof(buffer), format.c_str(), args...);
return std::string(buffer);
}
std::string join(
const std::vector<std::string>& components,
const std::string& delim) {
std::ostringstream s;
for (int i = 0; i < components.size(); ++i) {
if (i)
s << delim;
s << components[i];
}
return s.str();
}
std::string get_afilter_desc(
const c10::optional<int64_t>& sample_rate,
const c10::optional<c10::ScalarType>& dtype) {
std::vector<std::string> components;
if (sample_rate) {
// TODO: test float sample rate
components.emplace_back(
string_format("aresample=%d", static_cast<int>(sample_rate.value())));
}
if (dtype) {
AVSampleFormat fmt = [&]() {
switch (dtype.value()) {
case c10::ScalarType::Byte:
return AV_SAMPLE_FMT_U8P;
case c10::ScalarType::Short:
return AV_SAMPLE_FMT_S16P;
case c10::ScalarType::Int:
return AV_SAMPLE_FMT_S32P;
case c10::ScalarType::Long:
return AV_SAMPLE_FMT_S64P;
case c10::ScalarType::Float:
return AV_SAMPLE_FMT_FLTP;
case c10::ScalarType::Double:
return AV_SAMPLE_FMT_DBLP;
default:
throw std::runtime_error("Unexpected dtype.");
}
}();
components.emplace_back(
string_format("aformat=sample_fmts=%s", av_get_sample_fmt_name(fmt)));
}
return join(components, ",");
}
std::string get_vfilter_desc(
const c10::optional<double>& frame_rate,
const c10::optional<int64_t>& width,
const c10::optional<int64_t>& height,
const c10::optional<std::string>& format) {
// TODO:
// - Add `flags` for different scale algorithm
// https://ffmpeg.org/ffmpeg-filters.html#scale
// - Consider `framerate` as well
// https://ffmpeg.org/ffmpeg-filters.html#framerate
// - scale
// https://ffmpeg.org/ffmpeg-filters.html#scale-1
// https://ffmpeg.org/ffmpeg-scaler.html#toc-Scaler-Options
// - framerate
// https://ffmpeg.org/ffmpeg-filters.html#framerate
// TODO:
// - format
// https://ffmpeg.org/ffmpeg-filters.html#toc-format-1
// - fps
// https://ffmpeg.org/ffmpeg-filters.html#fps-1
std::vector<std::string> components;
if (frame_rate)
components.emplace_back(string_format("fps=%lf", frame_rate.value()));
std::vector<std::string> scale_components;
if (width)
scale_components.emplace_back(string_format("width=%d", width.value()));
if (height)
scale_components.emplace_back(string_format("height=%d", height.value()));
if (scale_components.size())
components.emplace_back(
string_format("scale=%s", join(scale_components, ":").c_str()));
if (format) {
// TODO:
// Check other useful formats
// https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
AVPixelFormat fmt = [&]() {
const std::map<const std::string, enum AVPixelFormat> valid_choices {
{"RGB", AV_PIX_FMT_RGB24},
{"BGR", AV_PIX_FMT_BGR24},
{"YUV", AV_PIX_FMT_YUV420P},
{"GRAY", AV_PIX_FMT_GRAY8},
};
const std::string val = format.value();
if (valid_choices.find(val) == valid_choices.end()) {
std::stringstream ss;
ss << "Unexpected output video format: \"" << val << "\"."
<< "Valid choices are; ";
int i = 0;
for (const auto& p : valid_choices) {
if (i == 0) {
ss << "\"" << p.first << "\"";
} else {
ss << ", \"" << p.first << "\"";
}
}
throw std::runtime_error(ss.str());
}
return valid_choices.at(val);
}();
components.emplace_back(
string_format("format=pix_fmts=%s", av_get_pix_fmt_name(fmt)));
}
return join(components, ",");
};
void add_basic_audio_stream(
S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<int64_t>& sample_rate,
const c10::optional<c10::ScalarType>& dtype) {
std::string filter_desc = get_afilter_desc(sample_rate, dtype);
s->s.add_audio_stream(i, frames_per_chunk, num_chunks, filter_desc, "", {});
}
void add_basic_video_stream(
S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<double>& frame_rate,
const c10::optional<int64_t>& width,
const c10::optional<int64_t>& height,
const c10::optional<std::string>& format) {
std::string filter_desc = get_vfilter_desc(frame_rate, width, height, format);
s->s.add_video_stream(
static_cast<int>(i),
static_cast<int>(frames_per_chunk),
static_cast<int>(num_chunks),
std::move(filter_desc),
"",
{},
torch::Device(c10::DeviceType::CPU));
}
void add_audio_stream(
S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_options) {
s->s.add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc.value_or(""),
decoder.value_or(""),
convert_dict(decoder_options));
}
void add_video_stream(
S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_options,
const c10::optional<std::string>& hw_accel) {
const torch::Device device = [&]() {
if (!hw_accel) {
return torch::Device{c10::DeviceType::CPU};
}
#ifdef USE_CUDA
torch::Device d{hw_accel.value()};
if (d.type() != c10::DeviceType::CUDA) {
std::stringstream ss;
ss << "Only CUDA is supported for hardware acceleration. Found: "
<< device.str();
throw std::runtime_error(ss.str());
}
return d;
#else
throw std::runtime_error(
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}();
s->s.add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc.value_or(""),
decoder.value_or(""),
convert_dict(decoder_options),
device);
}
void remove_stream(S s, int64_t i) {
s->s.remove_stream(i);
}
int64_t process_packet(
Streamer& s,
const c10::optional<double>& timeout = c10::optional<double>(),
const double backoff = 10.) {
int64_t code = [&]() {
if (timeout.has_value()) {
return s.process_packet_block(timeout.value(), backoff);
}
return s.process_packet();
}();
if (code < 0) {
throw std::runtime_error(
"Failed to process a packet. (" + av_err2string(code) + "). ");
}
return code;
}
void process_all_packets(Streamer& s) {
int ret = 0;
do {
ret = process_packet(s);
} while (!ret);
}
bool is_buffer_ready(S s) {
return s->s.is_buffer_ready();
}
std::vector<c10::optional<torch::Tensor>> pop_chunks(S s) {
return s->s.pop_chunks();
} }
std::tuple<c10::optional<torch::Tensor>, int64_t> load(const std::string& src) { std::tuple<c10::optional<torch::Tensor>, int64_t> load(const std::string& src) {
Streamer s{src, "", {}}; StreamReaderBinding s{get_input_format_context(src, {}, {})};
int i = s.find_best_audio_stream(); int i = s.find_best_audio_stream();
auto sinfo = s.get_src_stream_info(i); auto sinfo = s.Streamer::get_src_stream_info(i);
int64_t sample_rate = static_cast<int64_t>(sinfo.sample_rate); int64_t sample_rate = static_cast<int64_t>(sinfo.sample_rate);
s.add_audio_stream(i, -1, -1, "", "", {}); s.add_audio_stream(i, -1, -1, {}, {}, {});
process_all_packets(s); s.process_all_packets();
auto tensors = s.pop_chunks(); auto tensors = s.pop_chunks();
return std::make_tuple<>(tensors[0], sample_rate); return std::make_tuple<>(tensors[0], sample_rate);
} }
using S = const c10::intrusive_ptr<StreamReaderBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::ffmpeg_init", []() { m.def("torchaudio::ffmpeg_init", []() {
avdevice_register_all(); avdevice_register_all();
...@@ -365,38 +46,84 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -365,38 +46,84 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
av_log_set_level(AV_LOG_ERROR); av_log_set_level(AV_LOG_ERROR);
}); });
m.def("torchaudio::ffmpeg_load", load); m.def("torchaudio::ffmpeg_load", load);
m.class_<StreamerHolder>("ffmpeg_Streamer"); m.class_<StreamReaderBinding>("ffmpeg_Streamer");
m.def("torchaudio::ffmpeg_streamer_init", init); m.def("torchaudio::ffmpeg_streamer_init", init);
m.def("torchaudio::ffmpeg_streamer_num_src_streams", num_src_streams); m.def("torchaudio::ffmpeg_streamer_num_src_streams", [](S s) {
m.def("torchaudio::ffmpeg_streamer_num_out_streams", num_out_streams); return s->num_src_streams();
m.def("torchaudio::ffmpeg_streamer_get_src_stream_info", get_src_stream_info); });
m.def("torchaudio::ffmpeg_streamer_get_out_stream_info", get_out_stream_info); m.def("torchaudio::ffmpeg_streamer_num_out_streams", [](S s) {
m.def( return s->num_out_streams();
"torchaudio::ffmpeg_streamer_find_best_audio_stream", });
find_best_audio_stream); m.def("torchaudio::ffmpeg_streamer_get_src_stream_info", [](S s, int64_t i) {
m.def( return s->get_src_stream_info(i);
"torchaudio::ffmpeg_streamer_find_best_video_stream", });
find_best_video_stream); m.def("torchaudio::ffmpeg_streamer_get_out_stream_info", [](S s, int64_t i) {
m.def("torchaudio::ffmpeg_streamer_seek", seek); return s->get_out_stream_info(i);
});
m.def("torchaudio::ffmpeg_streamer_find_best_audio_stream", [](S s) {
return s->find_best_audio_stream();
});
m.def("torchaudio::ffmpeg_streamer_find_best_video_stream", [](S s) {
return s->find_best_video_stream();
});
m.def("torchaudio::ffmpeg_streamer_seek", [](S s, double t) {
return s->seek(t);
});
m.def( m.def(
"torchaudio::ffmpeg_streamer_add_basic_audio_stream", "torchaudio::ffmpeg_streamer_add_audio_stream",
add_basic_audio_stream); [](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options) {
s->add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map(decoder_options));
});
m.def( m.def(
"torchaudio::ffmpeg_streamer_add_basic_video_stream", "torchaudio::ffmpeg_streamer_add_video_stream",
add_basic_video_stream); [](S s,
m.def("torchaudio::ffmpeg_streamer_add_audio_stream", add_audio_stream); int64_t i,
m.def("torchaudio::ffmpeg_streamer_add_video_stream", add_video_stream); int64_t frames_per_chunk,
m.def("torchaudio::ffmpeg_streamer_remove_stream", remove_stream); int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map(decoder_options),
hw_accel);
});
m.def("torchaudio::ffmpeg_streamer_remove_stream", [](S s, int64_t i) {
s->remove_stream(i);
});
m.def( m.def(
"torchaudio::ffmpeg_streamer_process_packet", "torchaudio::ffmpeg_streamer_process_packet",
[](S s, const c10::optional<double>& timeout, double backoff) { [](S s, const c10::optional<double>& timeout, const double backoff) {
return process_packet(s->s, timeout, backoff); return s->process_packet(timeout, backoff);
}); });
m.def("torchaudio::ffmpeg_streamer_process_all_packets", [](S s) { m.def("torchaudio::ffmpeg_streamer_process_all_packets", [](S s) {
return process_all_packets(s->s); s->process_all_packets();
});
m.def("torchaudio::ffmpeg_streamer_is_buffer_ready", [](S s) {
return s->is_buffer_ready();
});
m.def("torchaudio::ffmpeg_streamer_pop_chunks", [](S s) {
return s->pop_chunks();
}); });
m.def("torchaudio::ffmpeg_streamer_is_buffer_ready", is_buffer_ready);
m.def("torchaudio::ffmpeg_streamer_pop_chunks", pop_chunks);
} }
} // namespace } // namespace
......
...@@ -30,9 +30,9 @@ Sink::Sink( ...@@ -30,9 +30,9 @@ Sink::Sink(
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_description, const c10::optional<std::string>& filter_description,
const torch::Device& device) const torch::Device& device)
: filter(input_time_base, codecpar, std::move(filter_description)), : filter(input_time_base, codecpar, filter_description),
buffer(get_buffer( buffer(get_buffer(
codecpar->codec_type, codecpar->codec_type,
frames_per_chunk, frames_per_chunk,
......
...@@ -18,7 +18,7 @@ class Sink { ...@@ -18,7 +18,7 @@ class Sink {
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_description, const c10::optional<std::string>& filter_description,
const torch::Device& device); const torch::Device& device);
int process_frame(AVFrame* frame); int process_frame(AVFrame* frame);
......
...@@ -8,8 +8,8 @@ using KeyType = StreamProcessor::KeyType; ...@@ -8,8 +8,8 @@ using KeyType = StreamProcessor::KeyType;
StreamProcessor::StreamProcessor( StreamProcessor::StreamProcessor(
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
const std::string& decoder_name, const c10::optional<std::string>& decoder_name,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device) const torch::Device& device)
: decoder(codecpar, decoder_name, decoder_option, device) {} : decoder(codecpar, decoder_name, decoder_option, device) {}
...@@ -21,7 +21,7 @@ KeyType StreamProcessor::add_stream( ...@@ -21,7 +21,7 @@ KeyType StreamProcessor::add_stream(
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_description, const c10::optional<std::string>& filter_description,
const torch::Device& device) { const torch::Device& device) {
switch (codecpar->codec_type) { switch (codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO: case AVMEDIA_TYPE_AUDIO:
...@@ -39,7 +39,7 @@ KeyType StreamProcessor::add_stream( ...@@ -39,7 +39,7 @@ KeyType StreamProcessor::add_stream(
codecpar, codecpar,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
std::move(filter_description), filter_description,
device)); device));
decoder_time_base = av_q2d(input_time_base); decoder_time_base = av_q2d(input_time_base);
return key; return key;
......
...@@ -27,8 +27,8 @@ class StreamProcessor { ...@@ -27,8 +27,8 @@ class StreamProcessor {
public: public:
StreamProcessor( StreamProcessor(
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
const std::string& decoder_name, const c10::optional<std::string>& decoder_name,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device); const torch::Device& device);
~StreamProcessor() = default; ~StreamProcessor() = default;
// Non-copyable // Non-copyable
...@@ -52,7 +52,7 @@ class StreamProcessor { ...@@ -52,7 +52,7 @@ class StreamProcessor {
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_description, const c10::optional<std::string>& filter_description,
const torch::Device& device); const torch::Device& device);
// 1. Remove the stream // 1. Remove the stream
......
#include <torchaudio/csrc/ffmpeg/stream_reader_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
SrcInfo convert(SrcStreamInfo ssi) {
return SrcInfo(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}
OutInfo convert(OutputStreamInfo osi) {
return OutInfo(
std::forward_as_tuple(osi.source_index, osi.filter_description));
}
} // namespace
StreamReaderBinding::StreamReaderBinding(AVFormatContextPtr&& p)
: Streamer(std::move(p)) {}
SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) {
return convert(Streamer::get_src_stream_info(i));
}
OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) {
return convert(Streamer::get_out_stream_info(i));
}
int64_t StreamReaderBinding::process_packet(
const c10::optional<double>& timeout,
const double backoff) {
int64_t code = [&]() {
if (timeout.has_value()) {
return Streamer::process_packet_block(timeout.value(), backoff);
}
return Streamer::process_packet();
}();
if (code < 0) {
throw std::runtime_error(
"Failed to process a packet. (" + av_err2string(code) + "). ");
}
return code;
}
void StreamReaderBinding::process_all_packets() {
int64_t ret = 0;
do {
ret = process_packet();
} while (!ret);
}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/streamer.h>
namespace torchaudio {
namespace ffmpeg {
using SrcInfo = std::tuple<
std::string, // media_type
std::string, // codec name
std::string, // codec long name
std::string, // format name
int64_t, // bit_rate
// Audio
double, // sample_rate
int64_t, // num_channels
// Video
int64_t, // width
int64_t, // height
double // frame_rate
>;
using OutInfo = std::tuple<
int64_t, // source index
std::string // filter description
>;
// Structure to implement wrapper API around Streamer, which is more suitable
// for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderBinding : public Streamer, public torch::CustomClassHolder {
explicit StreamReaderBinding(AVFormatContextPtr&& p);
SrcInfo get_src_stream_info(int64_t i);
OutInfo get_out_stream_info(int64_t i);
int64_t process_packet(
const c10::optional<double>& timeout = c10::optional<double>(),
const double backoff = 10.);
void process_all_packets();
};
} // namespace ffmpeg
} // namespace torchaudio
...@@ -42,11 +42,7 @@ void Streamer::validate_src_stream_type(int i, AVMediaType type) { ...@@ -42,11 +42,7 @@ void Streamer::validate_src_stream_type(int i, AVMediaType type) {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Initialization / resource allocations // Initialization / resource allocations
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
Streamer::Streamer( Streamer::Streamer(AVFormatContextPtr&& p) : pFormatContext(std::move(p)) {
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option)
: pFormatContext(get_input_format_context(src, device, option)) {
if (avformat_find_stream_info(pFormatContext, nullptr) < 0) { if (avformat_find_stream_info(pFormatContext, nullptr) < 0) {
throw std::runtime_error("Failed to find stream information."); throw std::runtime_error("Failed to find stream information.");
} }
...@@ -67,7 +63,7 @@ Streamer::Streamer( ...@@ -67,7 +63,7 @@ Streamer::Streamer(
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Query methods // Query methods
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
int Streamer::num_src_streams() const { int64_t Streamer::num_src_streams() const {
return pFormatContext->nb_streams; return pFormatContext->nb_streams;
} }
...@@ -103,7 +99,7 @@ SrcStreamInfo Streamer::get_src_stream_info(int i) const { ...@@ -103,7 +99,7 @@ SrcStreamInfo Streamer::get_src_stream_info(int i) const {
return ret; return ret;
} }
int Streamer::num_out_streams() const { int64_t Streamer::num_out_streams() const {
return stream_indices.size(); return stream_indices.size();
} }
...@@ -117,12 +113,12 @@ OutputStreamInfo Streamer::get_out_stream_info(int i) const { ...@@ -117,12 +113,12 @@ OutputStreamInfo Streamer::get_out_stream_info(int i) const {
return ret; return ret;
} }
int Streamer::find_best_audio_stream() const { int64_t Streamer::find_best_audio_stream() const {
return av_find_best_stream( return av_find_best_stream(
pFormatContext, AVMEDIA_TYPE_AUDIO, -1, -1, NULL, 0); pFormatContext, AVMEDIA_TYPE_AUDIO, -1, -1, NULL, 0);
} }
int Streamer::find_best_video_stream() const { int64_t Streamer::find_best_video_stream() const {
return av_find_best_stream( return av_find_best_stream(
pFormatContext, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0); pFormatContext, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0);
} }
...@@ -157,37 +153,56 @@ void Streamer::seek(double timestamp) { ...@@ -157,37 +153,56 @@ void Streamer::seek(double timestamp) {
} }
void Streamer::add_audio_stream( void Streamer::add_audio_stream(
int i, int64_t i,
int frames_per_chunk, int64_t frames_per_chunk,
int num_chunks, int64_t num_chunks,
std::string filter_desc, const c10::optional<std::string>& filter_desc,
const std::string& decoder, const c10::optional<std::string>& decoder,
const std::map<std::string, std::string>& decoder_option) { const OptionDict& decoder_option) {
add_stream( add_stream(
i, i,
AVMEDIA_TYPE_AUDIO, AVMEDIA_TYPE_AUDIO,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
std::move(filter_desc), filter_desc,
decoder, decoder,
decoder_option, decoder_option,
torch::Device(torch::DeviceType::CPU)); torch::Device(torch::DeviceType::CPU));
} }
void Streamer::add_video_stream( void Streamer::add_video_stream(
int i, int64_t i,
int frames_per_chunk, int64_t frames_per_chunk,
int num_chunks, int64_t num_chunks,
std::string filter_desc, const c10::optional<std::string>& filter_desc,
const std::string& decoder, const c10::optional<std::string>& decoder,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device) { const c10::optional<std::string>& hw_accel) {
const torch::Device device = [&]() {
if (!hw_accel) {
return torch::Device{c10::DeviceType::CPU};
}
#ifdef USE_CUDA
torch::Device d{hw_accel.value()};
if (d.type() != c10::DeviceType::CUDA) {
std::stringstream ss;
ss << "Only CUDA is supported for hardware acceleration. Found: "
<< device.str();
throw std::runtime_error(ss.str());
}
return d;
#else
throw std::runtime_error(
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}();
add_stream( add_stream(
i, i,
AVMEDIA_TYPE_VIDEO, AVMEDIA_TYPE_VIDEO,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
std::move(filter_desc), filter_desc,
decoder, decoder,
decoder_option, decoder_option,
device); device);
...@@ -198,9 +213,9 @@ void Streamer::add_stream( ...@@ -198,9 +213,9 @@ void Streamer::add_stream(
AVMediaType media_type, AVMediaType media_type,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_desc, const c10::optional<std::string>& filter_desc,
const std::string& decoder, const c10::optional<std::string>& decoder,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device) { const torch::Device& device) {
validate_src_stream_type(i, media_type); validate_src_stream_type(i, media_type);
...@@ -214,12 +229,12 @@ void Streamer::add_stream( ...@@ -214,12 +229,12 @@ void Streamer::add_stream(
stream->codecpar, stream->codecpar,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
std::move(filter_desc), filter_desc,
device); device);
stream_indices.push_back(std::make_pair<>(i, key)); stream_indices.push_back(std::make_pair<>(i, key));
} }
void Streamer::remove_stream(int i) { void Streamer::remove_stream(int64_t i) {
validate_output_stream_index(i); validate_output_stream_index(i);
auto it = stream_indices.begin() + i; auto it = stream_indices.begin() + i;
int iP = it->first; int iP = it->first;
......
...@@ -19,11 +19,7 @@ class Streamer { ...@@ -19,11 +19,7 @@ class Streamer {
std::vector<std::pair<int, int>> stream_indices; std::vector<std::pair<int, int>> stream_indices;
public: public:
// Open the input and allocate the resource explicit Streamer(AVFormatContextPtr&& p);
Streamer(
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option);
~Streamer() = default; ~Streamer() = default;
// Non-copyable // Non-copyable
Streamer(const Streamer&) = delete; Streamer(const Streamer&) = delete;
...@@ -46,13 +42,13 @@ class Streamer { ...@@ -46,13 +42,13 @@ class Streamer {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
public: public:
// Find a suitable audio/video streams using heuristics from ffmpeg // Find a suitable audio/video streams using heuristics from ffmpeg
int find_best_audio_stream() const; int64_t find_best_audio_stream() const;
int find_best_video_stream() const; int64_t find_best_video_stream() const;
// Fetch information about source streams // Fetch information about source streams
int num_src_streams() const; int64_t num_src_streams() const;
SrcStreamInfo get_src_stream_info(int i) const; SrcStreamInfo get_src_stream_info(int i) const;
// Fetch information about output streams // Fetch information about output streams
int num_out_streams() const; int64_t num_out_streams() const;
OutputStreamInfo get_out_stream_info(int i) const; OutputStreamInfo get_out_stream_info(int i) const;
// Check if all the buffers of the output streams are ready. // Check if all the buffers of the output streams are ready.
bool is_buffer_ready() const; bool is_buffer_ready() const;
...@@ -63,21 +59,21 @@ class Streamer { ...@@ -63,21 +59,21 @@ class Streamer {
void seek(double timestamp); void seek(double timestamp);
void add_audio_stream( void add_audio_stream(
int i, int64_t i,
int frames_per_chunk, int64_t frames_per_chunk,
int num_chunks, int64_t num_chunks,
std::string filter_desc, const c10::optional<std::string>& filter_desc,
const std::string& decoder, const c10::optional<std::string>& decoder,
const std::map<std::string, std::string>& decoder_option); const OptionDict& decoder_option);
void add_video_stream( void add_video_stream(
int i, int64_t i,
int frames_per_chunk, int64_t frames_per_chunk,
int num_chunks, int64_t num_chunks,
std::string filter_desc, const c10::optional<std::string>& filter_desc,
const std::string& decoder, const c10::optional<std::string>& decoder,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device); const c10::optional<std::string>& hw_accel);
void remove_stream(int i); void remove_stream(int64_t i);
private: private:
void add_stream( void add_stream(
...@@ -85,9 +81,9 @@ class Streamer { ...@@ -85,9 +81,9 @@ class Streamer {
AVMediaType media_type, AVMediaType media_type,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_desc, const c10::optional<std::string>& filter_desc,
const std::string& decoder, const c10::optional<std::string>& decoder,
const std::map<std::string, std::string>& decoder_option, const OptionDict& decoder_option,
const torch::Device& device); const torch::Device& device);
public: public:
......
...@@ -154,6 +154,45 @@ def _parse_oi(i): ...@@ -154,6 +154,45 @@ def _parse_oi(i):
return StreamReaderOutputStream(i[0], i[1]) return StreamReaderOutputStream(i[0], i[1])
def _get_afilter_desc(sample_rate: Optional[int], dtype: torch.dtype):
descs = []
if sample_rate is not None:
descs.append(f"aresample={sample_rate}")
if dtype is not None:
fmt = {
torch.uint8: "u8p",
torch.int16: "s16p",
torch.int32: "s32p",
torch.long: "s64p",
torch.float32: "fltp",
torch.float64: "dblp",
}[dtype]
descs.append(f"aformat=sample_fmts={fmt}")
return ",".join(descs) if descs else None
def _get_vfilter_desc(frame_rate: Optional[float], width: Optional[int], height: Optional[int], format: Optional[str]):
descs = []
if frame_rate is not None:
descs.append(f"fps={frame_rate}")
scales = []
if width is not None:
scales.append(f"width={width}")
if height is not None:
scales.append(f"height={height}")
if scales:
descs.append(f"scale={':'.join(scales)}")
if format is not None:
fmt = {
"RGB": "rgb24",
"BGR": "bgr24",
"YUV": "yuv420p",
"GRAY": "gray",
}[format]
descs.append(f"format=pix_fmts={fmt}")
return ",".join(descs) if descs else None
class StreamReader: class StreamReader:
"""Fetch and decode audio/video streams chunk by chunk. """Fetch and decode audio/video streams chunk by chunk.
...@@ -297,8 +336,14 @@ class StreamReader: ...@@ -297,8 +336,14 @@ class StreamReader:
`[-1, 1]`. `[-1, 1]`.
""" """
i = self.default_audio_stream if stream_index is None else stream_index i = self.default_audio_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_basic_audio_stream( torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream(
self._s, i, frames_per_chunk, buffer_chunk_size, sample_rate, dtype self._s,
i,
frames_per_chunk,
buffer_chunk_size,
_get_afilter_desc(sample_rate, dtype),
None,
None,
) )
def add_basic_video_stream( def add_basic_video_stream(
...@@ -338,15 +383,15 @@ class StreamReader: ...@@ -338,15 +383,15 @@ class StreamReader:
- `GRAY`: 8 bits * 1 channels - `GRAY`: 8 bits * 1 channels
""" """
i = self.default_video_stream if stream_index is None else stream_index i = self.default_video_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_basic_video_stream( torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
self._s, self._s,
i, i,
frames_per_chunk, frames_per_chunk,
buffer_chunk_size, buffer_chunk_size,
frame_rate, _get_vfilter_desc(frame_rate, width, height, format),
width, None,
height, None,
format, None,
) )
def add_audio_stream( def add_audio_stream(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment