Commit 3240de92 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Support YUV444P in GPU decoder (#3199)

Summary:
With the support of CUDA filter in https://github.com/pytorch/audio/issues/3183, it is now possible to change the pixel format of CUDA frame.

This commit adds conversion for YUV444P format.

Pull Request resolved: https://github.com/pytorch/audio/pull/3199

Reviewed By: hwangjeff

Differential Revision: D44323928

Pulled By: mthrok

fbshipit-source-id: 6d9b205e7235df5f21e7d3e06166b3a169f1ae9f
parent 68fa1d3f
...@@ -1147,8 +1147,14 @@ class FilterGraphWithCudaAccel(TorchaudioTestCase): ...@@ -1147,8 +1147,14 @@ class FilterGraphWithCudaAccel(TorchaudioTestCase):
assert num_frames == 390 assert num_frames == 390
def test_scale_cuda_format(self): def test_scale_cuda_format(self):
"""yuv444p format conversion does not work (yet)""" """yuv444p format conversion should work"""
src = get_asset_path("nasa_13013.mp4") src = get_asset_path("nasa_13013.mp4")
r = StreamReader(src) r = StreamReader(src)
with self.assertRaises(RuntimeError): r.add_video_stream(10, decoder="h264_cuvid", hw_accel="cuda", filter_desc="scale_cuda=format=yuv444p")
r.add_video_stream(10, decoder="h264_cuvid", hw_accel="cuda", filter_desc="scale_cuda=format=yuv444p") num_frames = 0
for (chunk,) in r.stream():
self.assertEqual(chunk.device, torch.device("cuda:0"))
self.assertEqual(chunk.dtype, torch.uint8)
self.assertEqual(chunk.shape, torch.Size([10, 3, 270, 480]))
num_frames += chunk.size(0)
assert num_frames == 390
...@@ -491,6 +491,59 @@ torch::Tensor P010CudaConverter::convert(const AVFrame* src) { ...@@ -491,6 +491,59 @@ torch::Tensor P010CudaConverter::convert(const AVFrame* src) {
return buffer; return buffer;
} }
////////////////////////////////////////////////////////////////////////////////
// YUV444P CUDA
////////////////////////////////////////////////////////////////////////////////
YUV444PCudaConverter::YUV444PCudaConverter(
int h,
int w,
const torch::Device& device)
: ImageConverterBase(h, w, 3), device(device) {}
void YUV444PCudaConverter::convert(const AVFrame* src, torch::Tensor& dst) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->height == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src->width == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(1) == 3);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(2) == height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.size(3) == width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dst.dtype() == torch::kUInt8);
auto fmt = (AVPixelFormat)(src->format);
AVHWFramesContext* hwctx = (AVHWFramesContext*)src->hw_frames_ctx->data;
AVPixelFormat sw_fmt = hwctx->sw_format;
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_CUDA == fmt,
"Expected CUDA frame. Found: ",
av_get_pix_fmt_name(fmt));
TORCH_INTERNAL_ASSERT(
AV_PIX_FMT_YUV444P == sw_fmt,
"Expected YUV444P format. Found: ",
av_get_pix_fmt_name(sw_fmt));
// Write Y plane directly
for (int i = 0; i < num_channels; ++i) {
auto status = cudaMemcpy2D(
dst.index({0, i}).data_ptr(),
width,
src->data[i],
src->linesize[i],
width,
height,
cudaMemcpyDeviceToDevice);
TORCH_CHECK(
cudaSuccess == status, "Failed to copy plane ", i, " to CUDA tensor.");
}
}
torch::Tensor YUV444PCudaConverter::convert(const AVFrame* src) {
torch::Tensor buffer =
get_image_buffer({1, num_channels, height, width}, device);
convert(src, buffer);
return buffer;
}
#endif #endif
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -102,5 +102,14 @@ class P010CudaConverter : ImageConverterBase { ...@@ -102,5 +102,14 @@ class P010CudaConverter : ImageConverterBase {
torch::Tensor convert(const AVFrame* src); torch::Tensor convert(const AVFrame* src);
}; };
class YUV444PCudaConverter : ImageConverterBase {
const torch::Device device;
public:
YUV444PCudaConverter(int height, int width, const torch::Device& device);
void convert(const AVFrame* src, torch::Tensor& dst);
torch::Tensor convert(const AVFrame* src);
};
#endif #endif
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -402,6 +402,11 @@ std::unique_ptr<IPostDecodeProcess> get_unchunked_cuda_video_process( ...@@ -402,6 +402,11 @@ std::unique_ptr<IPostDecodeProcess> get_unchunked_cuda_video_process(
return std::make_unique<ProcessImpl<C, B>>( return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.height, i.width, device}, B{i.time_base}); std::move(filter), C{i.height, i.width, device}, B{i.time_base});
} }
case AV_PIX_FMT_YUV444P: {
using C = YUV444PCudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.height, i.width, device}, B{i.time_base});
}
case AV_PIX_FMT_P016: { case AV_PIX_FMT_P016: {
TORCH_CHECK( TORCH_CHECK(
false, false,
...@@ -514,6 +519,13 @@ std::unique_ptr<IPostDecodeProcess> get_chunked_cuda_video_process( ...@@ -514,6 +519,13 @@ std::unique_ptr<IPostDecodeProcess> get_chunked_cuda_video_process(
C{i.height, i.width, device}, C{i.height, i.width, device},
B{i.time_base, frames_per_chunk, num_chunks}); B{i.time_base, frames_per_chunk, num_chunks});
} }
case AV_PIX_FMT_YUV444P: {
using C = YUV444PCudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter),
C{i.height, i.width, device},
B{i.time_base, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_P016: { case AV_PIX_FMT_P016: {
TORCH_CHECK( TORCH_CHECK(
false, false,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment