Commit be3bd1ac authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Reduce code duplication in VideoOutputStream (#3108)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3108

- Introduce process_frame method
- De-dupe validation logic

Reviewed By: xiaohui-zhang

Differential Revision: D43632390

fbshipit-source-id: 76b7ca0beb725acf686269c877a62e1256921b28
parent 46fae2fe
...@@ -81,6 +81,13 @@ void validate_video_input( ...@@ -81,6 +81,13 @@ void validate_video_input(
enum AVPixelFormat fmt, enum AVPixelFormat fmt,
AVCodecContext* ctx, AVCodecContext* ctx,
const torch::Tensor& t) { const torch::Tensor& t) {
if (fmt == AV_PIX_FMT_CUDA) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
fmt = ctx->sw_pix_fmt;
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
}
auto dtype = t.dtype().toScalarType(); auto dtype = t.dtype().toScalarType();
TORCH_CHECK(dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type."); TORCH_CHECK(dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D."); TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D.");
...@@ -136,9 +143,7 @@ void write_interlaced_video_cuda( ...@@ -136,9 +143,7 @@ void write_interlaced_video_cuda(
cudaMemcpyDeviceToDevice)) { cudaMemcpyDeviceToDevice)) {
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor."); TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
} }
os.src_frame->pts = os.num_frames; os.process_frame();
os.num_frames += 1;
os.process_frame(os.src_frame);
} }
} }
...@@ -167,9 +172,7 @@ void write_planar_video_cuda( ...@@ -167,9 +172,7 @@ void write_planar_video_cuda(
TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor."); TORCH_CHECK(false, "Failed to copy pixel data from CUDA tensor.");
} }
} }
os.src_frame->pts = os.num_frames; os.process_frame();
os.num_frames += 1;
os.process_frame(os.src_frame);
} }
} }
#endif #endif
...@@ -213,10 +216,7 @@ void write_interlaced_video( ...@@ -213,10 +216,7 @@ void write_interlaced_video(
src += width * num_channels; src += width * num_channels;
dst += os.src_frame->linesize[0]; dst += os.src_frame->linesize[0];
} }
os.src_frame->pts = os.num_frames; os.process_frame();
os.num_frames += 1;
os.process_frame(os.src_frame);
} }
} }
...@@ -268,10 +268,7 @@ void write_planar_video( ...@@ -268,10 +268,7 @@ void write_planar_video(
dst += os.src_frame->linesize[j]; dst += os.src_frame->linesize[j];
} }
} }
os.src_frame->pts = os.num_frames; os.process_frame();
os.num_frames += 1;
os.process_frame(os.src_frame);
} }
} }
...@@ -279,13 +276,12 @@ void write_planar_video( ...@@ -279,13 +276,12 @@ void write_planar_video(
void VideoOutputStream::write_chunk(const torch::Tensor& frames) { void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
enum AVPixelFormat fmt = static_cast<AVPixelFormat>(src_frame->format); enum AVPixelFormat fmt = static_cast<AVPixelFormat>(src_frame->format);
validate_video_input(fmt, codec_ctx, frames);
#ifdef USE_CUDA #ifdef USE_CUDA
if (fmt == AV_PIX_FMT_CUDA) { if (fmt == AV_PIX_FMT_CUDA) {
TORCH_CHECK(frames.device().is_cuda(), "Input tensor has to be on CUDA."); fmt = codec_ctx->sw_pix_fmt;
enum AVPixelFormat sw_fmt = codec_ctx->sw_pix_fmt; switch (fmt) {
validate_video_input(sw_fmt, codec_ctx, frames);
switch (sw_fmt) {
case AV_PIX_FMT_RGB0: case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: case AV_PIX_FMT_BGR0:
write_interlaced_video_cuda(*this, frames, true); write_interlaced_video_cuda(*this, frames, true);
...@@ -294,19 +290,17 @@ void VideoOutputStream::write_chunk(const torch::Tensor& frames) { ...@@ -294,19 +290,17 @@ void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
case AV_PIX_FMT_GBRP16LE: case AV_PIX_FMT_GBRP16LE:
case AV_PIX_FMT_YUV444P: case AV_PIX_FMT_YUV444P:
case AV_PIX_FMT_YUV444P16LE: case AV_PIX_FMT_YUV444P16LE:
write_planar_video_cuda(*this, frames, av_pix_fmt_count_planes(sw_fmt)); write_planar_video_cuda(*this, frames, av_pix_fmt_count_planes(fmt));
return; return;
default: default:
TORCH_CHECK( TORCH_CHECK(
false, false,
"Unexpected pixel format for CUDA: ", "Unexpected pixel format for CUDA: ",
av_get_pix_fmt_name(sw_fmt)); av_get_pix_fmt_name(fmt));
} }
} }
#endif #endif
TORCH_CHECK(frames.device().is_cpu(), "Input tensor has to be on CPU.");
validate_video_input(fmt, codec_ctx, frames);
switch (fmt) { switch (fmt) {
case AV_PIX_FMT_GRAY8: case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24: case AV_PIX_FMT_RGB24:
...@@ -321,4 +315,10 @@ void VideoOutputStream::write_chunk(const torch::Tensor& frames) { ...@@ -321,4 +315,10 @@ void VideoOutputStream::write_chunk(const torch::Tensor& frames) {
} }
} }
void VideoOutputStream::process_frame() {
src_frame->pts = num_frames;
num_frames += 1;
OutputStream::process_frame(src_frame);
}
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -19,6 +19,8 @@ struct VideoOutputStream : OutputStream { ...@@ -19,6 +19,8 @@ struct VideoOutputStream : OutputStream {
const torch::Device& device); const torch::Device& device);
void write_chunk(const torch::Tensor& frames) override; void write_chunk(const torch::Tensor& frames) override;
void process_frame();
~VideoOutputStream() override = default; ~VideoOutputStream() override = default;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment