Commit e259f156 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor buffer common utils (#2988)

Summary:
Split `convert_video` into memory allocation function and write function.

Also put all the buffer implementations into detail namespace.

Pull Request resolved: https://github.com/pytorch/audio/pull/2988

Reviewed By: xiaohui-zhang

Differential Revision: D42536769

Pulled By: mthrok

fbshipit-source-id: 36fbf437d4bfd521322846161ae08a48c782c540
parent f9d38796
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace detail {
ChunkedBuffer::ChunkedBuffer(int frames_per_chunk, int num_chunks) ChunkedBuffer::ChunkedBuffer(int frames_per_chunk, int num_chunks)
: frames_per_chunk(frames_per_chunk), num_chunks(num_chunks) {} : frames_per_chunk(frames_per_chunk), num_chunks(num_chunks) {}
...@@ -120,16 +121,17 @@ c10::optional<torch::Tensor> ChunkedBuffer::pop_chunk() { ...@@ -120,16 +121,17 @@ c10::optional<torch::Tensor> ChunkedBuffer::pop_chunk() {
} }
void ChunkedAudioBuffer::push_frame(AVFrame* frame) { void ChunkedAudioBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_audio(frame)); push_tensor(convert_audio(frame));
} }
void ChunkedVideoBuffer::push_frame(AVFrame* frame) { void ChunkedVideoBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_image(frame, device)); push_tensor(convert_image(frame, device));
} }
void ChunkedBuffer::flush() { void ChunkedBuffer::flush() {
chunks.clear(); chunks.clear();
} }
} // namespace detail
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace detail {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Chunked Buffer Implementation // Chunked Buffer Implementation
...@@ -53,5 +54,6 @@ class ChunkedVideoBuffer : public ChunkedBuffer { ...@@ -53,5 +54,6 @@ class ChunkedVideoBuffer : public ChunkedBuffer {
void push_frame(AVFrame* frame) override; void push_frame(AVFrame* frame) override;
}; };
} // namespace detail
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -95,9 +95,11 @@ torch::Tensor get_buffer( ...@@ -95,9 +95,11 @@ torch::Tensor get_buffer(
.device(device.type(), device.index()); .device(device.type(), device.index());
return torch::empty(shape, options); return torch::empty(shape, options);
} }
} // namespace
torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) { std::tuple<torch::Tensor, bool> get_image_buffer(
AVFrame* frame,
int num_frames,
const torch::Device& device) {
auto fmt = static_cast<AVPixelFormat>(frame->format); auto fmt = static_cast<AVPixelFormat>(frame->format);
const AVPixFmtDescriptor* desc = [&]() { const AVPixFmtDescriptor* desc = [&]() {
if (fmt == AV_PIX_FMT_CUDA) { if (fmt == AV_PIX_FMT_CUDA) {
...@@ -124,13 +126,16 @@ torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) { ...@@ -124,13 +126,16 @@ torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) {
// The actual number of planes can be retrieved with // The actual number of planes can be retrieved with
// av_pix_fmt_count_planes. // av_pix_fmt_count_planes.
int height = frame->height;
int width = frame->width;
if (desc->flags & AV_PIX_FMT_FLAG_PLANAR) { if (desc->flags & AV_PIX_FMT_FLAG_PLANAR) {
return get_buffer({1, channels, frame->height, frame->width}, device); auto buffer = get_buffer({num_frames, channels, height, width}, device);
return std::make_tuple(buffer, true);
} }
return get_buffer({1, frame->height, frame->width, channels}, device); auto buffer = get_buffer({num_frames, height, width, channels}, device);
return std::make_tuple(buffer, false);
} }
namespace {
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) { void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
auto ptr = frame.data_ptr<uint8_t>(); auto ptr = frame.data_ptr<uint8_t>();
uint8_t* buf = pFrame->data[0]; uint8_t* buf = pFrame->data[0];
...@@ -304,14 +309,11 @@ void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) { ...@@ -304,14 +309,11 @@ void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) {
} }
#endif #endif
} // namespace void write_image(AVFrame* frame, torch::Tensor& buf) {
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
// ref: // ref:
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179 // https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
// https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038 // https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
AVPixelFormat format = static_cast<AVPixelFormat>(frame->format); AVPixelFormat format = static_cast<AVPixelFormat>(frame->format);
torch::Tensor buf = get_image_buffer(frame, device);
switch (format) { switch (format) {
case AV_PIX_FMT_RGB24: case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: case AV_PIX_FMT_BGR24:
...@@ -321,19 +323,19 @@ torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) { ...@@ -321,19 +323,19 @@ torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
case AV_PIX_FMT_BGRA: case AV_PIX_FMT_BGRA:
case AV_PIX_FMT_GRAY8: { case AV_PIX_FMT_GRAY8: {
write_interlaced_image(frame, buf); write_interlaced_image(frame, buf);
return buf.permute({0, 3, 1, 2}); return;
} }
case AV_PIX_FMT_YUV444P: { case AV_PIX_FMT_YUV444P: {
write_planar_image(frame, buf); write_planar_image(frame, buf);
return buf; return;
} }
case AV_PIX_FMT_YUV420P: { case AV_PIX_FMT_YUV420P: {
write_yuv420p(frame, buf); write_yuv420p(frame, buf);
return buf; return;
} }
case AV_PIX_FMT_NV12: { case AV_PIX_FMT_NV12: {
write_nv12_cpu(frame, buf); write_nv12_cpu(frame, buf);
return buf; return;
} }
#ifdef USE_CUDA #ifdef USE_CUDA
case AV_PIX_FMT_CUDA: { case AV_PIX_FMT_CUDA: {
...@@ -345,7 +347,7 @@ torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) { ...@@ -345,7 +347,7 @@ torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
switch (sw_format) { switch (sw_format) {
case AV_PIX_FMT_NV12: { case AV_PIX_FMT_NV12: {
write_nv12_cuda(frame, buf); write_nv12_cuda(frame, buf);
return buf; return;
} }
case AV_PIX_FMT_P010: case AV_PIX_FMT_P010:
case AV_PIX_FMT_P016: case AV_PIX_FMT_P016:
...@@ -369,6 +371,14 @@ torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) { ...@@ -369,6 +371,14 @@ torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
} }
} }
} // namespace
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
auto [buffer, is_planar] = get_image_buffer(frame, 1, device);
write_image(frame, buffer);
return is_planar ? buffer : buffer.permute({0, 3, 1, 2});
}
} // namespace detail } // namespace detail
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -11,7 +11,6 @@ namespace detail { ...@@ -11,7 +11,6 @@ namespace detail {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
torch::Tensor convert_audio(AVFrame* frame); torch::Tensor convert_audio(AVFrame* frame);
torch::Tensor get_image_buffer(AVFrame* pFrame, const torch::Device& device);
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device); torch::Tensor convert_image(AVFrame* frame, const torch::Device& device);
} // namespace detail } // namespace detail
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace detail {
UnchunkedVideoBuffer::UnchunkedVideoBuffer(const torch::Device& device) UnchunkedVideoBuffer::UnchunkedVideoBuffer(const torch::Device& device)
: device(device) {} : device(device) {}
...@@ -16,11 +17,11 @@ void UnchunkedBuffer::push_tensor(const torch::Tensor& t) { ...@@ -16,11 +17,11 @@ void UnchunkedBuffer::push_tensor(const torch::Tensor& t) {
} }
void UnchunkedAudioBuffer::push_frame(AVFrame* frame) { void UnchunkedAudioBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_audio(frame)); push_tensor(convert_audio(frame));
} }
void UnchunkedVideoBuffer::push_frame(AVFrame* frame) { void UnchunkedVideoBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_image(frame, device)); push_tensor(convert_image(frame, device));
} }
c10::optional<torch::Tensor> UnchunkedBuffer::pop_chunk() { c10::optional<torch::Tensor> UnchunkedBuffer::pop_chunk() {
...@@ -38,5 +39,6 @@ void UnchunkedBuffer::flush() { ...@@ -38,5 +39,6 @@ void UnchunkedBuffer::flush() {
chunks.clear(); chunks.clear();
} }
} // namespace detail
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace detail {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Unchunked Buffer Interface // Unchunked Buffer Interface
...@@ -39,5 +40,6 @@ class UnchunkedVideoBuffer : public UnchunkedBuffer { ...@@ -39,5 +40,6 @@ class UnchunkedVideoBuffer : public UnchunkedBuffer {
void push_frame(AVFrame* frame) override; void push_frame(AVFrame* frame) override;
}; };
} // namespace detail
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -32,16 +32,16 @@ std::unique_ptr<Buffer> get_buffer( ...@@ -32,16 +32,16 @@ std::unique_ptr<Buffer> get_buffer(
if (frames_per_chunk > 0) { if (frames_per_chunk > 0) {
if (type == AVMEDIA_TYPE_AUDIO) { if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>( return std::unique_ptr<Buffer>(
new ChunkedAudioBuffer(frames_per_chunk, num_chunks)); new detail::ChunkedAudioBuffer(frames_per_chunk, num_chunks));
} else { } else {
return std::unique_ptr<Buffer>( return std::unique_ptr<Buffer>(
new ChunkedVideoBuffer(frames_per_chunk, num_chunks, device)); new detail::ChunkedVideoBuffer(frames_per_chunk, num_chunks, device));
} }
} else { // unchunked mode } else { // unchunked mode
if (type == AVMEDIA_TYPE_AUDIO) { if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(new UnchunkedAudioBuffer()); return std::unique_ptr<Buffer>(new detail::UnchunkedAudioBuffer());
} else { } else {
return std::unique_ptr<Buffer>(new UnchunkedVideoBuffer(device)); return std::unique_ptr<Buffer>(new detail::UnchunkedVideoBuffer(device));
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment