Commit ea78478e authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Tweak managed pointer interface (#3249)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3249

- Put ptr member private so that it's more secure and subclasses won't mess with it
- Remove unused `reset` method
- Do not default construct the managed object
  - Introduce helper function for default allocation.
    (for AVFrame and AVPacket as they are allocated in both reader and writer)
  - for others, allocation logics are moved to where it is used.
- Remove unused `pHWBufferRef` attribute from `StreamWriter`.

Reviewed By: hwangjeff

Differential Revision: D44775297

fbshipit-source-id: ff6db528152cd54c1ae398191110c30b9c1e238c
parent f7c8a7d3
......@@ -73,16 +73,13 @@ void AVPacketDeleter::operator()(AVPacket* p) {
av_packet_free(&p);
};
namespace {
AVPacket* get_av_packet() {
AVPacket* pPacket = av_packet_alloc();
TORCH_CHECK(pPacket, "Failed to allocate AVPacket object.");
return pPacket;
}
} // namespace
AVPacketPtr::AVPacketPtr(AVPacket* p) : Wrapper<AVPacket, AVPacketDeleter>(p) {}
AVPacketPtr::AVPacketPtr()
: Wrapper<AVPacket, AVPacketDeleter>(get_av_packet()) {}
AVPacketPtr alloc_avpacket() {
AVPacket* p = av_packet_alloc();
TORCH_CHECK(p, "Failed to allocate AVPacket object.");
return AVPacketPtr{p};
}
////////////////////////////////////////////////////////////////////////////////
// AVPacket - buffer unref
......@@ -101,15 +98,14 @@ AutoPacketUnref::operator AVPacket*() const {
void AVFrameDeleter::operator()(AVFrame* p) {
av_frame_free(&p);
};
namespace {
AVFrame* get_av_frame() {
AVFrame* pFrame = av_frame_alloc();
TORCH_CHECK(pFrame, "Failed to allocate AVFrame object.");
return pFrame;
}
} // namespace
AVFramePtr::AVFramePtr() : Wrapper<AVFrame, AVFrameDeleter>(get_av_frame()) {}
AVFramePtr::AVFramePtr(AVFrame* p) : Wrapper<AVFrame, AVFrameDeleter>(p) {}
AVFramePtr alloc_avframe() {
AVFrame* p = av_frame_alloc();
TORCH_CHECK(p, "Failed to allocate AVFrame object.");
return AVFramePtr{p};
};
////////////////////////////////////////////////////////////////////////////////
// AVCodecContext
......@@ -131,13 +127,6 @@ void AutoBufferUnref::operator()(AVBufferRef* p) {
AVBufferRefPtr::AVBufferRefPtr(AVBufferRef* p)
: Wrapper<AVBufferRef, AutoBufferUnref>(p) {}
void AVBufferRefPtr::reset(AVBufferRef* p) {
TORCH_CHECK(
!ptr,
"InternalError: A valid AVBufferRefPtr is being reset. Please file an issue.");
ptr.reset(p);
}
////////////////////////////////////////////////////////////////////////////////
// AVFilterGraph
////////////////////////////////////////////////////////////////////////////////
......@@ -145,19 +134,8 @@ void AVFilterGraphDeleter::operator()(AVFilterGraph* p) {
avfilter_graph_free(&p);
};
namespace {
AVFilterGraph* get_filter_graph() {
AVFilterGraph* ptr = avfilter_graph_alloc();
TORCH_CHECK(ptr, "Failed to allocate resouce.");
return ptr;
}
} // namespace
AVFilterGraphPtr::AVFilterGraphPtr()
: Wrapper<AVFilterGraph, AVFilterGraphDeleter>(get_filter_graph()) {}
void AVFilterGraphPtr::reset() {
ptr.reset(get_filter_graph());
}
AVFilterGraphPtr::AVFilterGraphPtr(AVFilterGraph* p)
: Wrapper<AVFilterGraph, AVFilterGraphDeleter>(p) {}
////////////////////////////////////////////////////////////////////////////////
// AVCodecParameters
......@@ -166,15 +144,8 @@ void AVCodecParametersDeleter::operator()(AVCodecParameters* codecpar) {
avcodec_parameters_free(&codecpar);
}
namespace {
AVCodecParameters* get_codecpar() {
AVCodecParameters* ptr = avcodec_parameters_alloc();
TORCH_CHECK(ptr, "Failed to allocate resource.");
return ptr;
}
} // namespace
AVCodecParametersPtr::AVCodecParametersPtr(AVCodecParameters* p)
: Wrapper<AVCodecParameters, AVCodecParametersDeleter>(p) {}
AVCodecParametersPtr::AVCodecParametersPtr()
: Wrapper<AVCodecParameters, AVCodecParametersDeleter>(get_codecpar()) {}
} // namespace io
} // namespace torchaudio
......@@ -54,7 +54,6 @@ av_always_inline std::string av_err2string(int errnum) {
// The resource allocation will be provided by custom constructors.
template <typename T, typename Deleter>
class Wrapper {
protected:
std::unique_ptr<T, Deleter> ptr;
public:
......@@ -123,9 +122,11 @@ struct AVPacketDeleter {
};
struct AVPacketPtr : public Wrapper<AVPacket, AVPacketDeleter> {
AVPacketPtr();
explicit AVPacketPtr(AVPacket* p);
};
AVPacketPtr alloc_avpacket();
////////////////////////////////////////////////////////////////////////////////
// AVPacket - buffer unref
////////////////////////////////////////////////////////////////////////////////
......@@ -152,9 +153,11 @@ struct AVFrameDeleter {
};
struct AVFramePtr : public Wrapper<AVFrame, AVFrameDeleter> {
AVFramePtr();
explicit AVFramePtr(AVFrame* p);
};
AVFramePtr alloc_avframe();
////////////////////////////////////////////////////////////////////////////////
// AutoBufferUnrer is responsible for performing unref at the end of lifetime
// of AVBufferRefPtr.
......@@ -164,8 +167,7 @@ struct AutoBufferUnref {
};
struct AVBufferRefPtr : public Wrapper<AVBufferRef, AutoBufferUnref> {
AVBufferRefPtr(AVBufferRef* p = nullptr);
void reset(AVBufferRef* p);
explicit AVBufferRefPtr(AVBufferRef* p);
};
////////////////////////////////////////////////////////////////////////////////
......@@ -186,8 +188,7 @@ struct AVFilterGraphDeleter {
void operator()(AVFilterGraph* p);
};
struct AVFilterGraphPtr : public Wrapper<AVFilterGraph, AVFilterGraphDeleter> {
AVFilterGraphPtr();
void reset();
explicit AVFilterGraphPtr(AVFilterGraph* p);
};
////////////////////////////////////////////////////////////////////////////////
......@@ -199,11 +200,11 @@ struct AVCodecParametersDeleter {
struct AVCodecParametersPtr
: public Wrapper<AVCodecParameters, AVCodecParametersDeleter> {
AVCodecParametersPtr();
explicit AVCodecParametersPtr(AVCodecParameters* p);
};
struct StreamParams {
AVCodecParametersPtr codec_params;
AVCodecParametersPtr codec_params{nullptr};
AVRational time_base{};
int stream_index{};
};
......
......@@ -4,7 +4,17 @@
namespace torchaudio {
namespace io {
FilterGraph::FilterGraph(AVMediaType media_type) : media_type(media_type) {
namespace {
AVFilterGraph* get_filter_graph() {
AVFilterGraph* ptr = avfilter_graph_alloc();
TORCH_CHECK(ptr, "Failed to allocate resouce.");
ptr->nb_threads = 1;
return ptr;
}
} // namespace
FilterGraph::FilterGraph(AVMediaType media_type)
: media_type(media_type), pFilterGraph(get_filter_graph()) {
switch (media_type) {
case AVMEDIA_TYPE_AUDIO:
case AVMEDIA_TYPE_VIDEO:
......@@ -12,8 +22,6 @@ FilterGraph::FilterGraph(AVMediaType media_type) : media_type(media_type) {
default:
TORCH_CHECK(false, "Only audio and video type is supported.");
}
pFilterGraph->nb_threads = 1;
}
////////////////////////////////////////////////////////////////////////////////
......
......@@ -4,9 +4,9 @@ namespace torchaudio {
namespace io {
void PacketBuffer::push_packet(AVPacket* packet) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(packet, "Packet is null.");
AVPacketPtr pPacket;
av_packet_ref(pPacket, packet);
packets.push_back(std::move(pPacket));
AVPacket* p = av_packet_clone(packet);
TORCH_INTERNAL_ASSERT(p, "Failed to clone packet.");
packets.emplace_back(p);
}
std::vector<AVPacketPtr> PacketBuffer::pop_packets() {
std::vector<AVPacketPtr> ret{
......
......@@ -95,7 +95,7 @@ struct FilterGraphWrapper {
template <typename Converter, typename Buffer>
struct ProcessImpl : public IPostDecodeProcess {
private:
AVFramePtr frame{};
AVFramePtr frame{alloc_avframe()};
FilterGraphWrapper filter_wrapper;
public:
......
......@@ -19,7 +19,7 @@ class StreamProcessor {
// Components for decoding source media
AVCodecContextPtr codec_ctx{nullptr};
AVFramePtr frame;
AVFramePtr frame{alloc_avframe()};
KeyType current_key = 0;
std::map<KeyType, std::unique_ptr<IPostDecodeProcess>> post_processes;
......
......@@ -179,19 +179,26 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
return ret;
}
namespace {
AVCodecParameters* get_codecpar() {
AVCodecParameters* ptr = avcodec_parameters_alloc();
TORCH_CHECK(ptr, "Failed to allocate resource.");
return ptr;
}
} // namespace
StreamParams StreamReader::get_src_stream_params(int i) {
StreamParams params;
validate_src_stream_index(pFormatContext, i);
AVStream* stream = pFormatContext->streams[i];
int ret = avcodec_parameters_copy(params.codec_params, stream->codecpar);
AVCodecParametersPtr codec_params(get_codecpar());
int ret = avcodec_parameters_copy(codec_params, stream->codecpar);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream's codec parameters. (",
av_err2string(ret),
")");
params.time_base = stream->time_base;
params.stream_index = i;
return params;
return {std::move(codec_params), stream->time_base, i};
}
int64_t StreamReader::num_out_streams() const {
......
......@@ -13,7 +13,7 @@ namespace io {
///
class StreamReader {
AVFormatInputContextPtr pFormatContext;
AVPacketPtr pPacket;
AVPacketPtr pPacket{alloc_avpacket()};
std::vector<std::unique_ptr<StreamProcessor>> processors;
// Mapping from user-facing stream index to internal index.
......
......@@ -676,7 +676,7 @@ AVFramePtr get_audio_frame(
int num_channels,
uint64_t channel_layout,
int nb_samples) {
AVFramePtr frame{};
AVFramePtr frame{alloc_avframe()};
frame->format = format;
frame->channel_layout = channel_layout;
frame->sample_rate = sample_rate;
......@@ -693,7 +693,7 @@ AVFramePtr get_audio_frame(
}
AVFramePtr get_video_frame(AVPixelFormat src_fmt, int width, int height) {
AVFramePtr frame{};
AVFramePtr frame{alloc_avframe()};
frame->format = src_fmt;
frame->width = width;
frame->height = height;
......@@ -921,7 +921,7 @@ EncodeProcess get_video_encode_process(
// 6. Instantiate source frame
AVFramePtr src_frame = [&]() {
if (codec_ctx->hw_frames_ctx) {
AVFramePtr frame{};
AVFramePtr frame{alloc_avframe()};
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
return frame;
......
......@@ -12,7 +12,7 @@ class EncodeProcess {
TensorConverter converter;
AVFramePtr src_frame;
FilterGraph filter;
AVFramePtr dst_frame{};
AVFramePtr dst_frame{alloc_avframe()};
Encoder encoder;
AVCodecContextPtr codec_ctx;
......
......@@ -16,7 +16,7 @@ class Encoder {
AVStream* stream;
// Temporary object used during the encoding
// Encoder owns it.
AVPacketPtr packet{};
AVPacketPtr packet{alloc_avpacket()};
public:
Encoder(
......
......@@ -15,11 +15,10 @@ namespace io {
///
class StreamWriter {
AVFormatOutputContextPtr pFormatContext;
AVBufferRefPtr pHWBufferRef;
std::map<int, EncodeProcess> processes;
std::map<int, PacketWriter> packet_writers;
AVPacketPtr pkt;
AVPacketPtr pkt{alloc_avpacket()};
bool is_open = false;
int current_key = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment