decode_jpeg.cpp 8.02 KB
Newer Older
1
2
#include "decode_jpeg.h"
#include "common_jpeg.h"
3
#include "exif.h"
4
5
6

namespace vision {
namespace image {
7
8

#if !JPEG_FOUND
9
10
11
12
torch::Tensor decode_jpeg(
    const torch::Tensor& data,
    ImageReadMode mode,
    bool apply_exif_orientation) {
13
  TORCH_CHECK(
14
      false, "decode_jpeg: torchvision not compiled with libjpeg support");
15
16
}
#else
17
18

using namespace detail;
19
using namespace exif_private;
20
21

namespace {
22
23
24
25
26
27
28
29
30
31

struct torch_jpeg_mgr {
  struct jpeg_source_mgr pub;
  const JOCTET* data;
  size_t len;
};

static void torch_jpeg_init_source(j_decompress_ptr cinfo) {}

static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) {
32
33
  // No more data.  Probably an incomplete image;  Raise exception.
  torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
34
  strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated");
35
  longjmp(myerr->setjmp_buffer, 1);
36
37
38
39
}

static void torch_jpeg_skip_input_data(j_decompress_ptr cinfo, long num_bytes) {
  torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src;
40
  if (src->pub.bytes_in_buffer < (size_t)num_bytes) {
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    // Skipping over all of remaining data;  output EOI.
    src->pub.next_input_byte = EOI_BUFFER;
    src->pub.bytes_in_buffer = 1;
  } else {
    // Skipping over only some of the remaining data.
    src->pub.next_input_byte += num_bytes;
    src->pub.bytes_in_buffer -= num_bytes;
  }
}

static void torch_jpeg_term_source(j_decompress_ptr cinfo) {}

static void torch_jpeg_set_source_mgr(
    j_decompress_ptr cinfo,
    const unsigned char* data,
    size_t len) {
  torch_jpeg_mgr* src;
  if (cinfo->src == 0) { // if this is first time;  allocate memory
    cinfo->src = (struct jpeg_source_mgr*)(*cinfo->mem->alloc_small)(
        (j_common_ptr)cinfo, JPOOL_PERMANENT, sizeof(torch_jpeg_mgr));
  }
  src = (torch_jpeg_mgr*)cinfo->src;
  src->pub.init_source = torch_jpeg_init_source;
  src->pub.fill_input_buffer = torch_jpeg_fill_input_buffer;
  src->pub.skip_input_data = torch_jpeg_skip_input_data;
  src->pub.resync_to_restart = jpeg_resync_to_restart; // default
  src->pub.term_source = torch_jpeg_term_source;
  // fill the buffers
  src->data = (const JOCTET*)data;
  src->len = len;
  src->pub.bytes_in_buffer = len;
  src->pub.next_input_byte = src->data;
73
74

  jpeg_save_markers(cinfo, APP1, 0xffff);
75
76
}

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
inline unsigned char clamped_cmyk_rgb_convert(
    unsigned char k,
    unsigned char cmy) {
  // Inspired from Pillow:
  // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
  int v = k * cmy + 128;
  v = ((v >> 8) + v) >> 8;
  return std::clamp(k - v, 0, 255);
}

void convert_line_cmyk_to_rgb(
    j_decompress_ptr cinfo,
    const unsigned char* cmyk_line,
    unsigned char* rgb_line) {
  int width = cinfo->output_width;
  for (int i = 0; i < width; ++i) {
    int c = cmyk_line[i * 4 + 0];
    int m = cmyk_line[i * 4 + 1];
    int y = cmyk_line[i * 4 + 2];
    int k = cmyk_line[i * 4 + 3];

    rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c);
    rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m);
    rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y);
  }
}

inline unsigned char rgb_to_gray(int r, int g, int b) {
  // Inspired from Pillow:
  // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
  return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16;
}

void convert_line_cmyk_to_gray(
    j_decompress_ptr cinfo,
    const unsigned char* cmyk_line,
    unsigned char* gray_line) {
  int width = cinfo->output_width;
  for (int i = 0; i < width; ++i) {
    int c = cmyk_line[i * 4 + 0];
    int m = cmyk_line[i * 4 + 1];
    int y = cmyk_line[i * 4 + 2];
    int k = cmyk_line[i * 4 + 3];

    int r = clamped_cmyk_rgb_convert(k, 255 - c);
    int g = clamped_cmyk_rgb_convert(k, 255 - m);
    int b = clamped_cmyk_rgb_convert(k, 255 - y);

    gray_line[i] = rgb_to_gray(r, g, b);
  }
}

129
130
} // namespace

131
132
133
134
torch::Tensor decode_jpeg(
    const torch::Tensor& data,
    ImageReadMode mode,
    bool apply_exif_orientation) {
Kai Zhang's avatar
Kai Zhang committed
135
136
  C10_LOG_API_USAGE_ONCE(
      "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg");
Francisco Massa's avatar
Francisco Massa committed
137
138
139
140
141
142
143
  // Check that the input tensor dtype is uint8
  TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
  // Check that the input tensor is 1-dimensional
  TORCH_CHECK(
      data.dim() == 1 && data.numel() > 0,
      "Expected a non empty 1-dimensional tensor");

144
145
146
147
148
149
150
151
152
153
154
155
156
  struct jpeg_decompress_struct cinfo;
  struct torch_jpeg_error_mgr jerr;

  auto datap = data.data_ptr<uint8_t>();
  // Setup decompression structure
  cinfo.err = jpeg_std_error(&jerr.pub);
  jerr.pub.error_exit = torch_jpeg_error_exit;
  /* Establish the setjmp return context for my_error_exit to use. */
  if (setjmp(jerr.setjmp_buffer)) {
    /* If we get here, the JPEG code has signaled an error.
     * We need to clean up the JPEG object.
     */
    jpeg_destroy_decompress(&cinfo);
157
    TORCH_CHECK(false, jerr.jpegLastErrorMsg);
158
159
160
161
162
163
164
  }

  jpeg_create_decompress(&cinfo);
  torch_jpeg_set_source_mgr(&cinfo, datap, data.numel());

  // read info from header.
  jpeg_read_header(&cinfo, TRUE);
165

166
  int channels = cinfo.num_components;
167
  bool cmyk_to_rgb_or_gray = false;
168

169
170
171
  if (mode != IMAGE_READ_MODE_UNCHANGED) {
    switch (mode) {
      case IMAGE_READ_MODE_GRAY:
172
173
174
175
176
        if (cinfo.jpeg_color_space == JCS_CMYK ||
            cinfo.jpeg_color_space == JCS_YCCK) {
          cinfo.out_color_space = JCS_CMYK;
          cmyk_to_rgb_or_gray = true;
        } else {
177
178
          cinfo.out_color_space = JCS_GRAYSCALE;
        }
179
        channels = 1;
180
        break;
181
      case IMAGE_READ_MODE_RGB:
182
183
184
185
186
        if (cinfo.jpeg_color_space == JCS_CMYK ||
            cinfo.jpeg_color_space == JCS_YCCK) {
          cinfo.out_color_space = JCS_CMYK;
          cmyk_to_rgb_or_gray = true;
        } else {
187
188
          cinfo.out_color_space = JCS_RGB;
        }
189
        channels = 3;
190
191
192
193
194
195
196
197
        break;
      /*
       * Libjpeg does not support converting from CMYK to grayscale etc. There
       * is a way to do this but it involves converting it manually to RGB:
       * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
       */
      default:
        jpeg_destroy_decompress(&cinfo);
198
        TORCH_CHECK(false, "The provided mode is not supported for JPEG files");
199
200
201
202
203
    }

    jpeg_calc_output_dimensions(&cinfo);
  }

204
205
  int exif_orientation = -1;
  if (apply_exif_orientation) {
206
    exif_orientation = fetch_jpeg_exif_orientation(&cinfo);
207
208
  }

209
210
211
212
213
  jpeg_start_decompress(&cinfo);

  int height = cinfo.output_height;
  int width = cinfo.output_width;

214
215
216
  int stride = width * channels;
  auto tensor =
      torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
217
  auto ptr = tensor.data_ptr<uint8_t>();
218
219
220
221
222
  torch::Tensor cmyk_line_tensor;
  if (cmyk_to_rgb_or_gray) {
    cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
  }

223
224
225
226
227
  while (cinfo.output_scanline < cinfo.output_height) {
    /* jpeg_read_scanlines expects an array of pointers to scanlines.
     * Here the array is only one element long, but you could ask for
     * more than one scanline at a time if that's more convenient.
     */
228
229
230
231
232
233
234
235
236
237
238
239
    if (cmyk_to_rgb_or_gray) {
      auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
      jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);

      if (channels == 3) {
        convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr);
      } else if (channels == 1) {
        convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr);
      }
    } else {
      jpeg_read_scanlines(&cinfo, &ptr, 1);
    }
240
241
242
243
244
    ptr += stride;
  }

  jpeg_finish_decompress(&cinfo);
  jpeg_destroy_decompress(&cinfo);
245
246
247
248
249
250
  auto output = tensor.permute({2, 0, 1});

  if (apply_exif_orientation) {
    return exif_orientation_transform(output, exif_orientation);
  }
  return output;
251
}
252
#endif // #if !JPEG_FOUND
253

254
int64_t _jpeg_version() {
255
#if JPEG_FOUND
256
257
258
259
260
261
262
263
264
265
266
  return JPEG_LIB_VERSION;
#else
  return -1;
#endif
}

bool _is_compiled_against_turbo() {
#ifdef LIBJPEG_TURBO_VERSION
  return true;
#else
  return false;
267
#endif
268
}
269
270
271

} // namespace image
} // namespace vision