Commit 9210cba2 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix StreamWriter regression around RGB0/BGR0 (#3428)

Summary:
- Add RGB0/BGR0 support to CPU encoder
- Allow to pass RGB/BGR when expectged format is RGB0/BGR0

Pull Request resolved: https://github.com/pytorch/audio/pull/3428

Differential Revision: D47274370

Pulled By: mthrok

fbshipit-source-id: d34d940e04b07673bb86f518fe895c0735912444
parent f77c3e5b
......@@ -96,33 +96,30 @@ enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
".");
}
const std::set<AVPixelFormat> SUPPORTED_PIX_FMTS{
AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB0,
AV_PIX_FMT_BGR0,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P};
enum AVPixelFormat get_src_pix_fmt(const std::string& src) {
AVPixelFormat fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
default:;
}
TORCH_CHECK(
false,
SUPPORTED_PIX_FMTS.count(fmt),
"Unsupported pixel format (",
src,
") was provided. Valid values are ",
[]() -> std::string {
std::vector<std::string> ret;
for (const auto& fmt :
{AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P}) {
for (const auto& fmt : SUPPORTED_PIX_FMTS) {
ret.emplace_back(av_get_pix_fmt_name(fmt));
}
return c10::Join(", ", ret);
}(),
".");
return fmt;
}
////////////////////////////////////////////////////////////////////////////////
......
......@@ -8,6 +8,8 @@ namespace torchaudio::io {
namespace {
using namespace torch::indexing;
using InitFunc = TensorConverter::InitFunc;
using ConvertFunc = TensorConverter::ConvertFunc;
......@@ -111,6 +113,28 @@ void validate_video_input(
t.sizes());
}
// Special case where encode pixel format is RGB0/BGR0 but the tensor is RGB/BGR
void validate_rgb0(const torch::Tensor& t, AVFrame* buffer) {
if (buffer->hw_frames_ctx) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
}
TORCH_CHECK(
t.dtype().toScalarType() == c10::ScalarType::Byte,
"Expected Tensor of uint8 type.");
TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D.");
TORCH_CHECK(
t.size(2) == buffer->height && t.size(3) == buffer->width,
"Expected tensor with shape (N, 3, ",
buffer->height,
", ",
buffer->width,
") (NCHW format). Found ",
t.sizes());
}
// NCHW ->NHWC, ensure contiguous
torch::Tensor init_interlaced(const torch::Tensor& tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.dim() == 4);
......@@ -276,16 +300,20 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
auto sw_pix_fmt = frames_ctx->sw_format;
switch (sw_pix_fmt) {
// Note:
// RGB0 / BGR0 expects 4 channel, but neither
// av_pix_fmt_desc_get(pix_fmt)->nb_components
// or av_pix_fmt_count_planes(pix_fmt) returns 4.
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: {
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video_cuda(t, f, 4);
};
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
// Special treatment for the case user pass regular RGB/BGR tensor.
if (t.dim() == 4 && t.size(1) == 3) {
validate_rgb0(t, f);
auto tmp =
torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options());
tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1}));
return tmp;
}
validate_video_input(t, f, 4);
return init_interlaced(t);
};
......@@ -327,6 +355,24 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
};
return {init_func, convert_func};
}
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: {
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
if (t.dim() == 4 && t.size(1) == 3) {
validate_rgb0(t, f);
auto tmp =
torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options());
tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1}));
return tmp;
}
validate_video_input(t, f, 4);
return init_interlaced(t);
};
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video(t, f, 4);
};
return {init_func, convert_func};
}
case AV_PIX_FMT_YUV444P: {
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
validate_video_input(t, f, 3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment