readjpeg_cpu.cpp 1.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include "readjpeg_cpu.h"

#include <ATen/ATen.h>
#include <setjmp.h>
#include <string>

#if !JPEG_FOUND

torch::Tensor decodeJPEG(const torch::Tensor& data) {
  AT_ERROR("decodeJPEG: torchvision not compiled with turboJPEG support");
}

#else
#include <turbojpeg.h>

torch::Tensor decodeJPEG(const torch::Tensor& data) {
  tjhandle tjInstance = tjInitDecompress();
  if (tjInstance == NULL) {
    TORCH_CHECK(false, "libjpeg-turbo decompression initialization failed.");
  }

  auto datap = data.accessor<unsigned char, 1>().data();

  int width, height;

  if (tjDecompressHeader(tjInstance, datap, data.numel(), &width, &height) <
      0) {
    tjDestroy(tjInstance);
    TORCH_CHECK(false, "Error while reading jpeg headers");
  }
  auto tensor =
      torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8);

  auto ptr = tensor.accessor<uint8_t, 3>().data();

  int pixelFormat = TJPF_RGB;

  auto ret = tjDecompress2(
      tjInstance,
      datap,
      data.numel(),
      ptr,
      width,
      0,
      height,
      pixelFormat,
      NULL);
  if (ret != 0) {
    tjDestroy(tjInstance);
    TORCH_CHECK(false, "decompressing JPEG image");
  }

  return tensor;
}

#endif // JPEG_FOUND