decode_png.cpp 7.68 KB
Newer Older
1
2
#include "decode_png.h"
#include "common_png.h"
limm's avatar
limm committed
3
#include "exif.h"
4
5
6

namespace vision {
namespace image {
7

limm's avatar
limm committed
8
9
using namespace exif_private;

10
#if !PNG_FOUND
limm's avatar
limm committed
11
12
13
14
15
torch::Tensor decode_png(
    const torch::Tensor& data,
    ImageReadMode mode,
    bool allow_16_bits,
    bool apply_exif_orientation) {
16
17
  TORCH_CHECK(
      false, "decode_png: torchvision not compiled with libPNG support");
18
19
20
}
#else

limm's avatar
limm committed
21
22
23
24
25
26
27
28
29
30
31
bool is_little_endian() {
  uint32_t x = 1;
  return *(uint8_t*)&x;
}

torch::Tensor decode_png(
    const torch::Tensor& data,
    ImageReadMode mode,
    bool allow_16_bits,
    bool apply_exif_orientation) {
  C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png");
Francisco Massa's avatar
Francisco Massa committed
32
33
34
35
36
37
38
  // Check that the input tensor dtype is uint8
  TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
  // Check that the input tensor is 1-dimensional
  TORCH_CHECK(
      data.dim() == 1 && data.numel() > 0,
      "Expected a non empty 1-dimensional tensor");

39
40
41
42
43
44
45
46
47
48
  auto png_ptr =
      png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
  TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")
  auto info_ptr = png_create_info_struct(png_ptr);
  if (!info_ptr) {
    png_destroy_read_struct(&png_ptr, nullptr, nullptr);
    // Seems redundant with the if statement. done here to avoid leaking memory.
    TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
  }

limm's avatar
limm committed
49
50
51
  auto accessor = data.accessor<unsigned char, 1>();
  auto datap = accessor.data();
  auto datap_len = accessor.size(0);
52
53
54
55
56

  if (setjmp(png_jmpbuf(png_ptr)) != 0) {
    png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
    TORCH_CHECK(false, "Internal error.");
  }
limm's avatar
limm committed
57
  TORCH_CHECK(datap_len >= 8, "Content is too small for png!")
58
59
60
61
62
  auto is_png = !png_sig_cmp(datap, 0, 8);
  TORCH_CHECK(is_png, "Content is not png!")

  struct Reader {
    png_const_bytep ptr;
limm's avatar
limm committed
63
    png_size_t count;
64
65
  } reader;
  reader.ptr = png_const_bytep(datap) + 8;
limm's avatar
limm committed
66
  reader.count = datap_len - 8;
67

limm's avatar
limm committed
68
69
70
71
72
73
74
75
76
77
78
  auto read_callback = [](png_structp png_ptr,
                          png_bytep output,
                          png_size_t bytes) {
    auto reader = static_cast<Reader*>(png_get_io_ptr(png_ptr));
    TORCH_CHECK(
        reader->count >= bytes,
        "Out of bound read in decode_png. Probably, the input image is corrupted");
    std::copy(reader->ptr, reader->ptr + bytes, output);
    reader->ptr += bytes;
    reader->count -= bytes;
  };
79
80
81
82
83
84
  png_set_sig_bytes(png_ptr, 8);
  png_set_read_fn(png_ptr, &reader, read_callback);
  png_read_info(png_ptr, info_ptr);

  png_uint_32 width, height;
  int bit_depth, color_type;
limm's avatar
limm committed
85
  int interlace_type;
86
87
88
89
90
91
92
  auto retval = png_get_IHDR(
      png_ptr,
      info_ptr,
      &width,
      &height,
      &bit_depth,
      &color_type,
limm's avatar
limm committed
93
      &interlace_type,
94
95
96
97
98
99
100
      nullptr,
      nullptr);

  if (retval != 1) {
    png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
    TORCH_CHECK(retval == 1, "Could read image metadata from content.")
  }
101

limm's avatar
limm committed
102
103
104
105
106
107
108
109
  auto max_bit_depth = allow_16_bits ? 16 : 8;
  auto err_msg = "At most " + std::to_string(max_bit_depth) +
      "-bit PNG images are supported currently.";
  if (bit_depth > max_bit_depth) {
    png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
    TORCH_CHECK(false, err_msg)
  }

110
  int channels = png_get_channels(png_ptr, info_ptr);
111

limm's avatar
limm committed
112
113
114
115
116
117
118
119
120
121
  if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8)
    png_set_expand_gray_1_2_4_to_8(png_ptr);

  int number_of_passes;
  if (interlace_type == PNG_INTERLACE_ADAM7) {
    number_of_passes = png_set_interlace_handling(png_ptr);
  } else {
    number_of_passes = 1;
  }

122
  if (mode != IMAGE_READ_MODE_UNCHANGED) {
123
124
125
126
127
    // TODO: consider supporting PNG_INFO_tRNS
    bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
    bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
    bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    switch (mode) {
      case IMAGE_READ_MODE_GRAY:
        if (color_type != PNG_COLOR_TYPE_GRAY) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          }

          if (has_alpha) {
            png_set_strip_alpha(png_ptr);
          }

          if (has_color) {
            png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
          }
          channels = 1;
144
145
        }
        break;
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
      case IMAGE_READ_MODE_GRAY_ALPHA:
        if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          }

          if (!has_alpha) {
            png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
          }

          if (has_color) {
            png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
          }
          channels = 2;
161
162
        }
        break;
163
164
165
166
167
168
169
170
171
172
173
174
175
      case IMAGE_READ_MODE_RGB:
        if (color_type != PNG_COLOR_TYPE_RGB) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          } else if (!has_color) {
            png_set_gray_to_rgb(png_ptr);
          }

          if (has_alpha) {
            png_set_strip_alpha(png_ptr);
          }
          channels = 3;
176
177
        }
        break;
178
179
180
181
182
183
184
185
186
187
188
189
190
      case IMAGE_READ_MODE_RGB_ALPHA:
        if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          } else if (!has_color) {
            png_set_gray_to_rgb(png_ptr);
          }

          if (!has_alpha) {
            png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
          }
          channels = 4;
191
192
193
194
        }
        break;
      default:
        png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
195
        TORCH_CHECK(false, "The provided mode is not supported for PNG files");
196
197
198
    }

    png_read_update_info(png_ptr, info_ptr);
199
200
  }

limm's avatar
limm committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
  auto num_pixels_per_row = width * channels;
  auto tensor = torch::empty(
      {int64_t(height), int64_t(width), channels},
      bit_depth <= 8 ? torch::kU8 : torch::kI32);

  if (bit_depth <= 8) {
    auto t_ptr = tensor.accessor<uint8_t, 3>().data();
    for (int pass = 0; pass < number_of_passes; pass++) {
      for (png_uint_32 i = 0; i < height; ++i) {
        png_read_row(png_ptr, t_ptr, nullptr);
        t_ptr += num_pixels_per_row;
      }
      t_ptr = tensor.accessor<uint8_t, 3>().data();
    }
  } else {
    // We're reading a 16bits png, but pytorch doesn't support uint16.
    // So we read each row in a 16bits tmp_buffer which we then cast into
    // a int32 tensor instead.
    if (is_little_endian()) {
      png_set_swap(png_ptr);
    }
    int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();

    // We create a tensor instead of malloc-ing for automatic memory management
    auto tmp_buffer_tensor = torch::empty(
        {int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
    uint16_t* tmp_buffer =
        (uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();

    for (int pass = 0; pass < number_of_passes; pass++) {
      for (png_uint_32 i = 0; i < height; ++i) {
        png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
        // Now we copy the uint16 values into the int32 tensor.
        for (size_t j = 0; j < num_pixels_per_row; ++j) {
          t_ptr[j] = (int32_t)tmp_buffer[j];
        }
        t_ptr += num_pixels_per_row;
      }
      t_ptr = tensor.accessor<int32_t, 3>().data();
    }
  }

  int exif_orientation = -1;
  if (apply_exif_orientation) {
    exif_orientation = fetch_png_exif_orientation(png_ptr, info_ptr);
246
  }
limm's avatar
limm committed
247

248
  png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
limm's avatar
limm committed
249
250
251
252
253
254

  auto output = tensor.permute({2, 0, 1});
  if (apply_exif_orientation) {
    return exif_orientation_transform(output, exif_orientation);
  }
  return output;
255
}
256
257
258
259
#endif

} // namespace image
} // namespace vision