Commit ea78478e authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Tweak managed pointer interface (#3249)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3249

- Put ptr member private so that it's more secure and subclasses won't mess with it
- Remove unused `reset` method
- Do not default construct the managed object
  - Introduce helper function for default allocation.
    (for AVFrame and AVPacket as they are allocated in both reader and writer)
  - for others, allocation logics are moved to where it is used.
- Remove unused `pHWBufferRef` attribute from `StreamWriter`.

Reviewed By: hwangjeff

Differential Revision: D44775297

fbshipit-source-id: ff6db528152cd54c1ae398191110c30b9c1e238c
parent f7c8a7d3
...@@ -73,16 +73,13 @@ void AVPacketDeleter::operator()(AVPacket* p) { ...@@ -73,16 +73,13 @@ void AVPacketDeleter::operator()(AVPacket* p) {
av_packet_free(&p); av_packet_free(&p);
}; };
namespace { AVPacketPtr::AVPacketPtr(AVPacket* p) : Wrapper<AVPacket, AVPacketDeleter>(p) {}
AVPacket* get_av_packet() {
AVPacket* pPacket = av_packet_alloc();
TORCH_CHECK(pPacket, "Failed to allocate AVPacket object.");
return pPacket;
}
} // namespace
AVPacketPtr::AVPacketPtr() AVPacketPtr alloc_avpacket() {
: Wrapper<AVPacket, AVPacketDeleter>(get_av_packet()) {} AVPacket* p = av_packet_alloc();
TORCH_CHECK(p, "Failed to allocate AVPacket object.");
return AVPacketPtr{p};
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVPacket - buffer unref // AVPacket - buffer unref
...@@ -101,15 +98,14 @@ AutoPacketUnref::operator AVPacket*() const { ...@@ -101,15 +98,14 @@ AutoPacketUnref::operator AVPacket*() const {
void AVFrameDeleter::operator()(AVFrame* p) { void AVFrameDeleter::operator()(AVFrame* p) {
av_frame_free(&p); av_frame_free(&p);
}; };
namespace {
AVFrame* get_av_frame() {
AVFrame* pFrame = av_frame_alloc();
TORCH_CHECK(pFrame, "Failed to allocate AVFrame object.");
return pFrame;
}
} // namespace
AVFramePtr::AVFramePtr() : Wrapper<AVFrame, AVFrameDeleter>(get_av_frame()) {} AVFramePtr::AVFramePtr(AVFrame* p) : Wrapper<AVFrame, AVFrameDeleter>(p) {}
AVFramePtr alloc_avframe() {
AVFrame* p = av_frame_alloc();
TORCH_CHECK(p, "Failed to allocate AVFrame object.");
return AVFramePtr{p};
};
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVCodecContext // AVCodecContext
...@@ -131,13 +127,6 @@ void AutoBufferUnref::operator()(AVBufferRef* p) { ...@@ -131,13 +127,6 @@ void AutoBufferUnref::operator()(AVBufferRef* p) {
AVBufferRefPtr::AVBufferRefPtr(AVBufferRef* p) AVBufferRefPtr::AVBufferRefPtr(AVBufferRef* p)
: Wrapper<AVBufferRef, AutoBufferUnref>(p) {} : Wrapper<AVBufferRef, AutoBufferUnref>(p) {}
void AVBufferRefPtr::reset(AVBufferRef* p) {
TORCH_CHECK(
!ptr,
"InternalError: A valid AVBufferRefPtr is being reset. Please file an issue.");
ptr.reset(p);
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVFilterGraph // AVFilterGraph
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -145,19 +134,8 @@ void AVFilterGraphDeleter::operator()(AVFilterGraph* p) { ...@@ -145,19 +134,8 @@ void AVFilterGraphDeleter::operator()(AVFilterGraph* p) {
avfilter_graph_free(&p); avfilter_graph_free(&p);
}; };
namespace { AVFilterGraphPtr::AVFilterGraphPtr(AVFilterGraph* p)
AVFilterGraph* get_filter_graph() { : Wrapper<AVFilterGraph, AVFilterGraphDeleter>(p) {}
AVFilterGraph* ptr = avfilter_graph_alloc();
TORCH_CHECK(ptr, "Failed to allocate resouce.");
return ptr;
}
} // namespace
AVFilterGraphPtr::AVFilterGraphPtr()
: Wrapper<AVFilterGraph, AVFilterGraphDeleter>(get_filter_graph()) {}
void AVFilterGraphPtr::reset() {
ptr.reset(get_filter_graph());
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVCodecParameters // AVCodecParameters
...@@ -166,15 +144,8 @@ void AVCodecParametersDeleter::operator()(AVCodecParameters* codecpar) { ...@@ -166,15 +144,8 @@ void AVCodecParametersDeleter::operator()(AVCodecParameters* codecpar) {
avcodec_parameters_free(&codecpar); avcodec_parameters_free(&codecpar);
} }
namespace { AVCodecParametersPtr::AVCodecParametersPtr(AVCodecParameters* p)
AVCodecParameters* get_codecpar() { : Wrapper<AVCodecParameters, AVCodecParametersDeleter>(p) {}
AVCodecParameters* ptr = avcodec_parameters_alloc();
TORCH_CHECK(ptr, "Failed to allocate resource.");
return ptr;
}
} // namespace
AVCodecParametersPtr::AVCodecParametersPtr()
: Wrapper<AVCodecParameters, AVCodecParametersDeleter>(get_codecpar()) {}
} // namespace io } // namespace io
} // namespace torchaudio } // namespace torchaudio
...@@ -54,7 +54,6 @@ av_always_inline std::string av_err2string(int errnum) { ...@@ -54,7 +54,6 @@ av_always_inline std::string av_err2string(int errnum) {
// The resource allocation will be provided by custom constructors. // The resource allocation will be provided by custom constructors.
template <typename T, typename Deleter> template <typename T, typename Deleter>
class Wrapper { class Wrapper {
protected:
std::unique_ptr<T, Deleter> ptr; std::unique_ptr<T, Deleter> ptr;
public: public:
...@@ -123,9 +122,11 @@ struct AVPacketDeleter { ...@@ -123,9 +122,11 @@ struct AVPacketDeleter {
}; };
struct AVPacketPtr : public Wrapper<AVPacket, AVPacketDeleter> { struct AVPacketPtr : public Wrapper<AVPacket, AVPacketDeleter> {
AVPacketPtr(); explicit AVPacketPtr(AVPacket* p);
}; };
AVPacketPtr alloc_avpacket();
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVPacket - buffer unref // AVPacket - buffer unref
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -152,9 +153,11 @@ struct AVFrameDeleter { ...@@ -152,9 +153,11 @@ struct AVFrameDeleter {
}; };
struct AVFramePtr : public Wrapper<AVFrame, AVFrameDeleter> { struct AVFramePtr : public Wrapper<AVFrame, AVFrameDeleter> {
AVFramePtr(); explicit AVFramePtr(AVFrame* p);
}; };
AVFramePtr alloc_avframe();
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AutoBufferUnrer is responsible for performing unref at the end of lifetime // AutoBufferUnrer is responsible for performing unref at the end of lifetime
// of AVBufferRefPtr. // of AVBufferRefPtr.
...@@ -164,8 +167,7 @@ struct AutoBufferUnref { ...@@ -164,8 +167,7 @@ struct AutoBufferUnref {
}; };
struct AVBufferRefPtr : public Wrapper<AVBufferRef, AutoBufferUnref> { struct AVBufferRefPtr : public Wrapper<AVBufferRef, AutoBufferUnref> {
AVBufferRefPtr(AVBufferRef* p = nullptr); explicit AVBufferRefPtr(AVBufferRef* p);
void reset(AVBufferRef* p);
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -186,8 +188,7 @@ struct AVFilterGraphDeleter { ...@@ -186,8 +188,7 @@ struct AVFilterGraphDeleter {
void operator()(AVFilterGraph* p); void operator()(AVFilterGraph* p);
}; };
struct AVFilterGraphPtr : public Wrapper<AVFilterGraph, AVFilterGraphDeleter> { struct AVFilterGraphPtr : public Wrapper<AVFilterGraph, AVFilterGraphDeleter> {
AVFilterGraphPtr(); explicit AVFilterGraphPtr(AVFilterGraph* p);
void reset();
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -199,11 +200,11 @@ struct AVCodecParametersDeleter { ...@@ -199,11 +200,11 @@ struct AVCodecParametersDeleter {
struct AVCodecParametersPtr struct AVCodecParametersPtr
: public Wrapper<AVCodecParameters, AVCodecParametersDeleter> { : public Wrapper<AVCodecParameters, AVCodecParametersDeleter> {
AVCodecParametersPtr(); explicit AVCodecParametersPtr(AVCodecParameters* p);
}; };
struct StreamParams { struct StreamParams {
AVCodecParametersPtr codec_params; AVCodecParametersPtr codec_params{nullptr};
AVRational time_base{}; AVRational time_base{};
int stream_index{}; int stream_index{};
}; };
......
...@@ -4,7 +4,17 @@ ...@@ -4,7 +4,17 @@
namespace torchaudio { namespace torchaudio {
namespace io { namespace io {
FilterGraph::FilterGraph(AVMediaType media_type) : media_type(media_type) { namespace {
AVFilterGraph* get_filter_graph() {
AVFilterGraph* ptr = avfilter_graph_alloc();
TORCH_CHECK(ptr, "Failed to allocate resouce.");
ptr->nb_threads = 1;
return ptr;
}
} // namespace
FilterGraph::FilterGraph(AVMediaType media_type)
: media_type(media_type), pFilterGraph(get_filter_graph()) {
switch (media_type) { switch (media_type) {
case AVMEDIA_TYPE_AUDIO: case AVMEDIA_TYPE_AUDIO:
case AVMEDIA_TYPE_VIDEO: case AVMEDIA_TYPE_VIDEO:
...@@ -12,8 +22,6 @@ FilterGraph::FilterGraph(AVMediaType media_type) : media_type(media_type) { ...@@ -12,8 +22,6 @@ FilterGraph::FilterGraph(AVMediaType media_type) : media_type(media_type) {
default: default:
TORCH_CHECK(false, "Only audio and video type is supported."); TORCH_CHECK(false, "Only audio and video type is supported.");
} }
pFilterGraph->nb_threads = 1;
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
...@@ -4,9 +4,9 @@ namespace torchaudio { ...@@ -4,9 +4,9 @@ namespace torchaudio {
namespace io { namespace io {
void PacketBuffer::push_packet(AVPacket* packet) { void PacketBuffer::push_packet(AVPacket* packet) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(packet, "Packet is null."); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(packet, "Packet is null.");
AVPacketPtr pPacket; AVPacket* p = av_packet_clone(packet);
av_packet_ref(pPacket, packet); TORCH_INTERNAL_ASSERT(p, "Failed to clone packet.");
packets.push_back(std::move(pPacket)); packets.emplace_back(p);
} }
std::vector<AVPacketPtr> PacketBuffer::pop_packets() { std::vector<AVPacketPtr> PacketBuffer::pop_packets() {
std::vector<AVPacketPtr> ret{ std::vector<AVPacketPtr> ret{
......
...@@ -95,7 +95,7 @@ struct FilterGraphWrapper { ...@@ -95,7 +95,7 @@ struct FilterGraphWrapper {
template <typename Converter, typename Buffer> template <typename Converter, typename Buffer>
struct ProcessImpl : public IPostDecodeProcess { struct ProcessImpl : public IPostDecodeProcess {
private: private:
AVFramePtr frame{}; AVFramePtr frame{alloc_avframe()};
FilterGraphWrapper filter_wrapper; FilterGraphWrapper filter_wrapper;
public: public:
......
...@@ -19,7 +19,7 @@ class StreamProcessor { ...@@ -19,7 +19,7 @@ class StreamProcessor {
// Components for decoding source media // Components for decoding source media
AVCodecContextPtr codec_ctx{nullptr}; AVCodecContextPtr codec_ctx{nullptr};
AVFramePtr frame; AVFramePtr frame{alloc_avframe()};
KeyType current_key = 0; KeyType current_key = 0;
std::map<KeyType, std::unique_ptr<IPostDecodeProcess>> post_processes; std::map<KeyType, std::unique_ptr<IPostDecodeProcess>> post_processes;
......
...@@ -179,19 +179,26 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const { ...@@ -179,19 +179,26 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
return ret; return ret;
} }
namespace {
AVCodecParameters* get_codecpar() {
AVCodecParameters* ptr = avcodec_parameters_alloc();
TORCH_CHECK(ptr, "Failed to allocate resource.");
return ptr;
}
} // namespace
StreamParams StreamReader::get_src_stream_params(int i) { StreamParams StreamReader::get_src_stream_params(int i) {
StreamParams params;
validate_src_stream_index(pFormatContext, i); validate_src_stream_index(pFormatContext, i);
AVStream* stream = pFormatContext->streams[i]; AVStream* stream = pFormatContext->streams[i];
int ret = avcodec_parameters_copy(params.codec_params, stream->codecpar);
AVCodecParametersPtr codec_params(get_codecpar());
int ret = avcodec_parameters_copy(codec_params, stream->codecpar);
TORCH_CHECK( TORCH_CHECK(
ret >= 0, ret >= 0,
"Failed to copy the stream's codec parameters. (", "Failed to copy the stream's codec parameters. (",
av_err2string(ret), av_err2string(ret),
")"); ")");
params.time_base = stream->time_base; return {std::move(codec_params), stream->time_base, i};
params.stream_index = i;
return params;
} }
int64_t StreamReader::num_out_streams() const { int64_t StreamReader::num_out_streams() const {
......
...@@ -13,7 +13,7 @@ namespace io { ...@@ -13,7 +13,7 @@ namespace io {
/// ///
class StreamReader { class StreamReader {
AVFormatInputContextPtr pFormatContext; AVFormatInputContextPtr pFormatContext;
AVPacketPtr pPacket; AVPacketPtr pPacket{alloc_avpacket()};
std::vector<std::unique_ptr<StreamProcessor>> processors; std::vector<std::unique_ptr<StreamProcessor>> processors;
// Mapping from user-facing stream index to internal index. // Mapping from user-facing stream index to internal index.
......
...@@ -676,7 +676,7 @@ AVFramePtr get_audio_frame( ...@@ -676,7 +676,7 @@ AVFramePtr get_audio_frame(
int num_channels, int num_channels,
uint64_t channel_layout, uint64_t channel_layout,
int nb_samples) { int nb_samples) {
AVFramePtr frame{}; AVFramePtr frame{alloc_avframe()};
frame->format = format; frame->format = format;
frame->channel_layout = channel_layout; frame->channel_layout = channel_layout;
frame->sample_rate = sample_rate; frame->sample_rate = sample_rate;
...@@ -693,7 +693,7 @@ AVFramePtr get_audio_frame( ...@@ -693,7 +693,7 @@ AVFramePtr get_audio_frame(
} }
AVFramePtr get_video_frame(AVPixelFormat src_fmt, int width, int height) { AVFramePtr get_video_frame(AVPixelFormat src_fmt, int width, int height) {
AVFramePtr frame{}; AVFramePtr frame{alloc_avframe()};
frame->format = src_fmt; frame->format = src_fmt;
frame->width = width; frame->width = width;
frame->height = height; frame->height = height;
...@@ -921,7 +921,7 @@ EncodeProcess get_video_encode_process( ...@@ -921,7 +921,7 @@ EncodeProcess get_video_encode_process(
// 6. Instantiate source frame // 6. Instantiate source frame
AVFramePtr src_frame = [&]() { AVFramePtr src_frame = [&]() {
if (codec_ctx->hw_frames_ctx) { if (codec_ctx->hw_frames_ctx) {
AVFramePtr frame{}; AVFramePtr frame{alloc_avframe()};
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0); int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret)); TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
return frame; return frame;
......
...@@ -12,7 +12,7 @@ class EncodeProcess { ...@@ -12,7 +12,7 @@ class EncodeProcess {
TensorConverter converter; TensorConverter converter;
AVFramePtr src_frame; AVFramePtr src_frame;
FilterGraph filter; FilterGraph filter;
AVFramePtr dst_frame{}; AVFramePtr dst_frame{alloc_avframe()};
Encoder encoder; Encoder encoder;
AVCodecContextPtr codec_ctx; AVCodecContextPtr codec_ctx;
......
...@@ -16,7 +16,7 @@ class Encoder { ...@@ -16,7 +16,7 @@ class Encoder {
AVStream* stream; AVStream* stream;
// Temporary object used during the encoding // Temporary object used during the encoding
// Encoder owns it. // Encoder owns it.
AVPacketPtr packet{}; AVPacketPtr packet{alloc_avpacket()};
public: public:
Encoder( Encoder(
......
...@@ -15,11 +15,10 @@ namespace io { ...@@ -15,11 +15,10 @@ namespace io {
/// ///
class StreamWriter { class StreamWriter {
AVFormatOutputContextPtr pFormatContext; AVFormatOutputContextPtr pFormatContext;
AVBufferRefPtr pHWBufferRef;
std::map<int, EncodeProcess> processes; std::map<int, EncodeProcess> processes;
std::map<int, PacketWriter> packet_writers; std::map<int, PacketWriter> packet_writers;
AVPacketPtr pkt; AVPacketPtr pkt{alloc_avpacket()};
bool is_open = false; bool is_open = false;
int current_key = 0; int current_key = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment