Commit af493e4e authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Remove redundant device arg from VideoOutputStream constructor (#3121)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3121

After careful review, it turned out device arg in VideoOutputStream
constructor and related helper functions can be replaced with
AVCodecContext::pix_fmt == AV_PIX_FMT_CUDA.

Reviewed By: xiaohui-zhang

Differential Revision: D43677801

fbshipit-source-id: f8f34f1aed46e223b44250d39cccc4cd26ecb458
parent 2381beec
...@@ -348,14 +348,24 @@ AVCodecContextPtr get_video_codec( ...@@ -348,14 +348,24 @@ AVCodecContextPtr get_video_codec(
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const torch::Device device, const c10::optional<std::string>& hw_accel,
AVBufferRefPtr& hw_device_ctx, AVBufferRefPtr& hw_device_ctx,
AVBufferRefPtr& hw_frame_ctx) { AVBufferRefPtr& hw_frame_ctx) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder); AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format); configure_video_codec(ctx, frame_rate, width, height, encoder_format);
#ifdef USE_CUDA if (hw_accel) {
if (device.type() == c10::DeviceType::CUDA) { #ifndef USE_CUDA
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#else
torch::Device device{hw_accel.value()};
TORCH_CHECK(
device.type() == c10::DeviceType::CUDA,
"Only CUDA is supported for hardware acceleration. Found: ",
device.str());
AVBufferRef* device_ctx = nullptr; AVBufferRef* device_ctx = nullptr;
int ret = av_hwdevice_ctx_create( int ret = av_hwdevice_ctx_create(
&device_ctx, &device_ctx,
...@@ -391,8 +401,8 @@ AVCodecContextPtr get_video_codec( ...@@ -391,8 +401,8 @@ AVCodecContextPtr get_video_codec(
ctx->hw_frames_ctx, ctx->hw_frames_ctx,
"Failed to attach CUDA frames to encoding context: ", "Failed to attach CUDA frames to encoding context: ",
av_err2string(ret)); av_err2string(ret));
}
#endif #endif
}
open_codec(ctx, encoder_option); open_codec(ctx, encoder_option);
return ctx; return ctx;
...@@ -466,24 +476,6 @@ void StreamWriter::add_video_stream( ...@@ -466,24 +476,6 @@ void StreamWriter::add_video_stream(
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) { const c10::optional<std::string>& hw_accel) {
const torch::Device device = [&]() {
if (!hw_accel) {
return torch::Device{c10::DeviceType::CPU};
}
#ifdef USE_CUDA
torch::Device d{hw_accel.value()};
TORCH_CHECK(
d.type() == c10::DeviceType::CUDA,
"Only CUDA is supported for hardware acceleration. Found:",
device.str());
return d;
#else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}();
AVBufferRefPtr hw_device_ctx{}; AVBufferRefPtr hw_device_ctx{};
AVBufferRefPtr hw_frame_ctx{}; AVBufferRefPtr hw_frame_ctx{};
...@@ -495,7 +487,7 @@ void StreamWriter::add_video_stream( ...@@ -495,7 +487,7 @@ void StreamWriter::add_video_stream(
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
device, hw_accel,
hw_device_ctx, hw_device_ctx,
hw_frame_ctx); hw_frame_ctx);
...@@ -504,8 +496,7 @@ void StreamWriter::add_video_stream( ...@@ -504,8 +496,7 @@ void StreamWriter::add_video_stream(
get_src_pixel_fmt(format), get_src_pixel_fmt(format),
std::move(ctx), std::move(ctx),
std::move(hw_device_ctx), std::move(hw_device_ctx),
std::move(hw_frame_ctx), std::move(hw_frame_ctx)));
device));
} }
void StreamWriter::set_metadata(const OptionDict& metadata) { void StreamWriter::set_metadata(const OptionDict& metadata) {
......
...@@ -8,13 +8,10 @@ namespace torchaudio::io { ...@@ -8,13 +8,10 @@ namespace torchaudio::io {
namespace { namespace {
FilterGraph get_video_filter( FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
AVPixelFormat src_fmt,
AVCodecContext* codec_ctx,
const torch::Device& device) {
auto desc = [&]() -> std::string { auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->pix_fmt || if (src_fmt == codec_ctx->pix_fmt ||
device.type() != c10::DeviceType::CPU) { codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
return "null"; return "null";
} else { } else {
std::stringstream ss; std::stringstream ss;
...@@ -36,29 +33,23 @@ FilterGraph get_video_filter( ...@@ -36,29 +33,23 @@ FilterGraph get_video_filter(
return p; return p;
} }
AVFramePtr get_hw_video_frame(AVCodecContext* codec_ctx) { AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
AVFramePtr frame{}; AVFramePtr frame{};
if (codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0); int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret)); TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
return frame; } else {
}
AVFramePtr get_video_frame(
AVPixelFormat src_fmt,
AVCodecContext* codec_ctx,
const torch::Device& device) {
if (device.type() == c10::DeviceType::CUDA) {
return get_hw_video_frame(codec_ctx);
}
AVFramePtr frame{};
frame->format = src_fmt; frame->format = src_fmt;
frame->width = codec_ctx->width; frame->width = codec_ctx->width;
frame->height = codec_ctx->height; frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0); int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK( TORCH_CHECK(
ret >= 0, "Error allocating a video buffer (", av_err2string(ret), ")."); ret >= 0,
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
return frame; return frame;
} }
...@@ -69,13 +60,12 @@ VideoOutputStream::VideoOutputStream( ...@@ -69,13 +60,12 @@ VideoOutputStream::VideoOutputStream(
AVPixelFormat src_fmt, AVPixelFormat src_fmt,
AVCodecContextPtr&& codec_ctx_, AVCodecContextPtr&& codec_ctx_,
AVBufferRefPtr&& hw_device_ctx_, AVBufferRefPtr&& hw_device_ctx_,
AVBufferRefPtr&& hw_frame_ctx_, AVBufferRefPtr&& hw_frame_ctx_)
const torch::Device& device)
: OutputStream( : OutputStream(
format_ctx, format_ctx,
codec_ctx_, codec_ctx_,
get_video_filter(src_fmt, codec_ctx_, device)), get_video_filter(src_fmt, codec_ctx_)),
src_frame(get_video_frame(src_fmt, codec_ctx_, device)), src_frame(get_video_frame(src_fmt, codec_ctx_)),
hw_device_ctx(std::move(hw_device_ctx_)), hw_device_ctx(std::move(hw_device_ctx_)),
hw_frame_ctx(std::move(hw_frame_ctx_)), hw_frame_ctx(std::move(hw_frame_ctx_)),
codec_ctx(std::move(codec_ctx_)) {} codec_ctx(std::move(codec_ctx_)) {}
......
...@@ -15,8 +15,7 @@ struct VideoOutputStream : OutputStream { ...@@ -15,8 +15,7 @@ struct VideoOutputStream : OutputStream {
AVPixelFormat src_fmt, AVPixelFormat src_fmt,
AVCodecContextPtr&& codec_ctx, AVCodecContextPtr&& codec_ctx,
AVBufferRefPtr&& hw_device_ctx, AVBufferRefPtr&& hw_device_ctx,
AVBufferRefPtr&& hw_frame_ctx, AVBufferRefPtr&& hw_frame_ctx);
const torch::Device& device);
void write_chunk(const torch::Tensor& frames) override; void write_chunk(const torch::Tensor& frames) override;
void process_frame(); void process_frame();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment