Commit 93c26d63 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor the constructors of pointer wrappers (#2373)

Summary:
This commit refactor the constructor of wrapper classes so that
wrapper classes are only responsible for deallocation of underlying
FFmpeg custom structures.

The responsibility of custom initialization is moved to helper functions.

Context:

FFmpeg API uses bunch of raw pointers, which require dedicated allocater
and deallcoator. In torchaudio we wrap these pointers with
`std::unique_ptr<>` to adopt RAII semantics.

Currently all of the customization logics required for `Streamer` are
handled by the constructor of wrapper class. Like the following;

```
AVFormatContextPtr(
      const std::string& src,
      const std::string& device,
      const std::map<std::string, std::string>& option);
```

This constructor allocates the raw `AVFormatContext*` pointer,
while initializing it with the given option, then it parses the
input media.

As we consider the write/encode features, which require different way
of initializing the `AVFormatContext*`, making it the responsibility
of constructors of `AVFormatContextPtr` reduce the flexibility.

Thus this commit moves the customization to helper factory function.

- `AVFormatContextPtr(...)` -> `get_input_format_context(...)`
- `AVCodecContextPtr(...)` -> `get_decode_context(...)`

Pull Request resolved: https://github.com/pytorch/audio/pull/2373

Reviewed By: hwangjeff

Differential Revision: D36230148

Pulled By: mthrok

fbshipit-source-id: 202d57d549223904ee958193f3b386ef5a9cda3a
parent 2c79b55a
...@@ -11,7 +11,15 @@ Decoder::Decoder( ...@@ -11,7 +11,15 @@ Decoder::Decoder(
const std::string& decoder_name, const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option, const std::map<std::string, std::string>& decoder_option,
const torch::Device& device) const torch::Device& device)
: pCodecContext(pParam, decoder_name, decoder_option, device) {} : pCodecContext(get_decode_context(pParam->codec_id, decoder_name)) {
init_codec_context(
pCodecContext,
pParam,
decoder_name,
decoder_option,
device,
pHWBufferRef);
}
int Decoder::process_packet(AVPacket* pPacket) { int Decoder::process_packet(AVPacket* pPacket) {
return avcodec_send_packet(pCodecContext, pPacket); return avcodec_send_packet(pCodecContext, pPacket);
......
...@@ -7,6 +7,7 @@ namespace ffmpeg { ...@@ -7,6 +7,7 @@ namespace ffmpeg {
class Decoder { class Decoder {
AVCodecContextPtr pCodecContext; AVCodecContextPtr pCodecContext;
AVBufferRefPtr pHWBufferRef;
public: public:
// Default constructable // Default constructable
......
...@@ -62,7 +62,9 @@ std::string join(std::vector<std::string> vars) { ...@@ -62,7 +62,9 @@ std::string join(std::vector<std::string> vars) {
#define AVINPUT_FORMAT_CONST #define AVINPUT_FORMAT_CONST
#endif #endif
AVFormatContext* get_format_context( } // namespace
AVFormatContextPtr get_input_format_context(
const std::string& src, const std::string& src,
const std::string& device, const std::string& device,
const std::map<std::string, std::string>& option) { const std::map<std::string, std::string>& option) {
...@@ -83,19 +85,11 @@ AVFormatContext* get_format_context( ...@@ -83,19 +85,11 @@ AVFormatContext* get_format_context(
throw std::runtime_error( throw std::runtime_error(
"Failed to open the input \"" + src + "\" (" + av_err2string(ret) + "Failed to open the input \"" + src + "\" (" + av_err2string(ret) +
")."); ").");
return pFormat; return AVFormatContextPtr(pFormat);
} }
} // namespace
AVFormatContextPtr::AVFormatContextPtr( AVFormatContextPtr::AVFormatContextPtr(AVFormatContext* p)
const std::string& src, : Wrapper<AVFormatContext, AVFormatContextDeleter>(p) {}
const std::string& device,
const std::map<std::string, std::string>& option)
: Wrapper<AVFormatContext, AVFormatContextDeleter>(
get_format_context(src, device, option)) {
if (avformat_find_stream_info(ptr.get(), NULL) < 0)
throw std::runtime_error("Failed to find stream information.");
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVPacket // AVPacket
...@@ -152,7 +146,7 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) { ...@@ -152,7 +146,7 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) {
}; };
namespace { namespace {
AVCodecContext* get_codec_context( const AVCodec* get_decode_codec(
enum AVCodecID codec_id, enum AVCodecID codec_id,
const std::string& decoder_name) { const std::string& decoder_name) {
const AVCodec* pCodec = decoder_name.empty() const AVCodec* pCodec = decoder_name.empty()
...@@ -169,12 +163,21 @@ AVCodecContext* get_codec_context( ...@@ -169,12 +163,21 @@ AVCodecContext* get_codec_context(
} }
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
return pCodec;
}
} // namespace
AVCodecContextPtr get_decode_context(
enum AVCodecID codec_id,
const std::string& decoder_name) {
const AVCodec* pCodec = get_decode_codec(codec_id, decoder_name);
AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec); AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec);
if (!pCodecContext) { if (!pCodecContext) {
throw std::runtime_error("Failed to allocate CodecContext."); throw std::runtime_error("Failed to allocate CodecContext.");
} }
return pCodecContext; return AVCodecContextPtr(pCodecContext);
} }
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -217,12 +220,7 @@ void init_codec_context( ...@@ -217,12 +220,7 @@ void init_codec_context(
const std::map<std::string, std::string>& decoder_option, const std::map<std::string, std::string>& decoder_option,
const torch::Device& device, const torch::Device& device,
AVBufferRefPtr& pHWBufferRef) { AVBufferRefPtr& pHWBufferRef) {
const AVCodec* pCodec = decoder_name.empty() const AVCodec* pCodec = get_decode_codec(pParams->codec_id, decoder_name);
? avcodec_find_decoder(pParams->codec_id)
: avcodec_find_decoder_by_name(decoder_name.c_str());
// No need to check if pCodec is null as it's been already checked in
// get_codec_context
if (avcodec_parameters_to_context(pCodecContext, pParams) < 0) { if (avcodec_parameters_to_context(pCodecContext, pParams) < 0) {
throw std::runtime_error("Failed to set CodecContext parameter."); throw std::runtime_error("Failed to set CodecContext parameter.");
...@@ -276,19 +274,9 @@ void init_codec_context( ...@@ -276,19 +274,9 @@ void init_codec_context(
pParams->channel_layout = pParams->channel_layout =
av_get_default_channel_layout(pCodecContext->channels); av_get_default_channel_layout(pCodecContext->channels);
} }
} // namespace
AVCodecContextPtr::AVCodecContextPtr( AVCodecContextPtr::AVCodecContextPtr(AVCodecContext* p)
AVCodecParameters* pParam, : Wrapper<AVCodecContext, AVCodecContextDeleter>(p) {}
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const torch::Device& device)
: Wrapper<AVCodecContext, AVCodecContextDeleter>(
get_codec_context(pParam->codec_id, decoder_name)),
pHWBufferRef() {
init_codec_context(
ptr.get(), pParam, decoder_name, decoder_option, device, pHWBufferRef);
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVBufferRefPtr // AVBufferRefPtr
......
...@@ -65,12 +65,15 @@ struct AVFormatContextDeleter { ...@@ -65,12 +65,15 @@ struct AVFormatContextDeleter {
struct AVFormatContextPtr struct AVFormatContextPtr
: public Wrapper<AVFormatContext, AVFormatContextDeleter> { : public Wrapper<AVFormatContext, AVFormatContextDeleter> {
AVFormatContextPtr( explicit AVFormatContextPtr(AVFormatContext* p);
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option);
}; };
// create format context for reading media
AVFormatContextPtr get_input_format_context(
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option);
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVPacket // AVPacket
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -132,15 +135,23 @@ struct AVCodecContextDeleter { ...@@ -132,15 +135,23 @@ struct AVCodecContextDeleter {
}; };
struct AVCodecContextPtr struct AVCodecContextPtr
: public Wrapper<AVCodecContext, AVCodecContextDeleter> { : public Wrapper<AVCodecContext, AVCodecContextDeleter> {
AVBufferRefPtr pHWBufferRef; explicit AVCodecContextPtr(AVCodecContext* p);
AVCodecContextPtr(
AVCodecParameters* pParam,
const std::string& decoder,
const std::map<std::string, std::string>& decoder_option,
const torch::Device& device);
}; };
// Allocate codec context from either decoder name or ID
AVCodecContextPtr get_decode_context(
enum AVCodecID codec_id,
const std::string& decoder);
// Initialize codec context with the parameters
void init_codec_context(
AVCodecContext* pCodecContext,
AVCodecParameters* pParams,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const torch::Device& device,
AVBufferRefPtr& pHWBufferRef);
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVFilterGraph // AVFilterGraph
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
...@@ -46,7 +46,11 @@ Streamer::Streamer( ...@@ -46,7 +46,11 @@ Streamer::Streamer(
const std::string& src, const std::string& src,
const std::string& device, const std::string& device,
const std::map<std::string, std::string>& option) const std::map<std::string, std::string>& option)
: pFormatContext(src, device, option) { : pFormatContext(get_input_format_context(src, device, option)) {
if (avformat_find_stream_info(pFormatContext, nullptr) < 0) {
throw std::runtime_error("Failed to find stream information.");
}
processors = processors =
std::vector<std::unique_ptr<StreamProcessor>>(pFormatContext->nb_streams); std::vector<std::unique_ptr<StreamProcessor>>(pFormatContext->nb_streams);
for (int i = 0; i < pFormatContext->nb_streams; ++i) { for (int i = 0; i < pFormatContext->nb_streams; ++i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment