".github/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "25e96f4246115a3deeec0afbd4ac52c47c0fa934"
Commit e259f156 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor buffer common utils (#2988)

Summary:
Split `convert_video` into memory allocation function and write function.

Also put all the buffer implementations into detail namespace.

Pull Request resolved: https://github.com/pytorch/audio/pull/2988

Reviewed By: xiaohui-zhang

Differential Revision: D42536769

Pulled By: mthrok

fbshipit-source-id: 36fbf437d4bfd521322846161ae08a48c782c540
parent f9d38796
......@@ -3,6 +3,7 @@
namespace torchaudio {
namespace ffmpeg {
namespace detail {
ChunkedBuffer::ChunkedBuffer(int frames_per_chunk, int num_chunks)
: frames_per_chunk(frames_per_chunk), num_chunks(num_chunks) {}
......@@ -120,16 +121,17 @@ c10::optional<torch::Tensor> ChunkedBuffer::pop_chunk() {
}
void ChunkedAudioBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_audio(frame));
push_tensor(convert_audio(frame));
}
void ChunkedVideoBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_image(frame, device));
push_tensor(convert_image(frame, device));
}
void ChunkedBuffer::flush() {
chunks.clear();
}
} // namespace detail
} // namespace ffmpeg
} // namespace torchaudio
......@@ -4,6 +4,7 @@
namespace torchaudio {
namespace ffmpeg {
namespace detail {
//////////////////////////////////////////////////////////////////////////////
// Chunked Buffer Implementation
......@@ -53,5 +54,6 @@ class ChunkedVideoBuffer : public ChunkedBuffer {
void push_frame(AVFrame* frame) override;
};
} // namespace detail
} // namespace ffmpeg
} // namespace torchaudio
......@@ -95,9 +95,11 @@ torch::Tensor get_buffer(
.device(device.type(), device.index());
return torch::empty(shape, options);
}
} // namespace
torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) {
std::tuple<torch::Tensor, bool> get_image_buffer(
AVFrame* frame,
int num_frames,
const torch::Device& device) {
auto fmt = static_cast<AVPixelFormat>(frame->format);
const AVPixFmtDescriptor* desc = [&]() {
if (fmt == AV_PIX_FMT_CUDA) {
......@@ -124,13 +126,16 @@ torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) {
// The actual number of planes can be retrieved with
// av_pix_fmt_count_planes.
int height = frame->height;
int width = frame->width;
if (desc->flags & AV_PIX_FMT_FLAG_PLANAR) {
return get_buffer({1, channels, frame->height, frame->width}, device);
auto buffer = get_buffer({num_frames, channels, height, width}, device);
return std::make_tuple(buffer, true);
}
return get_buffer({1, frame->height, frame->width, channels}, device);
auto buffer = get_buffer({num_frames, height, width, channels}, device);
return std::make_tuple(buffer, false);
}
namespace {
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
auto ptr = frame.data_ptr<uint8_t>();
uint8_t* buf = pFrame->data[0];
......@@ -304,14 +309,11 @@ void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) {
}
#endif
} // namespace
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
void write_image(AVFrame* frame, torch::Tensor& buf) {
// ref:
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
// https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
AVPixelFormat format = static_cast<AVPixelFormat>(frame->format);
torch::Tensor buf = get_image_buffer(frame, device);
switch (format) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
......@@ -321,19 +323,19 @@ torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
case AV_PIX_FMT_BGRA:
case AV_PIX_FMT_GRAY8: {
write_interlaced_image(frame, buf);
return buf.permute({0, 3, 1, 2});
return;
}
case AV_PIX_FMT_YUV444P: {
write_planar_image(frame, buf);
return buf;
return;
}
case AV_PIX_FMT_YUV420P: {
write_yuv420p(frame, buf);
return buf;
return;
}
case AV_PIX_FMT_NV12: {
write_nv12_cpu(frame, buf);
return buf;
return;
}
#ifdef USE_CUDA
case AV_PIX_FMT_CUDA: {
......@@ -345,7 +347,7 @@ torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
switch (sw_format) {
case AV_PIX_FMT_NV12: {
write_nv12_cuda(frame, buf);
return buf;
return;
}
case AV_PIX_FMT_P010:
case AV_PIX_FMT_P016:
......@@ -369,6 +371,14 @@ torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
}
}
} // namespace
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
auto [buffer, is_planar] = get_image_buffer(frame, 1, device);
write_image(frame, buffer);
return is_planar ? buffer : buffer.permute({0, 3, 1, 2});
}
} // namespace detail
} // namespace ffmpeg
} // namespace torchaudio
......@@ -11,7 +11,6 @@ namespace detail {
//////////////////////////////////////////////////////////////////////////////
torch::Tensor convert_audio(AVFrame* frame);
torch::Tensor get_image_buffer(AVFrame* pFrame, const torch::Device& device);
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device);
} // namespace detail
......
......@@ -3,6 +3,7 @@
namespace torchaudio {
namespace ffmpeg {
namespace detail {
UnchunkedVideoBuffer::UnchunkedVideoBuffer(const torch::Device& device)
: device(device) {}
......@@ -16,11 +17,11 @@ void UnchunkedBuffer::push_tensor(const torch::Tensor& t) {
}
void UnchunkedAudioBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_audio(frame));
push_tensor(convert_audio(frame));
}
void UnchunkedVideoBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_image(frame, device));
push_tensor(convert_image(frame, device));
}
c10::optional<torch::Tensor> UnchunkedBuffer::pop_chunk() {
......@@ -38,5 +39,6 @@ void UnchunkedBuffer::flush() {
chunks.clear();
}
} // namespace detail
} // namespace ffmpeg
} // namespace torchaudio
......@@ -6,6 +6,7 @@
namespace torchaudio {
namespace ffmpeg {
namespace detail {
//////////////////////////////////////////////////////////////////////////////
// Unchunked Buffer Interface
......@@ -39,5 +40,6 @@ class UnchunkedVideoBuffer : public UnchunkedBuffer {
void push_frame(AVFrame* frame) override;
};
} // namespace detail
} // namespace ffmpeg
} // namespace torchaudio
......@@ -32,16 +32,16 @@ std::unique_ptr<Buffer> get_buffer(
if (frames_per_chunk > 0) {
if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(
new ChunkedAudioBuffer(frames_per_chunk, num_chunks));
new detail::ChunkedAudioBuffer(frames_per_chunk, num_chunks));
} else {
return std::unique_ptr<Buffer>(
new ChunkedVideoBuffer(frames_per_chunk, num_chunks, device));
new detail::ChunkedVideoBuffer(frames_per_chunk, num_chunks, device));
}
} else { // unchunked mode
if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(new UnchunkedAudioBuffer());
return std::unique_ptr<Buffer>(new detail::UnchunkedAudioBuffer());
} else {
return std::unique_ptr<Buffer>(new UnchunkedVideoBuffer(device));
return std::unique_ptr<Buffer>(new detail::UnchunkedVideoBuffer(device));
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment