"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "a2c5e6cd58728190e1183fd02fdabbbc57e35f0f"
Commit 9210cba2 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix StreamWriter regression around RGB0/BGR0 (#3428)

Summary:
- Add RGB0/BGR0 support to CPU encoder
- Allow to pass RGB/BGR when expectged format is RGB0/BGR0

Pull Request resolved: https://github.com/pytorch/audio/pull/3428

Differential Revision: D47274370

Pulled By: mthrok

fbshipit-source-id: d34d940e04b07673bb86f518fe895c0735912444
parent f77c3e5b
...@@ -96,33 +96,30 @@ enum AVSampleFormat get_src_sample_fmt(const std::string& src) { ...@@ -96,33 +96,30 @@ enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
"."); ".");
} }
const std::set<AVPixelFormat> SUPPORTED_PIX_FMTS{
AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB0,
AV_PIX_FMT_BGR0,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P};
enum AVPixelFormat get_src_pix_fmt(const std::string& src) { enum AVPixelFormat get_src_pix_fmt(const std::string& src) {
AVPixelFormat fmt = av_get_pix_fmt(src.c_str()); AVPixelFormat fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
default:;
}
TORCH_CHECK( TORCH_CHECK(
false, SUPPORTED_PIX_FMTS.count(fmt),
"Unsupported pixel format (", "Unsupported pixel format (",
src, src,
") was provided. Valid values are ", ") was provided. Valid values are ",
[]() -> std::string { []() -> std::string {
std::vector<std::string> ret; std::vector<std::string> ret;
for (const auto& fmt : for (const auto& fmt : SUPPORTED_PIX_FMTS) {
{AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P}) {
ret.emplace_back(av_get_pix_fmt_name(fmt)); ret.emplace_back(av_get_pix_fmt_name(fmt));
} }
return c10::Join(", ", ret); return c10::Join(", ", ret);
}(), }(),
"."); ".");
return fmt;
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
...@@ -8,6 +8,8 @@ namespace torchaudio::io { ...@@ -8,6 +8,8 @@ namespace torchaudio::io {
namespace { namespace {
using namespace torch::indexing;
using InitFunc = TensorConverter::InitFunc; using InitFunc = TensorConverter::InitFunc;
using ConvertFunc = TensorConverter::ConvertFunc; using ConvertFunc = TensorConverter::ConvertFunc;
...@@ -111,6 +113,28 @@ void validate_video_input( ...@@ -111,6 +113,28 @@ void validate_video_input(
t.sizes()); t.sizes());
} }
// Special case where encode pixel format is RGB0/BGR0 but the tensor is RGB/BGR
void validate_rgb0(const torch::Tensor& t, AVFrame* buffer) {
if (buffer->hw_frames_ctx) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
}
TORCH_CHECK(
t.dtype().toScalarType() == c10::ScalarType::Byte,
"Expected Tensor of uint8 type.");
TORCH_CHECK(t.dim() == 4, "Input Tensor has to be 4D.");
TORCH_CHECK(
t.size(2) == buffer->height && t.size(3) == buffer->width,
"Expected tensor with shape (N, 3, ",
buffer->height,
", ",
buffer->width,
") (NCHW format). Found ",
t.sizes());
}
// NCHW ->NHWC, ensure contiguous // NCHW ->NHWC, ensure contiguous
torch::Tensor init_interlaced(const torch::Tensor& tensor) { torch::Tensor init_interlaced(const torch::Tensor& tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.dim() == 4); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.dim() == 4);
...@@ -276,16 +300,20 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) { ...@@ -276,16 +300,20 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data); auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
auto sw_pix_fmt = frames_ctx->sw_format; auto sw_pix_fmt = frames_ctx->sw_format;
switch (sw_pix_fmt) { switch (sw_pix_fmt) {
// Note:
// RGB0 / BGR0 expects 4 channel, but neither
// av_pix_fmt_desc_get(pix_fmt)->nb_components
// or av_pix_fmt_count_planes(pix_fmt) returns 4.
case AV_PIX_FMT_RGB0: case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: { case AV_PIX_FMT_BGR0: {
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) { ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video_cuda(t, f, 4); write_interlaced_video_cuda(t, f, 4);
}; };
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) { InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
// Special treatment for the case user pass regular RGB/BGR tensor.
if (t.dim() == 4 && t.size(1) == 3) {
validate_rgb0(t, f);
auto tmp =
torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options());
tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1}));
return tmp;
}
validate_video_input(t, f, 4); validate_video_input(t, f, 4);
return init_interlaced(t); return init_interlaced(t);
}; };
...@@ -327,6 +355,24 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) { ...@@ -327,6 +355,24 @@ std::pair<InitFunc, ConvertFunc> get_video_func(AVFrame* buffer) {
}; };
return {init_func, convert_func}; return {init_func, convert_func};
} }
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: {
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
if (t.dim() == 4 && t.size(1) == 3) {
validate_rgb0(t, f);
auto tmp =
torch::empty({t.size(0), t.size(2), t.size(3), 4}, t.options());
tmp.index_put_({"...", Slice(0, 3)}, t.permute({0, 2, 3, 1}));
return tmp;
}
validate_video_input(t, f, 4);
return init_interlaced(t);
};
ConvertFunc convert_func = [](const torch::Tensor& t, AVFrame* f) {
write_interlaced_video(t, f, 4);
};
return {init_func, convert_func};
}
case AV_PIX_FMT_YUV444P: { case AV_PIX_FMT_YUV444P: {
InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) { InitFunc init_func = [](const torch::Tensor& t, AVFrame* f) {
validate_video_input(t, f, 3); validate_video_input(t, f, 3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment