Commit b2a90f91 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add YUV444P support to StreamReader (#2516)

Summary:
This commit add support for `"yuv444p"` type as output format of StreamReader.

Pull Request resolved: https://github.com/pytorch/audio/pull/2516

Reviewed By: hwangjeff

Differential Revision: D37659715

Pulled By: mthrok

fbshipit-source-id: eae9b5590d8f138a6ebf3808c08adfe068f11a2b
parent 10ac6d2b
...@@ -19,7 +19,7 @@ from .case_utils import ( ...@@ -19,7 +19,7 @@ from .case_utils import (
) )
from .data_utils import get_asset_path, get_sinusoid, get_spectrogram, get_whitenoise from .data_utils import get_asset_path, get_sinusoid, get_spectrogram, get_whitenoise
from .func_utils import torch_script from .func_utils import torch_script
from .image_utils import get_image, save_image from .image_utils import get_image, rgb_to_gray, rgb_to_yuv_ccir, save_image
from .parameterized_utils import load_params, nested_params from .parameterized_utils import load_params, nested_params
from .wav_utils import get_wav_data, load_wav, normalize_wav, save_wav from .wav_utils import get_wav_data, load_wav, normalize_wav, save_wav
...@@ -55,4 +55,6 @@ __all__ = [ ...@@ -55,4 +55,6 @@ __all__ = [
"torch_script", "torch_script",
"save_image", "save_image",
"get_image", "get_image",
"rgb_to_gray",
"rgb_to_yuv_ccir",
] ]
...@@ -27,3 +27,46 @@ def get_image(width, height, grayscale=False): ...@@ -27,3 +27,46 @@ def get_image(width, height, grayscale=False):
img = torch.arange(numel, dtype=torch.int64) % 256 img = torch.arange(numel, dtype=torch.int64) % 256
img = img.reshape(channels, height, width).to(torch.uint8) img = img.reshape(channels, height, width).to(torch.uint8)
return img return img
def rgb_to_yuv_ccir(img):
"""rgb to yuv conversion ported from ffmpeg
The input image is expected to be (..., channel, height, width).
"""
assert img.dtype == torch.uint8
img = img.to(torch.float32)
r, g, b = torch.split(img, 1, dim=-3)
# https://github.com/FFmpeg/FFmpeg/blob/870bfe16a12bf09dca3a4ae27ef6f81a2de80c40/libavutil/colorspace.h#L98
y = 263 * r + 516 * g + 100 * b + 512 + 16384
y /= 1024
# https://github.com/FFmpeg/FFmpeg/blob/870bfe16a12bf09dca3a4ae27ef6f81a2de80c40/libavutil/colorspace.h#L102
# shift == 0
u = -152 * r - 298 * g + 450 * b + 512 - 1
u /= 1024
u += 128
# https://github.com/FFmpeg/FFmpeg/blob/870bfe16a12bf09dca3a4ae27ef6f81a2de80c40/libavutil/colorspace.h#L106
# shift == 0
v = 450 * r - 377 * g - 73 * b + 512 - 1
v /= 1024
v += 128
return torch.cat([y, u, v], -3).to(torch.uint8)
def rgb_to_gray(img):
"""rgb to gray conversion
The input image is expected to be (..., channel, height, width).
"""
assert img.dtype == torch.uint8
img = img.to(torch.float32)
r, g, b = torch.split(img, 1, dim=-3)
gray = 0.299 * r + 0.587 * g + 0.114 * b
return gray.to(torch.uint8)
...@@ -7,6 +7,8 @@ from torchaudio_unittest.common_utils import ( ...@@ -7,6 +7,8 @@ from torchaudio_unittest.common_utils import (
get_wav_data, get_wav_data,
is_ffmpeg_available, is_ffmpeg_available,
nested_params, nested_params,
rgb_to_gray,
rgb_to_yuv_ccir,
save_image, save_image,
save_wav, save_wav,
skipIfNoFFmpeg, skipIfNoFFmpeg,
...@@ -614,3 +616,29 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -614,3 +616,29 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
print("expected", expected) print("expected", expected)
print("output", output) print("output", output)
self.assertEqual(expected, output) self.assertEqual(expected, output)
def test_png_yuv_read_out(self):
"""Providing format prpoerly change the color space"""
rgb = torch.empty(1, 3, 256, 256, dtype=torch.uint8)
rgb[0, 0] = torch.arange(256, dtype=torch.uint8).reshape([1, -1])
rgb[0, 1] = torch.arange(256, dtype=torch.uint8).reshape([-1, 1])
for i in range(256):
rgb[0, 2] = i
path = self.get_temp_path(f"ref_{i}.png")
save_image(path, rgb[0], mode="RGB")
yuv = rgb_to_yuv_ccir(rgb)
bgr = rgb[:, [2, 1, 0], :, :]
gray = rgb_to_gray(rgb)
s = StreamReader(path)
s.add_basic_video_stream(frames_per_chunk=-1, format="yuv444p")
s.add_basic_video_stream(frames_per_chunk=-1, format="rgb24")
s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24")
s.add_basic_video_stream(frames_per_chunk=-1, format="gray8")
s.process_all_packets()
output_yuv, output_rgb, output_bgr, output_gray = s.pop_chunks()
self.assertEqual(yuv, output_yuv, atol=1, rtol=0)
self.assertEqual(rgb, output_rgb, atol=0, rtol=0)
self.assertEqual(bgr, output_bgr, atol=0, rtol=0)
self.assertEqual(gray, output_gray, atol=1, rtol=0)
...@@ -171,6 +171,56 @@ void AudioBuffer::push_frame(AVFrame* frame) { ...@@ -171,6 +171,56 @@ void AudioBuffer::push_frame(AVFrame* frame) {
// Modifiers - Push Video // Modifiers - Push Video
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
namespace { namespace {
torch::Tensor convert_interlaced_video(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
uint8_t* buf = pFrame->data[0];
int linesize = pFrame->linesize[0];
int channel = av_pix_fmt_desc_get(static_cast<AVPixelFormat>(pFrame->format))
->nb_components;
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
torch::Tensor frame = torch::empty({1, height, width, channel}, options);
auto ptr = frame.data_ptr<uint8_t>();
int stride = width * channel;
for (int i = 0; i < height; ++i) {
memcpy(ptr, buf, stride);
buf += linesize;
ptr += stride;
}
return frame.permute({0, 3, 1, 2});
}
torch::Tensor convert_planar_video(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
int num_planes =
av_pix_fmt_count_planes(static_cast<AVPixelFormat>(pFrame->format));
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
torch::Tensor frame = torch::empty({1, num_planes, height, width}, options);
for (int i = 0; i < num_planes; ++i) {
torch::Tensor plane = frame.index({0, i});
uint8_t* tgt = plane.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[i];
int linesize = pFrame->linesize[i];
for (int h = 0; h < height; ++h) {
memcpy(tgt, src, width);
tgt += width;
src += linesize;
}
}
return frame;
}
torch::Tensor convert_yuv420p(AVFrame* pFrame) { torch::Tensor convert_yuv420p(AVFrame* pFrame) {
int width = pFrame->width; int width = pFrame->width;
int height = pFrame->height; int height = pFrame->height;
...@@ -316,26 +366,17 @@ torch::Tensor convert_image_tensor( ...@@ -316,26 +366,17 @@ torch::Tensor convert_image_tensor(
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179 // https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
// https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038 // https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
AVPixelFormat format = static_cast<AVPixelFormat>(pFrame->format); AVPixelFormat format = static_cast<AVPixelFormat>(pFrame->format);
int width = pFrame->width;
int height = pFrame->height;
uint8_t* buf = pFrame->data[0];
int linesize = pFrame->linesize[0];
int channel;
switch (format) { switch (format) {
case AV_PIX_FMT_RGB24: case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: case AV_PIX_FMT_BGR24:
channel = 3;
break;
case AV_PIX_FMT_ARGB: case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA: case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR: case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: case AV_PIX_FMT_BGRA:
channel = 4;
break;
case AV_PIX_FMT_GRAY8: case AV_PIX_FMT_GRAY8:
channel = 1; return convert_interlaced_video(pFrame);
break; case AV_PIX_FMT_YUV444P:
return convert_planar_video(pFrame);
case AV_PIX_FMT_YUV420P: case AV_PIX_FMT_YUV420P:
return convert_yuv420p(pFrame); return convert_yuv420p(pFrame);
case AV_PIX_FMT_NV12: case AV_PIX_FMT_NV12:
...@@ -368,17 +409,6 @@ torch::Tensor convert_image_tensor( ...@@ -368,17 +409,6 @@ torch::Tensor convert_image_tensor(
"Unexpected video format: " + "Unexpected video format: " +
std::string(av_get_pix_fmt_name(format))); std::string(av_get_pix_fmt_name(format)));
} }
torch::Tensor t;
t = torch::empty({1, height, width, channel}, torch::kUInt8);
auto ptr = t.data_ptr<uint8_t>();
int stride = width * channel;
for (int i = 0; i < height; ++i) {
memcpy(ptr, buf, stride);
buf += linesize;
ptr += stride;
}
return t.permute({0, 3, 1, 2});
} }
} // namespace } // namespace
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment