encode_png.cpp 4.75 KB
Newer Older
1
2
3
4
5
6
#include "encode_jpeg.h"

#include "common_png.h"

namespace vision {
namespace image {
7
8
9

#if !PNG_FOUND

10
11
12
torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
  TORCH_CHECK(
      false, "encode_png: torchvision not compiled with libpng support");
13
14
15
}

#else
16
17

namespace {
18
19
20
21
22
23
24
25
26
27
28

struct torch_mem_encode {
  char* buffer;
  size_t size;
};

struct torch_png_error_mgr {
  const char* pngLastErrorMsg; /* error messages */
  jmp_buf setjmp_buffer; /* for return to caller */
};

29
using torch_png_error_mgr_ptr = torch_png_error_mgr*;
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

void torch_png_error(png_structp png_ptr, png_const_charp error_msg) {
  /* png_ptr->err really points to a torch_png_error_mgr struct, so coerce
   * pointer */
  auto error_ptr = (torch_png_error_mgr_ptr)png_get_error_ptr(png_ptr);
  /* Replace the error message on the error structure */
  error_ptr->pngLastErrorMsg = error_msg;
  /* Return control to the setjmp point */
  longjmp(error_ptr->setjmp_buffer, 1);
}

void torch_png_write_data(
    png_structp png_ptr,
    png_bytep data,
    png_size_t length) {
  struct torch_mem_encode* p =
      (struct torch_mem_encode*)png_get_io_ptr(png_ptr);
  size_t nsize = p->size + length;

  /* allocate or grow buffer */
  if (p->buffer)
    p->buffer = (char*)realloc(p->buffer, nsize);
  else
    p->buffer = (char*)malloc(nsize);

  if (!p->buffer)
    png_error(png_ptr, "Write Error");

  /* copy new bytes to end of buffer */
  memcpy(p->buffer + p->size, data, length);
  p->size += length;
}

63
64
65
} // namespace

torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
Kai Zhang's avatar
Kai Zhang committed
66
  C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png");
67
68
69
70
71
72
73
  // Define compression structures and error handling
  png_structp png_write;
  png_infop info_ptr;
  struct torch_png_error_mgr err_ptr;

  // Define output buffer
  struct torch_mem_encode buf_info;
74
  buf_info.buffer = nullptr;
75
76
77
78
79
80
81
  buf_info.size = 0;

  /* Establish the setjmp return context for my_error_exit to use. */
  if (setjmp(err_ptr.setjmp_buffer)) {
    /* If we get here, the PNG code has signaled an error.
     * We need to clean up the PNG object and the buffer.
     */
82
    if (info_ptr != nullptr) {
83
84
85
      png_destroy_info_struct(png_write, &info_ptr);
    }

86
87
    if (png_write != nullptr) {
      png_destroy_write_struct(&png_write, nullptr);
88
89
    }

90
    if (buf_info.buffer != nullptr) {
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
      free(buf_info.buffer);
    }

    TORCH_CHECK(false, err_ptr.pngLastErrorMsg);
  }

  // Check that the compression level is between 0 and 9
  TORCH_CHECK(
      compression_level >= 0 && compression_level <= 9,
      "Compression level should be between 0 and 9");

  // Check that the input tensor is on CPU
  TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");

  // Check that the input tensor dtype is uint8
  TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");

  // Check that the input tensor is 3-dimensional
  TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");

  // Get image info
  int channels = data.size(0);
  int height = data.size(1);
  int width = data.size(2);
  auto input = data.permute({1, 2, 0}).contiguous();

  TORCH_CHECK(
      channels == 1 || channels == 3,
      "The number of channels should be 1 or 3, got: ",
      channels);

  // Initialize PNG structures
  png_write = png_create_write_struct(
124
      PNG_LIBPNG_VER_STRING, &err_ptr, torch_png_error, nullptr);
125
126
127
128

  info_ptr = png_create_info_struct(png_write);

  // Define custom buffer output
129
  png_set_write_fn(png_write, &buf_info, torch_png_write_data, nullptr);
130
131

  // Set output image information
132
  auto color_type = channels == 1 ? PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB;
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  png_set_IHDR(
      png_write,
      info_ptr,
      width,
      height,
      8,
      color_type,
      PNG_INTERLACE_NONE,
      PNG_COMPRESSION_TYPE_DEFAULT,
      PNG_FILTER_TYPE_DEFAULT);

  // Set image compression level
  png_set_compression_level(png_write, compression_level);

  // Write file header
  png_write_info(png_write, info_ptr);

  auto stride = width * channels;
  auto ptr = input.data_ptr<uint8_t>();

  // Encode PNG file
154
  for (int y = 0; y < height; ++y) {
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    png_write_row(png_write, ptr);
    ptr += stride;
  }

  // Write EOF
  png_write_end(png_write, info_ptr);

  // Destroy structures
  png_destroy_write_struct(&png_write, &info_ptr);

  torch::TensorOptions options = torch::TensorOptions{torch::kU8};
  auto outTensor = torch::empty({(long)buf_info.size}, options);

  // Copy memory from png buffer, since torch cannot get ownership of it via
  // `from_blob`
  auto outPtr = outTensor.data_ptr<uint8_t>();
  std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel());
  free(buf_info.buffer);

  return outTensor;
}

#endif
178
179
180

} // namespace image
} // namespace vision