Commit af493e4e authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Remove redundant device arg from VideoOutputStream constructor (#3121)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3121

After careful review, it turned out device arg in VideoOutputStream
constructor and related helper functions can be replaced with
AVCodecContext::pix_fmt == AV_PIX_FMT_CUDA.

Reviewed By: xiaohui-zhang

Differential Revision: D43677801

fbshipit-source-id: f8f34f1aed46e223b44250d39cccc4cd26ecb458
parent 2381beec
......@@ -348,14 +348,24 @@ AVCodecContextPtr get_video_codec(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const torch::Device device,
const c10::optional<std::string>& hw_accel,
AVBufferRefPtr& hw_device_ctx,
AVBufferRefPtr& hw_frame_ctx) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format);
#ifdef USE_CUDA
if (device.type() == c10::DeviceType::CUDA) {
if (hw_accel) {
#ifndef USE_CUDA
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#else
torch::Device device{hw_accel.value()};
TORCH_CHECK(
device.type() == c10::DeviceType::CUDA,
"Only CUDA is supported for hardware acceleration. Found: ",
device.str());
AVBufferRef* device_ctx = nullptr;
int ret = av_hwdevice_ctx_create(
&device_ctx,
......@@ -391,8 +401,8 @@ AVCodecContextPtr get_video_codec(
ctx->hw_frames_ctx,
"Failed to attach CUDA frames to encoding context: ",
av_err2string(ret));
}
#endif
}
open_codec(ctx, encoder_option);
return ctx;
......@@ -466,24 +476,6 @@ void StreamWriter::add_video_stream(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) {
const torch::Device device = [&]() {
if (!hw_accel) {
return torch::Device{c10::DeviceType::CPU};
}
#ifdef USE_CUDA
torch::Device d{hw_accel.value()};
TORCH_CHECK(
d.type() == c10::DeviceType::CUDA,
"Only CUDA is supported for hardware acceleration. Found:",
device.str());
return d;
#else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}();
AVBufferRefPtr hw_device_ctx{};
AVBufferRefPtr hw_frame_ctx{};
......@@ -495,7 +487,7 @@ void StreamWriter::add_video_stream(
encoder,
encoder_option,
encoder_format,
device,
hw_accel,
hw_device_ctx,
hw_frame_ctx);
......@@ -504,8 +496,7 @@ void StreamWriter::add_video_stream(
get_src_pixel_fmt(format),
std::move(ctx),
std::move(hw_device_ctx),
std::move(hw_frame_ctx),
device));
std::move(hw_frame_ctx)));
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
......
......@@ -8,13 +8,10 @@ namespace torchaudio::io {
namespace {
FilterGraph get_video_filter(
AVPixelFormat src_fmt,
AVCodecContext* codec_ctx,
const torch::Device& device) {
FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->pix_fmt ||
device.type() != c10::DeviceType::CPU) {
codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
return "null";
} else {
std::stringstream ss;
......@@ -36,29 +33,23 @@ FilterGraph get_video_filter(
return p;
}
AVFramePtr get_hw_video_frame(AVCodecContext* codec_ctx) {
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
AVFramePtr frame{};
if (codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
return frame;
}
AVFramePtr get_video_frame(
AVPixelFormat src_fmt,
AVCodecContext* codec_ctx,
const torch::Device& device) {
if (device.type() == c10::DeviceType::CUDA) {
return get_hw_video_frame(codec_ctx);
}
AVFramePtr frame{};
} else {
frame->format = src_fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0, "Error allocating a video buffer (", av_err2string(ret), ").");
ret >= 0,
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
return frame;
}
......@@ -69,13 +60,12 @@ VideoOutputStream::VideoOutputStream(
AVPixelFormat src_fmt,
AVCodecContextPtr&& codec_ctx_,
AVBufferRefPtr&& hw_device_ctx_,
AVBufferRefPtr&& hw_frame_ctx_,
const torch::Device& device)
AVBufferRefPtr&& hw_frame_ctx_)
: OutputStream(
format_ctx,
codec_ctx_,
get_video_filter(src_fmt, codec_ctx_, device)),
src_frame(get_video_frame(src_fmt, codec_ctx_, device)),
get_video_filter(src_fmt, codec_ctx_)),
src_frame(get_video_frame(src_fmt, codec_ctx_)),
hw_device_ctx(std::move(hw_device_ctx_)),
hw_frame_ctx(std::move(hw_frame_ctx_)),
codec_ctx(std::move(codec_ctx_)) {}
......
......@@ -15,8 +15,7 @@ struct VideoOutputStream : OutputStream {
AVPixelFormat src_fmt,
AVCodecContextPtr&& codec_ctx,
AVBufferRefPtr&& hw_device_ctx,
AVBufferRefPtr&& hw_frame_ctx,
const torch::Device& device);
AVBufferRefPtr&& hw_frame_ctx);
void write_chunk(const torch::Tensor& frames) override;
void process_frame();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment