Commit b7e173fa authored by Tristan Rice's avatar Tristan Rice Committed by Facebook GitHub Bot
Browse files

Add rgb48le and CUDA p010 support (HDR/10bit) to StreamReader (#3023)

Summary:
This adds 2 10 bit pix formats one for CPU and one for CUDA. This allows for training on HDR/10bit video datasets.

Pull Request resolved: https://github.com/pytorch/audio/pull/3023

Test Plan:
```py
r = StreamReader(
    reader, format='hevc',
)
stream = r.add_video_stream(
    frames_per_chunk=-1,
    decoder="hevc_cuvid",
    hw_accel="cuda",
)
frame = next(r.stream())
```

```py
r = StreamReader(
    reader, format='hevc',
)
stream = r.add_video_stream(
    frames_per_chunk=-1,
    filter_desc="format=rgb48le",
)
frame = next(r.stream())
```

![audio-example](https://user-images.githubusercontent.com/909104/215696543-ed3dc5a3-3013-4a57-8b98-05aa4a5a9a7c.png)

Reviewed By: xiaohui-zhang

Differential Revision: D43019191

Pulled By: mthrok

fbshipit-source-id: fe4359e525b24c8b856dfdf3d2f8596871566350
parent 4f201054
...@@ -895,6 +895,8 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -895,6 +895,8 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
path = self.get_temp_path(f"ref_{i}.png") path = self.get_temp_path(f"ref_{i}.png")
save_image(path, rgb[0], mode="RGB") save_image(path, rgb[0], mode="RGB")
rgb16 = ((rgb.to(torch.int32) - 128) << 8).to(torch.int16)
yuv = rgb_to_yuv_ccir(rgb) yuv = rgb_to_yuv_ccir(rgb)
bgr = rgb[:, [2, 1, 0], :, :] bgr = rgb[:, [2, 1, 0], :, :]
gray = rgb_to_gray(rgb) gray = rgb_to_gray(rgb)
...@@ -906,11 +908,13 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -906,11 +908,13 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
s.add_basic_video_stream(frames_per_chunk=-1, format="rgb24") s.add_basic_video_stream(frames_per_chunk=-1, format="rgb24")
s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24") s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24")
s.add_basic_video_stream(frames_per_chunk=-1, format="gray8") s.add_basic_video_stream(frames_per_chunk=-1, format="gray8")
s.add_basic_video_stream(frames_per_chunk=-1, format="rgb48le")
s.process_all_packets() s.process_all_packets()
yuv444, yuv420, nv12, rgb24, bgr24, gray8 = s.pop_chunks() yuv444, yuv420, nv12, rgb24, bgr24, gray8, rgb48le = s.pop_chunks()
self.assertEqual(yuv, yuv444, atol=1, rtol=0) self.assertEqual(yuv, yuv444, atol=1, rtol=0)
self.assertEqual(yuv, yuv420, atol=1, rtol=0) self.assertEqual(yuv, yuv420, atol=1, rtol=0)
self.assertEqual(yuv, nv12, atol=1, rtol=0) self.assertEqual(yuv, nv12, atol=1, rtol=0)
self.assertEqual(rgb, rgb24, atol=0, rtol=0) self.assertEqual(rgb, rgb24, atol=0, rtol=0)
self.assertEqual(bgr, bgr24, atol=0, rtol=0) self.assertEqual(bgr, bgr24, atol=0, rtol=0)
self.assertEqual(gray, gray8, atol=1, rtol=0) self.assertEqual(gray, gray8, atol=1, rtol=0)
self.assertEqual(rgb16, rgb48le, atol=256, rtol=0)
...@@ -88,9 +88,10 @@ torch::Tensor convert_audio(AVFrame* pFrame) { ...@@ -88,9 +88,10 @@ torch::Tensor convert_audio(AVFrame* pFrame) {
namespace { namespace {
torch::Tensor get_buffer( torch::Tensor get_buffer(
at::IntArrayRef shape, at::IntArrayRef shape,
const torch::Device& device = torch::Device(torch::kCPU)) { const torch::Device& device = torch::Device(torch::kCPU),
const torch::Dtype dtype = torch::kUInt8) {
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(torch::kUInt8) .dtype(dtype)
.layout(torch::kStrided) .layout(torch::kStrided)
.device(device.type(), device.index()); .device(device.type(), device.index());
return torch::empty(shape, options); return torch::empty(shape, options);
...@@ -128,11 +129,17 @@ std::tuple<torch::Tensor, bool> get_image_buffer( ...@@ -128,11 +129,17 @@ std::tuple<torch::Tensor, bool> get_image_buffer(
int height = frame->height; int height = frame->height;
int width = frame->width; int width = frame->width;
int depth = desc->comp[0].depth;
auto dtype = (depth > 8) ? torch::kInt16 : torch::kUInt8;
if (desc->flags & AV_PIX_FMT_FLAG_PLANAR) { if (desc->flags & AV_PIX_FMT_FLAG_PLANAR) {
auto buffer = get_buffer({num_frames, channels, height, width}, device); auto buffer =
get_buffer({num_frames, channels, height, width}, device, dtype);
return std::make_tuple(buffer, true); return std::make_tuple(buffer, true);
} }
auto buffer = get_buffer({num_frames, height, width, channels}, device); auto buffer =
get_buffer({num_frames, height, width, channels}, device, dtype);
return std::make_tuple(buffer, false); return std::make_tuple(buffer, false);
} }
...@@ -148,6 +155,20 @@ void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) { ...@@ -148,6 +155,20 @@ void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
} }
} }
void write_interlaced_image16(AVFrame* pFrame, torch::Tensor& frame) {
auto ptr = frame.data_ptr<int16_t>();
uint8_t* buf = pFrame->data[0];
size_t height = frame.size(1);
size_t stride = frame.size(2) * frame.size(3);
for (int i = 0; i < height; ++i) {
memcpy(ptr, buf, stride * 2);
buf += pFrame->linesize[0];
ptr += stride;
}
// correct for int16
frame += 32768;
}
void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) { void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) {
int num_planes = static_cast<int>(frame.size(1)); int num_planes = static_cast<int>(frame.size(1));
int height = static_cast<int>(frame.size(2)); int height = static_cast<int>(frame.size(2));
...@@ -307,6 +328,63 @@ void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) { ...@@ -307,6 +328,63 @@ void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) {
// yuv[:, 1:] = uv // yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv); yuv.index_put_({Slice(), Slice(1)}, uv);
} }
void write_p010_cuda(AVFrame* pFrame, torch::Tensor& yuv) {
int height = static_cast<int>(yuv.size(2));
int width = static_cast<int>(yuv.size(3));
// Write Y plane directly
{
int16_t* tgt = yuv.data_ptr<int16_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[0];
int linesize = pFrame->linesize[0];
TORCH_CHECK(
cudaSuccess ==
cudaMemcpy2D(
(void*)tgt,
width * 2,
(const void*)src,
linesize,
width * 2,
height,
cudaMemcpyDeviceToDevice),
"Failed to copy Y plane to Cuda tensor.");
}
// Prepare intermediate UV planes
torch::Tensor uv =
get_buffer({1, height / 2, width / 2, 2}, yuv.device(), torch::kInt16);
{
int16_t* tgt = uv.data_ptr<int16_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[1];
int linesize = pFrame->linesize[1];
TORCH_CHECK(
cudaSuccess ==
cudaMemcpy2D(
(void*)tgt,
width * 2,
(const void*)src,
linesize,
width * 2,
height / 2,
cudaMemcpyDeviceToDevice),
"Failed to copy UV plane to Cuda tensor.");
}
uv = uv.permute({0, 3, 1, 2});
using namespace torch::indexing;
// Write to the UV plane
// very simplistic upscale using indexing since interpolate doesn't support
// shorts
yuv.index_put_(
{Slice(), Slice(1, 3), Slice(None, None, 2), Slice(None, None, 2)}, uv);
yuv.index_put_(
{Slice(), Slice(1, 3), Slice(1, None, 2), Slice(None, None, 2)}, uv);
yuv.index_put_(
{Slice(), Slice(1, 3), Slice(None, None, 2), Slice(1, None, 2)}, uv);
yuv.index_put_(
{Slice(), Slice(1, 3), Slice(1, None, 2), Slice(1, None, 2)}, uv);
// correct for int16
yuv += 32768;
}
#endif #endif
void write_image(AVFrame* frame, torch::Tensor& buf) { void write_image(AVFrame* frame, torch::Tensor& buf) {
...@@ -337,6 +415,10 @@ void write_image(AVFrame* frame, torch::Tensor& buf) { ...@@ -337,6 +415,10 @@ void write_image(AVFrame* frame, torch::Tensor& buf) {
write_nv12_cpu(frame, buf); write_nv12_cpu(frame, buf);
return; return;
} }
case AV_PIX_FMT_RGB48LE: {
write_interlaced_image16(frame, buf);
return;
}
#ifdef USE_CUDA #ifdef USE_CUDA
case AV_PIX_FMT_CUDA: { case AV_PIX_FMT_CUDA: {
AVHWFramesContext* hwctx = (AVHWFramesContext*)frame->hw_frames_ctx->data; AVHWFramesContext* hwctx = (AVHWFramesContext*)frame->hw_frames_ctx->data;
...@@ -349,7 +431,10 @@ void write_image(AVFrame* frame, torch::Tensor& buf) { ...@@ -349,7 +431,10 @@ void write_image(AVFrame* frame, torch::Tensor& buf) {
write_nv12_cuda(frame, buf); write_nv12_cuda(frame, buf);
return; return;
} }
case AV_PIX_FMT_P010: case AV_PIX_FMT_P010: {
write_p010_cuda(frame, buf);
return;
}
case AV_PIX_FMT_P016: case AV_PIX_FMT_P016:
TORCH_CHECK( TORCH_CHECK(
false, false,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment