decode_png.cpp 5.45 KB
Newer Older
1
2
3
4
5
#include "decode_png.h"
#include "common_png.h"

namespace vision {
namespace image {
6
7

#if !PNG_FOUND
8
9
10
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
  TORCH_CHECK(
      false, "decode_png: torchvision not compiled with libPNG support");
11
12
13
}
#else

14
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
Francisco Massa's avatar
Francisco Massa committed
15
16
17
18
19
20
21
  // Check that the input tensor dtype is uint8
  TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
  // Check that the input tensor is 1-dimensional
  TORCH_CHECK(
      data.dim() == 1 && data.numel() > 0,
      "Expected a non empty 1-dimensional tensor");

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  auto png_ptr =
      png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
  TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")
  auto info_ptr = png_create_info_struct(png_ptr);
  if (!info_ptr) {
    png_destroy_read_struct(&png_ptr, nullptr, nullptr);
    // Seems redundant with the if statement. done here to avoid leaking memory.
    TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
  }

  auto datap = data.accessor<unsigned char, 1>().data();

  if (setjmp(png_jmpbuf(png_ptr)) != 0) {
    png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
    TORCH_CHECK(false, "Internal error.");
  }
  auto is_png = !png_sig_cmp(datap, 0, 8);
  TORCH_CHECK(is_png, "Content is not png!")

  struct Reader {
    png_const_bytep ptr;
  } reader;
  reader.ptr = png_const_bytep(datap) + 8;

  auto read_callback =
      [](png_structp png_ptr, png_bytep output, png_size_t bytes) {
        auto reader = static_cast<Reader*>(png_get_io_ptr(png_ptr));
        std::copy(reader->ptr, reader->ptr + bytes, output);
        reader->ptr += bytes;
      };
  png_set_sig_bytes(png_ptr, 8);
  png_set_read_fn(png_ptr, &reader, read_callback);
  png_read_info(png_ptr, info_ptr);

  png_uint_32 width, height;
  int bit_depth, color_type;
58
  int interlace_type;
59
60
61
62
63
64
65
  auto retval = png_get_IHDR(
      png_ptr,
      info_ptr,
      &width,
      &height,
      &bit_depth,
      &color_type,
66
      &interlace_type,
67
68
69
70
71
72
73
      nullptr,
      nullptr);

  if (retval != 1) {
    png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
    TORCH_CHECK(retval == 1, "Could read image metadata from content.")
  }
74

75
76
77
78
79
  if (bit_depth > 8) {
    png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
    TORCH_CHECK(false, "At most 8-bit PNG images are supported currently.")
  }

80
  int channels = png_get_channels(png_ptr, info_ptr);
81

Prabhat Roy's avatar
Prabhat Roy committed
82
83
84
  if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8)
    png_set_expand_gray_1_2_4_to_8(png_ptr);

85
86
87
88
89
90
91
  int number_of_passes;
  if (interlace_type == PNG_INTERLACE_ADAM7) {
    number_of_passes = png_set_interlace_handling(png_ptr);
  } else {
    number_of_passes = 1;
  }

92
  if (mode != IMAGE_READ_MODE_UNCHANGED) {
93
94
95
96
97
    // TODO: consider supporting PNG_INFO_tRNS
    bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
    bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
    bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    switch (mode) {
      case IMAGE_READ_MODE_GRAY:
        if (color_type != PNG_COLOR_TYPE_GRAY) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          }

          if (has_alpha) {
            png_set_strip_alpha(png_ptr);
          }

          if (has_color) {
            png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
          }
          channels = 1;
114
115
        }
        break;
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
      case IMAGE_READ_MODE_GRAY_ALPHA:
        if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          }

          if (!has_alpha) {
            png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
          }

          if (has_color) {
            png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
          }
          channels = 2;
131
132
        }
        break;
133
134
135
136
137
138
139
140
141
142
143
144
145
      case IMAGE_READ_MODE_RGB:
        if (color_type != PNG_COLOR_TYPE_RGB) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          } else if (!has_color) {
            png_set_gray_to_rgb(png_ptr);
          }

          if (has_alpha) {
            png_set_strip_alpha(png_ptr);
          }
          channels = 3;
146
147
        }
        break;
148
149
150
151
152
153
154
155
156
157
158
159
160
      case IMAGE_READ_MODE_RGB_ALPHA:
        if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) {
          if (is_palette) {
            png_set_palette_to_rgb(png_ptr);
            has_alpha = true;
          } else if (!has_color) {
            png_set_gray_to_rgb(png_ptr);
          }

          if (!has_alpha) {
            png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
          }
          channels = 4;
161
162
163
164
        }
        break;
      default:
        png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
165
        TORCH_CHECK(false, "The provided mode is not supported for PNG files");
166
167
168
    }

    png_read_update_info(png_ptr, info_ptr);
169
170
  }

171
172
  auto tensor =
      torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
173
  auto ptr = tensor.accessor<uint8_t, 3>().data();
174
175
176
177
178
179
  for (int pass = 0; pass < number_of_passes; pass++) {
    for (png_uint_32 i = 0; i < height; ++i) {
      png_read_row(png_ptr, ptr, nullptr);
      ptr += width * channels;
    }
    ptr = tensor.accessor<uint8_t, 3>().data();
180
181
  }
  png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
182
  return tensor.permute({2, 0, 1});
183
}
184
185
186
187
#endif

} // namespace image
} // namespace vision