Unverified Commit f8780e2e authored by Edgar Andrés Margffoy Tuay's avatar Edgar Andrés Margffoy Tuay Committed by GitHub
Browse files

Add encoding and writing PNG ops (#2726)



* Add encode/write_png functions

* Do not redefine

* Style issues correction

* Comply with low-level interface

* Minor comment correction

* Add python frontend functions

* Add encode_png test

* Pass compession level to encode_png

* Do not compare output buffers

* Convert to bytes

* Compare pil image instead of buffer

* Add error tests

* Add test_write_png

* Remove png test assets

* Register writePNG correctly

* Update write_png docstring

* Do not preserve PIL image beyond the scope
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 6f758832
...@@ -8,7 +8,8 @@ import torch ...@@ -8,7 +8,8 @@ import torch
import torchvision import torchvision
from PIL import Image from PIL import Image
from torchvision.io.image import ( from torchvision.io.image import (
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, _read_file) read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, _read_file,
encode_png, write_png)
import numpy as np import numpy as np
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
...@@ -154,6 +155,53 @@ class ImageTester(unittest.TestCase): ...@@ -154,6 +155,53 @@ class ImageTester(unittest.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_encode_png(self):
for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
png_buf = encode_png(img_pil, compression_level=6)
rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
rec_img = torch.from_numpy(np.array(rec_img))
rec_img = rec_img.permute(2, 0, 1)
self.assertTrue(img_pil.equal(rec_img))
with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"):
encode_png(torch.empty((3, 100, 100), dtype=torch.float32))
with self.assertRaisesRegex(
RuntimeError, "Compression level should be between 0 and 9"):
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
compression_level=-1)
with self.assertRaisesRegex(
RuntimeError, "Compression level should be between 0 and 9"):
encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
compression_level=10)
with self.assertRaisesRegex(
RuntimeError, "The number of channels should be 1 or 3, got: 5"):
encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))
def test_write_png(self):
for img_path in get_images(IMAGE_DIR, '.png'):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(basedir, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
os.remove(torch_png)
saved_image = saved_image.permute(2, 0, 1)
self.assertTrue(img_pil.equal(saved_image))
def test_decode_image(self): def test_decode_image(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth')) img_pil = torch.load(img_path.replace('jpg', 'pth'))
......
...@@ -14,6 +14,8 @@ PyMODINIT_FUNC PyInit_image(void) { ...@@ -14,6 +14,8 @@ PyMODINIT_FUNC PyInit_image(void) {
static auto registry = torch::RegisterOperators() static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decodePNG) .op("image::decode_png", &decodePNG)
.op("image::encode_png", &encodePNG)
.op("image::write_png", &writePNG)
.op("image::decode_jpeg", &decodeJPEG) .op("image::decode_jpeg", &decodeJPEG)
.op("image::encode_jpeg", &encodeJPEG) .op("image::encode_jpeg", &encodeJPEG)
.op("image::write_jpeg", &writeJPEG) .op("image::write_jpeg", &writeJPEG)
......
...@@ -7,3 +7,4 @@ ...@@ -7,3 +7,4 @@
#include "readjpeg_cpu.h" #include "readjpeg_cpu.h"
#include "readpng_cpu.h" #include "readpng_cpu.h"
#include "writejpeg_cpu.h" #include "writejpeg_cpu.h"
#include "writepng_cpu.h"
#include "writejpeg_cpu.h"
#include <setjmp.h>
#include <string>
#if !PNG_FOUND
torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
TORCH_CHECK(false, "encodePNG: torchvision not compiled with libpng support");
}
void writePNG(
const torch::Tensor& data,
std::string filename,
int64_t compression_level) {
TORCH_CHECK(false, "writePNG: torchvision not compiled with libpng support");
}
#else
#include <png.h>
struct torch_mem_encode {
char* buffer;
size_t size;
};
struct torch_png_error_mgr {
const char* pngLastErrorMsg; /* error messages */
jmp_buf setjmp_buffer; /* for return to caller */
};
typedef torch_png_error_mgr* torch_png_error_mgr_ptr;
void torch_png_warn(png_structp png_ptr, png_const_charp warn_msg) {
/* Display warning to user */
TORCH_WARN_ONCE(warn_msg);
}
void torch_png_error(png_structp png_ptr, png_const_charp error_msg) {
/* png_ptr->err really points to a torch_png_error_mgr struct, so coerce
* pointer */
auto error_ptr = (torch_png_error_mgr_ptr)png_get_error_ptr(png_ptr);
/* Replace the error message on the error structure */
error_ptr->pngLastErrorMsg = error_msg;
/* Return control to the setjmp point */
longjmp(error_ptr->setjmp_buffer, 1);
}
void torch_png_write_data(
png_structp png_ptr,
png_bytep data,
png_size_t length) {
struct torch_mem_encode* p =
(struct torch_mem_encode*)png_get_io_ptr(png_ptr);
size_t nsize = p->size + length;
/* allocate or grow buffer */
if (p->buffer)
p->buffer = (char*)realloc(p->buffer, nsize);
else
p->buffer = (char*)malloc(nsize);
if (!p->buffer)
png_error(png_ptr, "Write Error");
/* copy new bytes to end of buffer */
memcpy(p->buffer + p->size, data, length);
p->size += length;
}
torch::Tensor encodePNG(const torch::Tensor& data, int64_t compression_level) {
// Define compression structures and error handling
png_structp png_write;
png_infop info_ptr;
struct torch_png_error_mgr err_ptr;
// Define output buffer
struct torch_mem_encode buf_info;
buf_info.buffer = NULL;
buf_info.size = 0;
/* Establish the setjmp return context for my_error_exit to use. */
if (setjmp(err_ptr.setjmp_buffer)) {
/* If we get here, the PNG code has signaled an error.
* We need to clean up the PNG object and the buffer.
*/
if (info_ptr != NULL) {
png_destroy_info_struct(png_write, &info_ptr);
}
if (png_write != NULL) {
png_destroy_write_struct(&png_write, NULL);
}
if (buf_info.buffer != NULL) {
free(buf_info.buffer);
}
TORCH_CHECK(false, err_ptr.pngLastErrorMsg);
}
// Check that the compression level is between 0 and 9
TORCH_CHECK(
compression_level >= 0 && compression_level <= 9,
"Compression level should be between 0 and 9");
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");
// Get image info
int channels = data.size(0);
int height = data.size(1);
int width = data.size(2);
auto input = data.permute({1, 2, 0}).contiguous();
TORCH_CHECK(
channels == 1 || channels == 3,
"The number of channels should be 1 or 3, got: ",
channels);
// Initialize PNG structures
png_write = png_create_write_struct(
PNG_LIBPNG_VER_STRING, &err_ptr, torch_png_error, NULL);
info_ptr = png_create_info_struct(png_write);
// Define custom buffer output
png_set_write_fn(png_write, &buf_info, torch_png_write_data, NULL);
// Set output image information
auto color_type = PNG_COLOR_TYPE_GRAY ? channels == 1 : PNG_COLOR_TYPE_RGB;
png_set_IHDR(
png_write,
info_ptr,
width,
height,
8,
color_type,
PNG_INTERLACE_NONE,
PNG_COMPRESSION_TYPE_DEFAULT,
PNG_FILTER_TYPE_DEFAULT);
// Set image compression level
png_set_compression_level(png_write, compression_level);
// Write file header
png_write_info(png_write, info_ptr);
auto stride = width * channels;
auto ptr = input.data_ptr<uint8_t>();
// Encode PNG file
for (size_t y = 0; y < height; ++y) {
png_write_row(png_write, ptr);
ptr += stride;
}
// Write EOF
png_write_end(png_write, info_ptr);
// Destroy structures
png_destroy_write_struct(&png_write, &info_ptr);
torch::TensorOptions options = torch::TensorOptions{torch::kU8};
auto outTensor = torch::empty({(long)buf_info.size}, options);
// Copy memory from png buffer, since torch cannot get ownership of it via
// `from_blob`
auto outPtr = outTensor.data_ptr<uint8_t>();
std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel());
free(buf_info.buffer);
return outTensor;
}
void writePNG(
const torch::Tensor& data,
std::string filename,
int64_t compression_level) {
auto pngBuf = encodePNG(data, compression_level);
auto fileBytes = pngBuf.data_ptr<uint8_t>();
auto fileCStr = filename.c_str();
FILE* outfile = fopen(fileCStr, "wb");
TORCH_CHECK(outfile != NULL, "Error opening output png file");
fwrite(fileBytes, sizeof(uint8_t), pngBuf.numel(), outfile);
fclose(outfile);
}
#endif
#pragma once
#include <torch/torch.h>
C10_EXPORT torch::Tensor encodePNG(
const torch::Tensor& data,
int64_t compression_level);
C10_EXPORT void writePNG(
const torch::Tensor& data,
std::string filename,
int64_t compression_level);
...@@ -65,6 +65,37 @@ def read_png(path: str) -> torch.Tensor: ...@@ -65,6 +65,37 @@ def read_png(path: str) -> torch.Tensor:
return decode_png(data) return decode_png(data)
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
of its corresponding PNG file.
Arguments:
input (Tensor[channels, image_height, image_width]): int8 image tensor
of `c` channels, where `c` must 3 or 1.
compression_level (int): Compression factor for the resulting file, it
must be a number between 0 and 9. Default: 6
Returns
output (Tensor[1]): A one dimensional int8 tensor that contains the raw
bytes of the PNG file.
"""
output = torch.ops.image.encode_png(input, compression_level)
return output
def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
"""
Takes an input tensor in CHW layout (or HW in the case of grayscale images)
and saves it in a PNG file.
Arguments:
input (Tensor[channels, image_height, image_width]): int8 image tensor
of `c` channels, where `c` must be 1 or 3.
filename (str): Path to save the image.
compression_level (int): Compression factor for the resulting file, it
must be a number between 0 and 9. Default: 6
"""
torch.ops.image.write_png(input, filename, compression_level)
def decode_jpeg(input: torch.Tensor) -> torch.Tensor: def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
""" """
Decodes a JPEG image into a 3 dimensional RGB Tensor. Decodes a JPEG image into a 3 dimensional RGB Tensor.
...@@ -94,8 +125,8 @@ def read_jpeg(path: str) -> torch.Tensor: ...@@ -94,8 +125,8 @@ def read_jpeg(path: str) -> torch.Tensor:
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
""" """
Takes an input tensor in CHW layout (or HW in the case of grayscale images) Takes an input tensor in CHW layout and returns a buffer with the contents
and returns a buffer with the contents of its corresponding JPEG file. of its corresponding JPEG file.
Arguments: Arguments:
input (Tensor[channels, image_height, image_width]): int8 image tensor input (Tensor[channels, image_height, image_width]): int8 image tensor
of `c` channels, where `c` must be 1 or 3. of `c` channels, where `c` must be 1 or 3.
...@@ -115,8 +146,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: ...@@ -115,8 +146,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
""" """
Takes an input tensor in CHW layout (or HW in the case of grayscale images) Takes an input tensor in CHW layout and saves it in a JPEG file.
and saves it in a JPEG file.
Arguments: Arguments:
input (Tensor[channels, image_height, image_width]): int8 image tensor input (Tensor[channels, image_height, image_width]): int8 image tensor
of `c` channels, where `c` must be 1 or 3. of `c` channels, where `c` must be 1 or 3.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment