Commit be243c59 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Support specifying decoder and its options (#2327)

Summary:
This commit adds support to specify decoder to Streamer's add stream method.
This is roughly equivalent to `ffmpeg`'s `-c:v foo` and `-c:a foo` options.

This allows to override the decoder codec and/or specify the option of
the decoder.

This change allows to specify Nvidia NVDEC codec for supported formats,
which uses dedicated hardware for decoding the video.

 ---

Note: The CL might look overwhelming, but it's essentially, add new parameters in Python, and pass them down all the way to  `AVCodecContextPtr`, which initializes the actual decoder implementation (`AVCodecContext`.)

Pull Request resolved: https://github.com/pytorch/audio/pull/2327

Reviewed By: carolineechen

Differential Revision: D35626904

Pulled By: mthrok

fbshipit-source-id: a115ed548624e53c16bacfecff5aa6c9d4e8bede
parent 7972be99
...@@ -6,7 +6,11 @@ namespace ffmpeg { ...@@ -6,7 +6,11 @@ namespace ffmpeg {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Decoder // Decoder
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
Decoder::Decoder(AVCodecParameters* pParam) : pCodecContext(pParam) {} Decoder::Decoder(
AVCodecParameters* pParam,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option)
: pCodecContext(pParam, decoder_name, decoder_option) {}
int Decoder::process_packet(AVPacket* pPacket) { int Decoder::process_packet(AVPacket* pPacket) {
return avcodec_send_packet(pCodecContext, pPacket); return avcodec_send_packet(pCodecContext, pPacket);
......
...@@ -10,7 +10,10 @@ class Decoder { ...@@ -10,7 +10,10 @@ class Decoder {
public: public:
// Default constructable // Default constructable
Decoder(AVCodecParameters* pParam); Decoder(
AVCodecParameters* pParam,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option);
// Custom destructor to clean up the resources // Custom destructor to clean up the resources
~Decoder() = default; ~Decoder() = default;
// Non-copyable // Non-copyable
......
...@@ -151,11 +151,22 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) { ...@@ -151,11 +151,22 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) {
}; };
namespace { namespace {
AVCodecContext* get_codec_context(AVCodecParameters* pParams) { AVCodecContext* get_codec_context(
const AVCodec* pCodec = avcodec_find_decoder(pParams->codec_id); enum AVCodecID codec_id,
const std::string& decoder_name) {
const AVCodec* pCodec = decoder_name.empty()
? avcodec_find_decoder(codec_id)
: avcodec_find_decoder_by_name(decoder_name.c_str());
if (!pCodec) { if (!pCodec) {
throw std::runtime_error("Unknown codec."); std::stringstream ss;
if (decoder_name.empty()) {
ss << "Unsupported codec: \"" << avcodec_get_name(codec_id) << "\", ("
<< codec_id << ").";
} else {
ss << "Unsupported codec: \"" << decoder_name << "\".";
}
throw std::runtime_error(ss.str());
} }
AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec); AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec);
...@@ -167,16 +178,29 @@ AVCodecContext* get_codec_context(AVCodecParameters* pParams) { ...@@ -167,16 +178,29 @@ AVCodecContext* get_codec_context(AVCodecParameters* pParams) {
void init_codec_context( void init_codec_context(
AVCodecContext* pCodecContext, AVCodecContext* pCodecContext,
AVCodecParameters* pParams) { AVCodecParameters* pParams,
const AVCodec* pCodec = avcodec_find_decoder(pParams->codec_id); const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option) {
const AVCodec* pCodec = decoder_name.empty()
? avcodec_find_decoder(pParams->codec_id)
: avcodec_find_decoder_by_name(decoder_name.c_str());
// No need to check if pCodec is null as it's been already checked in
// get_codec_context
if (avcodec_parameters_to_context(pCodecContext, pParams) < 0) { if (avcodec_parameters_to_context(pCodecContext, pParams) < 0) {
throw std::runtime_error("Failed to set CodecContext parameter."); throw std::runtime_error("Failed to set CodecContext parameter.");
} }
if (avcodec_open2(pCodecContext, pCodec, NULL) < 0) { AVDictionary* opts = get_option_dict(decoder_option);
if (avcodec_open2(pCodecContext, pCodec, &opts) < 0) {
throw std::runtime_error("Failed to initialize CodecContext."); throw std::runtime_error("Failed to initialize CodecContext.");
} }
auto unused_keys = clean_up_dict(opts);
if (unused_keys.size()) {
throw std::runtime_error(
"Unexpected decoder options: " + join(unused_keys));
}
if (pParams->codec_type == AVMEDIA_TYPE_AUDIO && !pParams->channel_layout) if (pParams->codec_type == AVMEDIA_TYPE_AUDIO && !pParams->channel_layout)
pParams->channel_layout = pParams->channel_layout =
...@@ -184,10 +208,13 @@ void init_codec_context( ...@@ -184,10 +208,13 @@ void init_codec_context(
} }
} // namespace } // namespace
AVCodecContextPtr::AVCodecContextPtr(AVCodecParameters* pParam) AVCodecContextPtr::AVCodecContextPtr(
AVCodecParameters* pParam,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option)
: Wrapper<AVCodecContext, AVCodecContextDeleter>( : Wrapper<AVCodecContext, AVCodecContextDeleter>(
get_codec_context(pParam)) { get_codec_context(pParam->codec_id, decoder_name)) {
init_codec_context(ptr.get(), pParam); init_codec_context(ptr.get(), pParam, decoder_name, decoder_option);
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVFilterGraph // AVFilterGraph
......
...@@ -118,7 +118,10 @@ struct AVCodecContextDeleter { ...@@ -118,7 +118,10 @@ struct AVCodecContextDeleter {
}; };
struct AVCodecContextPtr struct AVCodecContextPtr
: public Wrapper<AVCodecContext, AVCodecContextDeleter> { : public Wrapper<AVCodecContext, AVCodecContextDeleter> {
AVCodecContextPtr(AVCodecParameters* pParam); AVCodecContextPtr(
AVCodecParameters* pParam,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option);
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
...@@ -7,8 +7,10 @@ namespace ffmpeg { ...@@ -7,8 +7,10 @@ namespace ffmpeg {
namespace { namespace {
using OptionDict = c10::Dict<std::string, std::string>;
std::map<std::string, std::string> convert_dict( std::map<std::string, std::string> convert_dict(
const c10::optional<c10::Dict<std::string, std::string>>& option) { const c10::optional<OptionDict>& option) {
std::map<std::string, std::string> opts; std::map<std::string, std::string> opts;
if (option) { if (option) {
for (auto& it : option.value()) { for (auto& it : option.value()) {
...@@ -22,8 +24,8 @@ struct StreamerHolder : torch::CustomClassHolder { ...@@ -22,8 +24,8 @@ struct StreamerHolder : torch::CustomClassHolder {
Streamer s; Streamer s;
StreamerHolder( StreamerHolder(
const std::string& src, const std::string& src,
c10::optional<std::string> device, const c10::optional<std::string>& device,
c10::optional<c10::Dict<std::string, std::string>> option) const c10::optional<OptionDict>& option)
: s(src, device.value_or(""), convert_dict(option)) {} : s(src, device.value_or(""), convert_dict(option)) {}
}; };
...@@ -31,8 +33,8 @@ using S = c10::intrusive_ptr<StreamerHolder>; ...@@ -31,8 +33,8 @@ using S = c10::intrusive_ptr<StreamerHolder>;
S init( S init(
const std::string& src, const std::string& src,
c10::optional<std::string> device, const c10::optional<std::string>& device,
c10::optional<c10::Dict<std::string, std::string>> option) { const c10::optional<OptionDict>& option) {
return c10::make_intrusive<StreamerHolder>(src, device, option); return c10::make_intrusive<StreamerHolder>(src, device, option);
} }
...@@ -231,7 +233,7 @@ void add_basic_audio_stream( ...@@ -231,7 +233,7 @@ void add_basic_audio_stream(
const c10::optional<int64_t>& sample_rate, const c10::optional<int64_t>& sample_rate,
const c10::optional<c10::ScalarType>& dtype) { const c10::optional<c10::ScalarType>& dtype) {
std::string filter_desc = get_afilter_desc(sample_rate, dtype); std::string filter_desc = get_afilter_desc(sample_rate, dtype);
s->s.add_audio_stream(i, frames_per_chunk, num_chunks, filter_desc); s->s.add_audio_stream(i, frames_per_chunk, num_chunks, filter_desc, "", {});
} }
void add_basic_video_stream( void add_basic_video_stream(
...@@ -244,7 +246,7 @@ void add_basic_video_stream( ...@@ -244,7 +246,7 @@ void add_basic_video_stream(
const c10::optional<int64_t>& height, const c10::optional<int64_t>& height,
const c10::optional<std::string>& format) { const c10::optional<std::string>& format) {
std::string filter_desc = get_vfilter_desc(frame_rate, width, height, format); std::string filter_desc = get_vfilter_desc(frame_rate, width, height, format);
s->s.add_video_stream(i, frames_per_chunk, num_chunks, filter_desc); s->s.add_video_stream(i, frames_per_chunk, num_chunks, filter_desc, "", {});
} }
void add_audio_stream( void add_audio_stream(
...@@ -252,9 +254,16 @@ void add_audio_stream( ...@@ -252,9 +254,16 @@ void add_audio_stream(
int64_t i, int64_t i,
int64_t frames_per_chunk, int64_t frames_per_chunk,
int64_t num_chunks, int64_t num_chunks,
const c10::optional<std::string>& filter_desc) { const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_options) {
s->s.add_audio_stream( s->s.add_audio_stream(
i, frames_per_chunk, num_chunks, filter_desc.value_or("")); i,
frames_per_chunk,
num_chunks,
filter_desc.value_or(""),
decoder.value_or(""),
convert_dict(decoder_options));
} }
void add_video_stream( void add_video_stream(
...@@ -262,9 +271,16 @@ void add_video_stream( ...@@ -262,9 +271,16 @@ void add_video_stream(
int64_t i, int64_t i,
int64_t frames_per_chunk, int64_t frames_per_chunk,
int64_t num_chunks, int64_t num_chunks,
const c10::optional<std::string>& filter_desc) { const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_options) {
s->s.add_video_stream( s->s.add_video_stream(
i, frames_per_chunk, num_chunks, filter_desc.value_or("")); i,
frames_per_chunk,
num_chunks,
filter_desc.value_or(""),
decoder.value_or(""),
convert_dict(decoder_options));
} }
void remove_stream(S s, int64_t i) { void remove_stream(S s, int64_t i) {
...@@ -308,7 +324,7 @@ std::tuple<c10::optional<torch::Tensor>, int64_t> load(const std::string& src) { ...@@ -308,7 +324,7 @@ std::tuple<c10::optional<torch::Tensor>, int64_t> load(const std::string& src) {
int i = s.find_best_audio_stream(); int i = s.find_best_audio_stream();
auto sinfo = s.get_src_stream_info(i); auto sinfo = s.get_src_stream_info(i);
int64_t sample_rate = static_cast<int64_t>(sinfo.sample_rate); int64_t sample_rate = static_cast<int64_t>(sinfo.sample_rate);
s.add_audio_stream(i, -1, -1, ""); s.add_audio_stream(i, -1, -1, "", "", {});
process_all_packets(s); process_all_packets(s);
auto tensors = s.pop_chunks(); auto tensors = s.pop_chunks();
return std::make_tuple<>(tensors[0], sample_rate); return std::make_tuple<>(tensors[0], sample_rate);
......
...@@ -6,8 +6,11 @@ namespace ffmpeg { ...@@ -6,8 +6,11 @@ namespace ffmpeg {
using KeyType = StreamProcessor::KeyType; using KeyType = StreamProcessor::KeyType;
StreamProcessor::StreamProcessor(AVCodecParameters* codecpar) StreamProcessor::StreamProcessor(
: decoder(codecpar) {} AVCodecParameters* codecpar,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option)
: decoder(codecpar, decoder_name, decoder_option) {}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Configurations // Configurations
......
...@@ -25,7 +25,10 @@ class StreamProcessor { ...@@ -25,7 +25,10 @@ class StreamProcessor {
std::map<KeyType, Sink> sinks; std::map<KeyType, Sink> sinks;
public: public:
StreamProcessor(AVCodecParameters* codecpar); StreamProcessor(
AVCodecParameters* codecpar,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option);
~StreamProcessor() = default; ~StreamProcessor() = default;
// Non-copyable // Non-copyable
StreamProcessor(const StreamProcessor&) = delete; StreamProcessor(const StreamProcessor&) = delete;
......
...@@ -156,26 +156,34 @@ void Streamer::add_audio_stream( ...@@ -156,26 +156,34 @@ void Streamer::add_audio_stream(
int i, int i,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_desc) { std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option) {
add_stream( add_stream(
i, i,
AVMEDIA_TYPE_AUDIO, AVMEDIA_TYPE_AUDIO,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
std::move(filter_desc)); std::move(filter_desc),
decoder,
decoder_option);
} }
void Streamer::add_video_stream( void Streamer::add_video_stream(
int i, int i,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_desc) { std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option) {
add_stream( add_stream(
i, i,
AVMEDIA_TYPE_VIDEO, AVMEDIA_TYPE_VIDEO,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
std::move(filter_desc)); std::move(filter_desc),
decoder,
decoder_option);
} }
void Streamer::add_stream( void Streamer::add_stream(
...@@ -183,12 +191,15 @@ void Streamer::add_stream( ...@@ -183,12 +191,15 @@ void Streamer::add_stream(
AVMediaType media_type, AVMediaType media_type,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_desc) { std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option) {
validate_src_stream_type(i, media_type); validate_src_stream_type(i, media_type);
AVStream* stream = pFormatContext->streams[i]; AVStream* stream = pFormatContext->streams[i];
stream->discard = AVDISCARD_DEFAULT; stream->discard = AVDISCARD_DEFAULT;
if (!processors[i]) if (!processors[i])
processors[i] = std::make_unique<StreamProcessor>(stream->codecpar); processors[i] = std::make_unique<StreamProcessor>(
stream->codecpar, decoder, decoder_option);
int key = processors[i]->add_stream( int key = processors[i]->add_stream(
stream->time_base, stream->time_base,
stream->codecpar, stream->codecpar,
......
...@@ -66,12 +66,16 @@ class Streamer { ...@@ -66,12 +66,16 @@ class Streamer {
int i, int i,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_desc); std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option);
void add_video_stream( void add_video_stream(
int i, int i,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_desc); std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option);
void remove_stream(int i); void remove_stream(int i);
private: private:
...@@ -80,7 +84,9 @@ class Streamer { ...@@ -80,7 +84,9 @@ class Streamer {
AVMediaType media_type, AVMediaType media_type,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
std::string filter_desc); std::string filter_desc,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option);
public: public:
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
......
...@@ -355,6 +355,8 @@ class Streamer: ...@@ -355,6 +355,8 @@ class Streamer:
buffer_chunk_size: int = 3, buffer_chunk_size: int = 3,
stream_index: Optional[int] = None, stream_index: Optional[int] = None,
filter_desc: Optional[str] = None, filter_desc: Optional[str] = None,
decoder: Optional[str] = None,
decoder_options: Optional[Dict[str, str]] = None,
): ):
"""Add output audio stream """Add output audio stream
...@@ -375,10 +377,22 @@ class Streamer: ...@@ -375,10 +377,22 @@ class Streamer:
The list of available filters can be found at The list of available filters can be found at
https://ffmpeg.org/ffmpeg-filters.html https://ffmpeg.org/ffmpeg-filters.html
Note that complex filters are not supported. Note that complex filters are not supported.
decoder (str or None, optional): The name of the decoder to be used.
When provided, use the specified decoder instead of the default one.
decoder_options (dict or None, optional): Options passed to decoder.
Mapping from str to str.
""" """
i = self.default_audio_stream if stream_index is None else stream_index i = self.default_audio_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream( torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream(
self._s, i, frames_per_chunk, buffer_chunk_size, filter_desc self._s,
i,
frames_per_chunk,
buffer_chunk_size,
filter_desc,
decoder,
decoder_options,
) )
def add_video_stream( def add_video_stream(
...@@ -387,6 +401,8 @@ class Streamer: ...@@ -387,6 +401,8 @@ class Streamer:
buffer_chunk_size: int = 3, buffer_chunk_size: int = 3,
stream_index: Optional[int] = None, stream_index: Optional[int] = None,
filter_desc: Optional[str] = None, filter_desc: Optional[str] = None,
decoder: Optional[str] = None,
decoder_options: Optional[Dict[str, str]] = None,
): ):
"""Add output video stream """Add output video stream
...@@ -407,10 +423,22 @@ class Streamer: ...@@ -407,10 +423,22 @@ class Streamer:
The list of available filters can be found at The list of available filters can be found at
https://ffmpeg.org/ffmpeg-filters.html https://ffmpeg.org/ffmpeg-filters.html
Note that complex filters are not supported. Note that complex filters are not supported.
decoder (str or None, optional): The name of the decoder to be used.
When provided, use the specified decoder instead of the default one.
decoder_options (dict or None, optional): Options passed to decoder.
Mapping from str to str.
""" """
i = self.default_video_stream if stream_index is None else stream_index i = self.default_video_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream( torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
self._s, i, frames_per_chunk, buffer_chunk_size, filter_desc self._s,
i,
frames_per_chunk,
buffer_chunk_size,
filter_desc,
decoder,
decoder_options,
) )
def remove_stream(self, i: int): def remove_stream(self, i: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment