Commit 2381beec authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Decouple image conversion and OutputStream class (#3113)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3113

Decouple the Tensor to AVFrame conversion process from encoding process.

Reviewed By: nateanl

Differential Revision: D43628942

fbshipit-source-id: e698f3150292567dbc23e7d6795ad58265f24780
parent fd24af00
...@@ -117,59 +117,46 @@ void validate_video_input( ...@@ -117,59 +117,46 @@ void validate_video_input(
t.sizes()); t.sizes());
} }
#ifdef USE_CUDA
void write_interlaced_video_cuda( void write_interlaced_video_cuda(
VideoOutputStream& os, const torch::Tensor& chunk,
const torch::Tensor& frames, AVFrame* buffer,
bool pad_extra) { bool pad_extra) {
const auto num_frames = frames.size(0); #ifdef USE_CUDA
const auto num_channels = frames.size(1); const auto height = chunk.size(0);
const auto height = frames.size(2); const auto width = chunk.size(1);
const auto width = frames.size(3); const auto num_channels = chunk.size(2) + (pad_extra ? 1 : 0);
const auto num_channels_buffer = num_channels + (pad_extra ? 1 : 0); size_t spitch = width * num_channels;
using namespace torch::indexing;
torch::Tensor buffer =
torch::empty({height, width, num_channels_buffer}, frames.options());
size_t spitch = width * num_channels_buffer;
for (int i = 0; i < num_frames; ++i) {
// Slice frame as HWC
auto chunk = frames.index({i}).permute({1, 2, 0});
buffer.index_put_({"...", Slice(0, num_channels)}, chunk);
if (cudaSuccess != if (cudaSuccess !=
cudaMemcpy2D( cudaMemcpy2D(
(void*)(os.src_frame->data[0]), (void*)(buffer->data[0]),
os.src_frame->linesize[0], buffer->linesize[0],
(const void*)(buffer.data_ptr<uint8_t>()), (const void*)(chunk.data_ptr<uint8_t>()),
spitch, spitch,
spitch, spitch,
height, height,
cudaMemcpyDeviceToDevice)) { cudaMemcpyDeviceToDevice)) {
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor."); TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
} }
os.process_frame(); #else
} TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
} }
void write_planar_video_cuda( void write_planar_video_cuda(
VideoOutputStream& os, const torch::Tensor& chunk,
const torch::Tensor& frames, AVFrame* buffer,
int num_planes) { int num_planes) {
const auto num_frames = frames.size(0); #ifdef USE_CUDA
const auto height = frames.size(2); const auto height = chunk.size(1);
const auto width = frames.size(3); const auto width = chunk.size(2);
using namespace torch::indexing;
torch::Tensor buffer = torch::empty({height, width}, frames.options());
for (int i = 0; i < num_frames; ++i) {
for (int j = 0; j < num_planes; ++j) { for (int j = 0; j < num_planes; ++j) {
buffer.index_put_({"..."}, frames.index({i, j}));
if (cudaSuccess != if (cudaSuccess !=
cudaMemcpy2D( cudaMemcpy2D(
(void*)(os.src_frame->data[j]), (void*)(buffer->data[j]),
os.src_frame->linesize[j], buffer->linesize[j],
(const void*)(buffer.data_ptr<uint8_t>()), (const void*)(chunk.index({j}).data_ptr<uint8_t>()),
width, width,
width, width,
height, height,
...@@ -177,10 +164,12 @@ void write_planar_video_cuda( ...@@ -177,10 +164,12 @@ void write_planar_video_cuda(
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor."); TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
} }
} }
os.process_frame(); #else
} TORCH_CHECK(
} false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif #endif
}
// Interlaced video // Interlaced video
// Each frame is composed of one plane, and color components for each pixel are // Each frame is composed of one plane, and color components for each pixel are
...@@ -193,35 +182,22 @@ void write_planar_video_cuda( ...@@ -193,35 +182,22 @@ void write_planar_video_cuda(
// 1: RGB RGB ... RGB PAD ... PAD // 1: RGB RGB ... RGB PAD ... PAD
// ... // ...
// H: RGB RGB ... RGB PAD ... PAD // H: RGB RGB ... RGB PAD ... PAD
void write_interlaced_video( void write_interlaced_video(const torch::Tensor& chunk, AVFrame* buffer) {
VideoOutputStream& os, const auto height = chunk.size(0);
const torch::Tensor& frames) { const auto width = chunk.size(1);
const auto num_frames = frames.size(0); const auto num_channels = chunk.size(2);
const auto num_channels = frames.size(1);
const auto height = frames.size(2);
const auto width = frames.size(3);
using namespace torch::indexing;
size_t stride = width * num_channels; size_t stride = width * num_channels;
for (int i = 0; i < num_frames; ++i) {
// TODO: writable // TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472 // https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_CHECK( TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
av_frame_is_writable(os.src_frame),
"Internal Error: frame is not writable.");
// CHW -> HWC
auto chunk =
frames.index({i}).permute({1, 2, 0}).reshape({-1}).contiguous();
uint8_t* src = chunk.data_ptr<uint8_t>(); uint8_t* src = chunk.data_ptr<uint8_t>();
uint8_t* dst = os.src_frame->data[0]; uint8_t* dst = buffer->data[0];
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
std::memcpy(dst, src, stride); std::memcpy(dst, src, stride);
src += width * num_channels; src += width * num_channels;
dst += os.src_frame->linesize[0]; dst += buffer->linesize[0];
}
os.process_frame();
} }
} }
...@@ -247,33 +223,24 @@ void write_interlaced_video( ...@@ -247,33 +223,24 @@ void write_interlaced_video(
// H2: UV ... UV PAD ... PAD // H2: UV ... UV PAD ... PAD
// //
void write_planar_video( void write_planar_video(
VideoOutputStream& os, const torch::Tensor& chunk,
const torch::Tensor& frames, AVFrame* buffer,
int num_planes) { int num_planes) {
const auto num_frames = frames.size(0); const auto height = chunk.size(1);
const auto height = frames.size(2); const auto width = chunk.size(2);
const auto width = frames.size(3);
using namespace torch::indexing;
for (int i = 0; i < num_frames; ++i) {
// TODO: writable // TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472 // https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_CHECK( TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
av_frame_is_writable(os.src_frame),
"Internal Error: frame is not writable.");
for (int j = 0; j < num_planes; ++j) { for (int j = 0; j < num_planes; ++j) {
auto chunk = frames.index({i, j}).contiguous(); uint8_t* src = chunk.index({j}).data_ptr<uint8_t>();
uint8_t* dst = buffer->data[j];
uint8_t* src = chunk.data_ptr<uint8_t>();
uint8_t* dst = os.src_frame->data[j];
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
memcpy(dst, src, width); memcpy(dst, src, width);
src += width; src += width;
dst += os.src_frame->linesize[j]; dst += buffer->linesize[j];
}
} }
os.process_frame();
} }
} }
...@@ -282,21 +249,33 @@ void write_planar_video( ...@@ -282,21 +249,33 @@ void write_planar_video(
void VideoOutputStream::write_chunk(const torch::Tensor& frames) { void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
enum AVPixelFormat fmt = static_cast<AVPixelFormat>(src_frame->format); enum AVPixelFormat fmt = static_cast<AVPixelFormat>(src_frame->format);
validate_video_input(fmt, codec_ctx, frames); validate_video_input(fmt, codec_ctx, frames);
const auto num_frames = frames.size(0);
#ifdef USE_CUDA #ifdef USE_CUDA
if (fmt == AV_PIX_FMT_CUDA) { if (fmt == AV_PIX_FMT_CUDA) {
fmt = codec_ctx->sw_pix_fmt; fmt = codec_ctx->sw_pix_fmt;
switch (fmt) { switch (fmt) {
case AV_PIX_FMT_RGB0: case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: case AV_PIX_FMT_BGR0: {
write_interlaced_video_cuda(*this, frames, true); auto chunks = frames.permute({0, 2, 3, 1}).contiguous(); // to NHWC
for (int i = 0; i < num_frames; ++i) {
write_interlaced_video_cuda(chunks.index({i}), src_frame, true);
process_frame();
}
return; return;
}
case AV_PIX_FMT_GBRP: case AV_PIX_FMT_GBRP:
case AV_PIX_FMT_GBRP16LE: case AV_PIX_FMT_GBRP16LE:
case AV_PIX_FMT_YUV444P: case AV_PIX_FMT_YUV444P:
case AV_PIX_FMT_YUV444P16LE: case AV_PIX_FMT_YUV444P16LE: {
write_planar_video_cuda(*this, frames, av_pix_fmt_count_planes(fmt)); auto chunks = frames.contiguous();
for (int i = 0; i < num_frames; ++i) {
write_planar_video_cuda(
chunks.index({i}), src_frame, av_pix_fmt_count_planes(fmt));
process_frame();
}
return; return;
}
default: default:
TORCH_CHECK( TORCH_CHECK(
false, false,
...@@ -309,12 +288,23 @@ void VideoOutputStream::write_chunk(const torch::Tensor& frames) { ...@@ -309,12 +288,23 @@ void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
switch (fmt) { switch (fmt) {
case AV_PIX_FMT_GRAY8: case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24: case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: case AV_PIX_FMT_BGR24: {
write_interlaced_video(*this, frames); auto chunks = frames.permute({0, 2, 3, 1}).contiguous();
for (int i = 0; i < num_frames; ++i) {
write_interlaced_video(chunks.index({i}), src_frame);
process_frame();
}
return; return;
case AV_PIX_FMT_YUV444P: }
write_planar_video(*this, frames, av_pix_fmt_count_planes(fmt)); case AV_PIX_FMT_YUV444P: {
auto chunks = frames.contiguous();
for (int i = 0; i < num_frames; ++i) {
write_planar_video(
chunks.index({i}), src_frame, av_pix_fmt_count_planes(fmt));
process_frame();
}
return; return;
}
default: default:
TORCH_CHECK(false, "Unexpected pixel format: ", av_get_pix_fmt_name(fmt)); TORCH_CHECK(false, "Unexpected pixel format: ", av_get_pix_fmt_name(fmt));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment