Commit be3bd1ac authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Reduce code duplication in VideoOutputStream (#3108)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3108

- Introduce process_frame method
- De-dupe validation logic

Reviewed By: xiaohui-zhang

Differential Revision: D43632390

fbshipit-source-id: 76b7ca0beb725acf686269c877a62e1256921b28
parent 46fae2fe
......@@ -81,6 +81,13 @@ void validate_video_input(
enum AVPixelFormat fmt,
AVCodecContext* ctx,
const torch::Tensor& t) {
if (fmt == AV_PIX_FMT_CUDA) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
fmt = ctx->sw_pix_fmt;
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
}
auto dtype = t.dtype().toScalarType();
TORCH_CHECK(dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D.");
......@@ -136,9 +143,7 @@ void write_interlaced_video_cuda(
cudaMemcpyDeviceToDevice)) {
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
}
os.src_frame->pts = os.num_frames;
os.num_frames += 1;
os.process_frame(os.src_frame);
os.process_frame();
}
}
......@@ -167,9 +172,7 @@ void write_planar_video_cuda(
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
}
}
os.src_frame->pts = os.num_frames;
os.num_frames += 1;
os.process_frame(os.src_frame);
os.process_frame();
}
}
#endif
......@@ -213,10 +216,7 @@ void write_interlaced_video(
src += width * num_channels;
dst += os.src_frame->linesize[0];
}
os.src_frame->pts = os.num_frames;
os.num_frames += 1;
os.process_frame(os.src_frame);
os.process_frame();
}
}
......@@ -268,10 +268,7 @@ void write_planar_video(
dst += os.src_frame->linesize[j];
}
}
os.src_frame->pts = os.num_frames;
os.num_frames += 1;
os.process_frame(os.src_frame);
os.process_frame();
}
}
......@@ -279,13 +276,12 @@ void write_planar_video(
void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
enum AVPixelFormat fmt = static_cast<AVPixelFormat>(src_frame->format);
validate_video_input(fmt, codec_ctx, frames);
#ifdef USE_CUDA
if (fmt == AV_PIX_FMT_CUDA) {
TORCH_CHECK(frames.device().is_cuda(), "Input tensor has to be on CUDA.");
enum AVPixelFormat sw_fmt = codec_ctx->sw_pix_fmt;
validate_video_input(sw_fmt, codec_ctx, frames);
switch (sw_fmt) {
fmt = codec_ctx->sw_pix_fmt;
switch (fmt) {
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0:
write_interlaced_video_cuda(*this, frames, true);
......@@ -294,19 +290,17 @@ void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
case AV_PIX_FMT_GBRP16LE:
case AV_PIX_FMT_YUV444P:
case AV_PIX_FMT_YUV444P16LE:
write_planar_video_cuda(*this, frames, av_pix_fmt_count_planes(sw_fmt));
write_planar_video_cuda(*this, frames, av_pix_fmt_count_planes(fmt));
return;
default:
TORCH_CHECK(
false,
"Unexpected pixel format for CUDA: ",
av_get_pix_fmt_name(sw_fmt));
av_get_pix_fmt_name(fmt));
}
}
#endif
TORCH_CHECK(frames.device().is_cpu(), "Input tensor has to be on CPU.");
validate_video_input(fmt, codec_ctx, frames);
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
......@@ -321,4 +315,10 @@ void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
}
}
void VideoOutputStream::process_frame() {
src_frame->pts = num_frames;
num_frames += 1;
OutputStream::process_frame(src_frame);
}
} // namespace torchaudio::io
......@@ -19,6 +19,8 @@ struct VideoOutputStream : OutputStream {
const torch::Device& device);
void write_chunk(const torch::Tensor& frames) override;
void process_frame();
~VideoOutputStream() override = default;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment