Commit cc0d1e0b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor and optimize yuv420p and nv12 processing (#2945)

Summary:
This commit refactors and optimizes functions that converts AVFrames of `yuv420p` and `nv12` into PyTorch's Tensor.
The performance is improved about 30%.

1. Reduce the number of intermediate Tensors allocated.
2. Replace 2 calls to `repeat_interleave` with `F::interpolate`.

 * (`F::interpolate` is about 5x faster than `repeat_interleave`. )
    <details><summary>code</summary>

    ```bash
    #!/usr/bin/env bash

    set -e

    python -c """
    import torch
    import torch.nn.functional as F

    a = torch.arange(49, dtype=torch.uint8).reshape(7, 7).clone()
    val1 = a.repeat_interleave(2, -1).repeat_interleave(2, -2)
    val2 = F.interpolate(a.view((1, 1, 7, 7, 1)), size=[14, 14, 1], mode=\"nearest\")
    print(torch.sum(torch.abs(val1 - val2[0, 0, :, :, 0])))
    """

    python3 -m timeit \
            --setup """
    import torch

    a = torch.arange(49, dtype=torch.uint8).reshape(7, 7).clone()
    """ \
            """
    a.repeat_interleave(2, -1).repeat_interleave(2, -2)
    """

    python3 -m timeit \
            --setup """
    import torch
    import torch.nn.functional as F

    a = torch.arange(49, dtype=torch.uint8).reshape(7, 7).clone()
    """ \
            """
    F.interpolate(a.view((1, 1, 7, 7, 1)), size=[14, 14, 1], mode=\"nearest\")
    """
    ```

    </details>

    ```
    tensor(0)
    10000 loops, best of 5: 38.3 usec per loop
    50000 loops, best of 5: 7.1 usec per loop
    ```

## Benchmark Result

<details><summary>code</summary>

```bash
#!/usr/bin/env bash

set -e

mkdir -p tmp

for ext in avi mp4; do
    for duration in 1 5 10 30 60; do
        printf "Testing ${ext} ${duration} [sec]\n"

        test_data="tmp/test_${duration}.${ext}"
        if [ ! -f "${test_data}" ]; then
            printf "Generating test data\n"
            ffmpeg -hide_banner -f lavfi -t ${duration} -i testsrc "${test_data}" > /dev/null 2>&1
        fi

        python -m timeit \
               --setup="from torchaudio.io import StreamReader" \
               """
r = StreamReader(\"${test_data}\")
r.add_basic_video_stream(frames_per_chunk=-1, format=\"yuv420p\")
r.process_all_packets()
r.pop_chunks()
"""
    done
done
```

</details>

![Time to decode AVI file](https://user-images.githubusercontent.com/855818/210008881-8cc83f18-0e51-46e3-afe9-a5ff5dff041e.png)

<details><summary>raw data</summary>

Video Type - AVI
Duration | Before | After
-- | -- | --
1 | 10.3 | 6.29
5 | 44.3 | 28.3
10 | 89.3 | 56.9
30 | 265 | 185
60 | 555 | 353
</details>

![Time to decode MP4 file](https://user-images.githubusercontent.com/855818/210008891-c4546c52-43d7-49d0-8eff-d866ad627129.png)

<details><summary>raw data</summary>

Video Type - MP4
Duration | Before | After
-- | -- | --
1 | 15.3 | 10.5
5 | 62.1 | 43.2
10 | 124 | 83.8
30 | 380 | 252
60 | 721 | 511
</details>

Pull Request resolved: https://github.com/pytorch/audio/pull/2945

Reviewed By: carolineechen

Differential Revision: D42283269

Pulled By: mthrok

fbshipit-source-id: 59840f943ff516b69ab8ad35fed7104c48a0bf0c
parent 9f57951a
...@@ -721,12 +721,16 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -721,12 +721,16 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
s = StreamReader(path) s = StreamReader(path)
s.add_basic_video_stream(frames_per_chunk=-1, format="yuv444p") s.add_basic_video_stream(frames_per_chunk=-1, format="yuv444p")
s.add_basic_video_stream(frames_per_chunk=-1, format="yuv420p")
s.add_basic_video_stream(frames_per_chunk=-1, format="nv12")
s.add_basic_video_stream(frames_per_chunk=-1, format="rgb24") s.add_basic_video_stream(frames_per_chunk=-1, format="rgb24")
s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24") s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24")
s.add_basic_video_stream(frames_per_chunk=-1, format="gray8") s.add_basic_video_stream(frames_per_chunk=-1, format="gray8")
s.process_all_packets() s.process_all_packets()
output_yuv, output_rgb, output_bgr, output_gray = s.pop_chunks() yuv444, yuv420, nv12, rgb24, bgr24, gray8 = s.pop_chunks()
self.assertEqual(yuv, output_yuv, atol=1, rtol=0) self.assertEqual(yuv, yuv444, atol=1, rtol=0)
self.assertEqual(rgb, output_rgb, atol=0, rtol=0) self.assertEqual(yuv, yuv420, atol=1, rtol=0)
self.assertEqual(bgr, output_bgr, atol=0, rtol=0) self.assertEqual(yuv, nv12, atol=1, rtol=0)
self.assertEqual(gray, output_gray, atol=1, rtol=0) self.assertEqual(rgb, rgb24, atol=0, rtol=0)
self.assertEqual(bgr, bgr24, atol=0, rtol=0)
self.assertEqual(gray, gray8, atol=1, rtol=0)
...@@ -164,7 +164,7 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) { ...@@ -164,7 +164,7 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) {
.layout(torch::kStrided) .layout(torch::kStrided)
.device(torch::kCPU); .device(torch::kCPU);
torch::Tensor y = torch::empty({1, height, width, 1}, options); torch::Tensor y = torch::empty({1, 1, height, width}, options);
{ {
uint8_t* tgt = y.data_ptr<uint8_t>(); uint8_t* tgt = y.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[0]; uint8_t* src = pFrame->data[0];
...@@ -175,9 +175,9 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) { ...@@ -175,9 +175,9 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) {
src += linesize; src += linesize;
} }
} }
torch::Tensor u = torch::empty({1, height / 2, width / 2, 1}, options); torch::Tensor uv = torch::empty({1, 2, height / 2, width / 2}, options);
{ {
uint8_t* tgt = u.data_ptr<uint8_t>(); uint8_t* tgt = uv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[1]; uint8_t* src = pFrame->data[1];
int linesize = pFrame->linesize[1]; int linesize = pFrame->linesize[1];
for (int h = 0; h < height / 2; ++h) { for (int h = 0; h < height / 2; ++h) {
...@@ -185,23 +185,22 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) { ...@@ -185,23 +185,22 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) {
tgt += width / 2; tgt += width / 2;
src += linesize; src += linesize;
} }
} src = pFrame->data[2];
torch::Tensor v = torch::empty({1, height / 2, width / 2, 1}, options); linesize = pFrame->linesize[2];
{
uint8_t* tgt = v.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[2];
int linesize = pFrame->linesize[2];
for (int h = 0; h < height / 2; ++h) { for (int h = 0; h < height / 2; ++h) {
memcpy(tgt, src, width / 2); memcpy(tgt, src, width / 2);
tgt += width / 2; tgt += width / 2;
src += linesize; src += linesize;
} }
} }
torch::Tensor uv = torch::cat({u, v}, -1);
// Upsample width and height // Upsample width and height
uv = uv.repeat_interleave(2, -2).repeat_interleave(2, -3); namespace F = torch::nn::functional;
torch::Tensor t = torch::cat({y, uv}, -1); uv = F::interpolate(
return t.permute({0, 3, 1, 2}); // NCHW uv,
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width})));
return torch::cat({y, uv}, 1);
} }
torch::Tensor convert_nv12_cpu(AVFrame* pFrame) { torch::Tensor convert_nv12_cpu(AVFrame* pFrame) {
...@@ -236,8 +235,13 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) { ...@@ -236,8 +235,13 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) {
} }
} }
// Upsample width and height // Upsample width and height
uv = uv.repeat_interleave(2, -2).repeat_interleave(2, -3); namespace F = torch::nn::functional;
torch::Tensor t = torch::cat({y, uv}, -1); uv = F::interpolate(
uv.view({1, 1, height / 2, width / 2, 2}),
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width, 2})));
torch::Tensor t = torch::cat({y, uv[0]}, -1);
return t.permute({0, 3, 1, 2}); // NCHW return t.permute({0, 3, 1, 2}); // NCHW
} }
...@@ -287,8 +291,13 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) { ...@@ -287,8 +291,13 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) {
"Failed to copy UV plane to Cuda tensor."); "Failed to copy UV plane to Cuda tensor.");
} }
// Upsample width and height // Upsample width and height
uv = uv.repeat_interleave(2, -2).repeat_interleave(2, -3); namespace F = torch::nn::functional;
torch::Tensor t = torch::cat({y, uv}, -1); uv = F::interpolate(
uv.view({1, 1, height / 2, width / 2, 2}),
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width, 2})));
torch::Tensor t = torch::cat({y, uv[0]}, -1);
return t.permute({0, 3, 1, 2}); // NCHW return t.permute({0, 3, 1, 2}); // NCHW
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment