Commit 2381beec authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Decouple image conversion and OutputStream class (#3113)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3113

Decouple the Tensor to AVFrame conversion process from encoding process.

Reviewed By: nateanl

Differential Revision: D43628942

fbshipit-source-id: e698f3150292567dbc23e7d6795ad58265f24780
parent fd24af00
......@@ -117,70 +117,59 @@ void validate_video_input(
t.sizes());
}
#ifdef USE_CUDA
void write_interlaced_video_cuda(
VideoOutputStream& os,
const torch::Tensor& frames,
const torch::Tensor& chunk,
AVFrame* buffer,
bool pad_extra) {
const auto num_frames = frames.size(0);
const auto num_channels = frames.size(1);
const auto height = frames.size(2);
const auto width = frames.size(3);
const auto num_channels_buffer = num_channels + (pad_extra ? 1 : 0);
using namespace torch::indexing;
torch::Tensor buffer =
torch::empty({height, width, num_channels_buffer}, frames.options());
size_t spitch = width * num_channels_buffer;
for (int i = 0; i < num_frames; ++i) {
// Slice frame as HWC
auto chunk = frames.index({i}).permute({1, 2, 0});
buffer.index_put_({"...", Slice(0, num_channels)}, chunk);
#ifdef USE_CUDA
const auto height = chunk.size(0);
const auto width = chunk.size(1);
const auto num_channels = chunk.size(2) + (pad_extra ? 1 : 0);
size_t spitch = width * num_channels;
if (cudaSuccess !=
cudaMemcpy2D(
(void*)(buffer->data[0]),
buffer->linesize[0],
(const void*)(chunk.data_ptr<uint8_t>()),
spitch,
spitch,
height,
cudaMemcpyDeviceToDevice)) {
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
}
#else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}
void write_planar_video_cuda(
const torch::Tensor& chunk,
AVFrame* buffer,
int num_planes) {
#ifdef USE_CUDA
const auto height = chunk.size(1);
const auto width = chunk.size(2);
for (int j = 0; j < num_planes; ++j) {
if (cudaSuccess !=
cudaMemcpy2D(
(void*)(os.src_frame->data[0]),
os.src_frame->linesize[0],
(const void*)(buffer.data_ptr<uint8_t>()),
spitch,
spitch,
(void*)(buffer->data[j]),
buffer->linesize[j],
(const void*)(chunk.index({j}).data_ptr<uint8_t>()),
width,
width,
height,
cudaMemcpyDeviceToDevice)) {
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
}
os.process_frame();
}
}
void write_planar_video_cuda(
VideoOutputStream& os,
const torch::Tensor& frames,
int num_planes) {
const auto num_frames = frames.size(0);
const auto height = frames.size(2);
const auto width = frames.size(3);
using namespace torch::indexing;
torch::Tensor buffer = torch::empty({height, width}, frames.options());
for (int i = 0; i < num_frames; ++i) {
for (int j = 0; j < num_planes; ++j) {
buffer.index_put_({"..."}, frames.index({i, j}));
if (cudaSuccess !=
cudaMemcpy2D(
(void*)(os.src_frame->data[j]),
os.src_frame->linesize[j],
(const void*)(buffer.data_ptr<uint8_t>()),
width,
width,
height,
cudaMemcpyDeviceToDevice)) {
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
}
}
os.process_frame();
}
}
#else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}
// Interlaced video
// Each frame is composed of one plane, and color components for each pixel are
......@@ -193,35 +182,22 @@ void write_planar_video_cuda(
// 1: RGB RGB ... RGB PAD ... PAD
// ...
// H: RGB RGB ... RGB PAD ... PAD
void write_interlaced_video(
VideoOutputStream& os,
const torch::Tensor& frames) {
const auto num_frames = frames.size(0);
const auto num_channels = frames.size(1);
const auto height = frames.size(2);
const auto width = frames.size(3);
void write_interlaced_video(const torch::Tensor& chunk, AVFrame* buffer) {
const auto height = chunk.size(0);
const auto width = chunk.size(1);
const auto num_channels = chunk.size(2);
using namespace torch::indexing;
size_t stride = width * num_channels;
for (int i = 0; i < num_frames; ++i) {
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_CHECK(
av_frame_is_writable(os.src_frame),
"Internal Error: frame is not writable.");
// CHW -> HWC
auto chunk =
frames.index({i}).permute({1, 2, 0}).reshape({-1}).contiguous();
uint8_t* src = chunk.data_ptr<uint8_t>();
uint8_t* dst = os.src_frame->data[0];
for (int h = 0; h < height; ++h) {
std::memcpy(dst, src, stride);
src += width * num_channels;
dst += os.src_frame->linesize[0];
}
os.process_frame();
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
uint8_t* src = chunk.data_ptr<uint8_t>();
uint8_t* dst = buffer->data[0];
for (int h = 0; h < height; ++h) {
std::memcpy(dst, src, stride);
src += width * num_channels;
dst += buffer->linesize[0];
}
}
......@@ -247,33 +223,24 @@ void write_interlaced_video(
// H2: UV ... UV PAD ... PAD
//
void write_planar_video(
VideoOutputStream& os,
const torch::Tensor& frames,
const torch::Tensor& chunk,
AVFrame* buffer,
int num_planes) {
const auto num_frames = frames.size(0);
const auto height = frames.size(2);
const auto width = frames.size(3);
using namespace torch::indexing;
for (int i = 0; i < num_frames; ++i) {
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_CHECK(
av_frame_is_writable(os.src_frame),
"Internal Error: frame is not writable.");
const auto height = chunk.size(1);
const auto width = chunk.size(2);
for (int j = 0; j < num_planes; ++j) {
auto chunk = frames.index({i, j}).contiguous();
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
uint8_t* src = chunk.data_ptr<uint8_t>();
uint8_t* dst = os.src_frame->data[j];
for (int h = 0; h < height; ++h) {
memcpy(dst, src, width);
src += width;
dst += os.src_frame->linesize[j];
}
for (int j = 0; j < num_planes; ++j) {
uint8_t* src = chunk.index({j}).data_ptr<uint8_t>();
uint8_t* dst = buffer->data[j];
for (int h = 0; h < height; ++h) {
memcpy(dst, src, width);
src += width;
dst += buffer->linesize[j];
}
os.process_frame();
}
}
......@@ -282,21 +249,33 @@ void write_planar_video(
void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
enum AVPixelFormat fmt = static_cast<AVPixelFormat>(src_frame->format);
validate_video_input(fmt, codec_ctx, frames);
const auto num_frames = frames.size(0);
#ifdef USE_CUDA
if (fmt == AV_PIX_FMT_CUDA) {
fmt = codec_ctx->sw_pix_fmt;
switch (fmt) {
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0:
write_interlaced_video_cuda(*this, frames, true);
case AV_PIX_FMT_BGR0: {
auto chunks = frames.permute({0, 2, 3, 1}).contiguous(); // to NHWC
for (int i = 0; i < num_frames; ++i) {
write_interlaced_video_cuda(chunks.index({i}), src_frame, true);
process_frame();
}
return;
}
case AV_PIX_FMT_GBRP:
case AV_PIX_FMT_GBRP16LE:
case AV_PIX_FMT_YUV444P:
case AV_PIX_FMT_YUV444P16LE:
write_planar_video_cuda(*this, frames, av_pix_fmt_count_planes(fmt));
case AV_PIX_FMT_YUV444P16LE: {
auto chunks = frames.contiguous();
for (int i = 0; i < num_frames; ++i) {
write_planar_video_cuda(
chunks.index({i}), src_frame, av_pix_fmt_count_planes(fmt));
process_frame();
}
return;
}
default:
TORCH_CHECK(
false,
......@@ -309,12 +288,23 @@ void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
write_interlaced_video(*this, frames);
case AV_PIX_FMT_BGR24: {
auto chunks = frames.permute({0, 2, 3, 1}).contiguous();
for (int i = 0; i < num_frames; ++i) {
write_interlaced_video(chunks.index({i}), src_frame);
process_frame();
}
return;
case AV_PIX_FMT_YUV444P:
write_planar_video(*this, frames, av_pix_fmt_count_planes(fmt));
}
case AV_PIX_FMT_YUV444P: {
auto chunks = frames.contiguous();
for (int i = 0; i < num_frames; ++i) {
write_planar_video(
chunks.index({i}), src_frame, av_pix_fmt_count_planes(fmt));
process_frame();
}
return;
}
default:
TORCH_CHECK(false, "Unexpected pixel format: ", av_get_pix_fmt_name(fmt));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment