decode_png.cpp 6.68 KB
Newer Older
1
2
3
4
5
#include "decode_png.h"
#include "common_png.h"

namespace vision {
namespace image {
6
7

#if !PNG_FOUND
8
9
10
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
  TORCH_CHECK(
      false, "decode_png: torchvision not compiled with libPNG support");
11
12
13
}
#else

14
15
16
17
18
bool is_little_endian() {
  uint32_t x = 1;
  return *(uint8_t*)&x;
}

19
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
Francisco Massa's avatar
Francisco Massa committed
20
21
22
23
24
25
26
  // Check that the input tensor dtype is uint8
  TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
  // Check that the input tensor is 1-dimensional
  TORCH_CHECK(
      data.dim() == 1 && data.numel() > 0,
      "Expected a non empty 1-dimensional tensor");

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
  auto png_ptr =
      png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
  TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")
  auto info_ptr = png_create_info_struct(png_ptr);
  if (!info_ptr) {
    png_destroy_read_struct(&png_ptr, nullptr, nullptr);
    // Seems redundant with the if statement. done here to avoid leaking memory.
    TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
  }

  auto datap = data.accessor<unsigned char, 1>().data();

  if (setjmp(png_jmpbuf(png_ptr)) != 0) {
    png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
    TORCH_CHECK(false, "Internal error.");
  }
  auto is_png = !png_sig_cmp(datap, 0, 8);
  TORCH_CHECK(is_png, "Content is not png!")

  struct Reader {
    png_const_bytep ptr;
  } reader;
  reader.ptr = png_const_bytep(datap) + 8;

  auto read_callback =
      [](png_structp png_ptr, png_bytep output, png_size_t bytes) {
        auto reader = static_cast<Reader*>(png_get_io_ptr(png_ptr));
        std::copy(reader->ptr, reader->ptr + bytes, output);
        reader->ptr += bytes;
      };
  png_set_sig_bytes(png_ptr, 8);
  png_set_read_fn(png_ptr, &reader, read_callback);
  png_read_info(png_ptr, info_ptr);

  png_uint_32 width, height;
  int bit_depth, color_type;
63
  int interlace_type;
64
65
66
67
68
69
70
  auto retval = png_get_IHDR(
      png_ptr,
      info_ptr,
      &width,
      &height,
      &bit_depth,
      &color_type,
71
      &interlace_type,
72
73
74
75
76
77
78
      nullptr,
      nullptr);

  if (retval != 1) {
    png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
    TORCH_CHECK(retval == 1, "Could read image metadata from content.")
  }
79

80
  if (bit_depth > 16) {
81
    png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
82
    TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.")
83
84
  }

85
  int channels = png_get_channels(png_ptr, info_ptr);
86

Prabhat Roy's avatar
Prabhat Roy committed
87
88
89
  if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8)
    png_set_expand_gray_1_2_4_to_8(png_ptr);

90
91
92
93
94
95
96
  int number_of_passes;
  if (interlace_type == PNG_INTERLACE_ADAM7) {
    number_of_passes = png_set_interlace_handling(png_ptr);
  } else {
    number_of_passes = 1;
  }

97
  if (mode != IMAGE_READ_MODE_UNCHANGED) {
98
99
100
101
102
    // TODO: consider supporting PNG_INFO_tRNS
    bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
    bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
    bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    switch (mode) {
      case IMAGE_READ_MODE_GRAY:
        if (color_type != PNG_COLOR_TYPE_GRAY) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          }

          if (has_alpha) {
            png_set_strip_alpha(png_ptr);
          }

          if (has_color) {
            png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
          }
          channels = 1;
119
120
        }
        break;
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
      case IMAGE_READ_MODE_GRAY_ALPHA:
        if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          }

          if (!has_alpha) {
            png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
          }

          if (has_color) {
            png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
          }
          channels = 2;
136
137
        }
        break;
138
139
140
141
142
143
144
145
146
147
148
149
150
      case IMAGE_READ_MODE_RGB:
        if (color_type != PNG_COLOR_TYPE_RGB) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          } else if (!has_color) {
            png_set_gray_to_rgb(png_ptr);
          }

          if (has_alpha) {
            png_set_strip_alpha(png_ptr);
          }
          channels = 3;
151
152
        }
        break;
153
154
155
156
157
158
159
160
161
162
163
164
165
      case IMAGE_READ_MODE_RGB_ALPHA:
        if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          } else if (!has_color) {
            png_set_gray_to_rgb(png_ptr);
          }

          if (!has_alpha) {
            png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
          }
          channels = 4;
166
167
168
169
        }
        break;
      default:
        png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
170
        TORCH_CHECK(false, "The provided mode is not supported for PNG files");
171
172
173
    }

    png_read_update_info(png_ptr, info_ptr);
174
175
  }

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
  auto num_pixels_per_row = width * channels;
  auto tensor = torch::empty(
      {int64_t(height), int64_t(width), channels},
      bit_depth <= 8 ? torch::kU8 : torch::kI32);

  if (bit_depth <= 8) {
    auto t_ptr = tensor.accessor<uint8_t, 3>().data();
    for (int pass = 0; pass < number_of_passes; pass++) {
      for (png_uint_32 i = 0; i < height; ++i) {
        png_read_row(png_ptr, t_ptr, nullptr);
        t_ptr += num_pixels_per_row;
      }
      t_ptr = tensor.accessor<uint8_t, 3>().data();
    }
  } else {
    // We're reading a 16bits png, but pytorch doesn't support uint16.
    // So we read each row in a 16bits tmp_buffer which we then cast into
    // a int32 tensor instead.
    if (is_little_endian()) {
      png_set_swap(png_ptr);
    }
    int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();

    // We create a tensor instead of malloc-ing for automatic memory management
    auto tmp_buffer_tensor = torch::empty(
        {int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
    uint16_t* tmp_buffer =
        (uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();

    for (int pass = 0; pass < number_of_passes; pass++) {
      for (png_uint_32 i = 0; i < height; ++i) {
        png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
        // Now we copy the uint16 values into the int32 tensor.
        for (size_t j = 0; j < num_pixels_per_row; ++j) {
          t_ptr[j] = (int32_t)tmp_buffer[j];
        }
        t_ptr += num_pixels_per_row;
      }
      t_ptr = tensor.accessor<int32_t, 3>().data();
215
    }
216
217
  }
  png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
218
  return tensor.permute({2, 0, 1});
219
}
220
221
222
223
#endif

} // namespace image
} // namespace vision