encode_jpeg.cpp 3.31 KB
Newer Older
1
2
3
4
5
6
#include "encode_jpeg.h"

#include "common_jpeg.h"

namespace vision {
namespace image {
7
8
9

#if !JPEG_FOUND

10
torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
11
  TORCH_CHECK(
12
      false, "encode_jpeg: torchvision not compiled with libjpeg support");
13
14
15
}

#else
16
17
18
19
20
21
22
23
24
25
26
27
// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is
// defined as unsigned long, where as in later version, it is defined as size_t.
// For windows backward compatibility, we define JpegSizeType as different types
// according to the libjpeg version used, in order to prevent compilcation
// errors.
#if defined(_WIN32) || !defined(JPEG_LIB_VERSION_MAJOR) || \
    (JPEG_LIB_VERSION_MAJOR < 9) ||                        \
    (JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2)
using JpegSizeType = unsigned long;
#else
using JpegSizeType = size_t;
#endif
28

29
30
31
using namespace detail;

torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
32
  // Define compression structures and error handling
cyy's avatar
cyy committed
33
34
  struct jpeg_compress_struct cinfo {};
  struct torch_jpeg_error_mgr jerr {};
35
36

  // Define buffer to write JPEG information to and its size
37
  JpegSizeType jpegSize = 0;
cyy's avatar
cyy committed
38
  uint8_t* jpegBuf = nullptr;
39
40
41
42
43
44
45
46
47
48

  cinfo.err = jpeg_std_error(&jerr.pub);
  jerr.pub.error_exit = torch_jpeg_error_exit;

  /* Establish the setjmp return context for my_error_exit to use. */
  if (setjmp(jerr.setjmp_buffer)) {
    /* If we get here, the JPEG code has signaled an error.
     * We need to clean up the JPEG object and the buffer.
     */
    jpeg_destroy_compress(&cinfo);
cyy's avatar
cyy committed
49
    if (jpegBuf != nullptr) {
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
      free(jpegBuf);
    }

    TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg);
  }

  // Check that the input tensor is on CPU
  TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");

  // Check that the input tensor dtype is uint8
  TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");

  // Check that the input tensor is 3-dimensional
  TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");

  // Get image info
  int channels = data.size(0);
  int height = data.size(1);
  int width = data.size(2);
  auto input = data.permute({1, 2, 0}).contiguous();

  TORCH_CHECK(
      channels == 1 || channels == 3,
      "The number of channels should be 1 or 3, got: ",
      channels);

  // Initialize JPEG structure
  jpeg_create_compress(&cinfo);

  // Set output image information
  cinfo.image_width = width;
  cinfo.image_height = height;
  cinfo.input_components = channels;
  cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB;

  jpeg_set_defaults(&cinfo);
  jpeg_set_quality(&cinfo, quality, TRUE);

  // Save JPEG output to a buffer
  jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize);

  // Start JPEG compression
  jpeg_start_compress(&cinfo, TRUE);

  auto stride = width * channels;
  auto ptr = input.data_ptr<uint8_t>();

  // Encode JPEG file
  while (cinfo.next_scanline < cinfo.image_height) {
    jpeg_write_scanlines(&cinfo, &ptr, 1);
    ptr += stride;
  }

  jpeg_finish_compress(&cinfo);
  jpeg_destroy_compress(&cinfo);

  torch::TensorOptions options = torch::TensorOptions{torch::kU8};
cyy's avatar
cyy committed
107
108
109
110
  auto out_tensor =
      torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options);
  jpegBuf = nullptr;
  return out_tensor;
111
112
}
#endif
113
114
115

} // namespace image
} // namespace vision