Unverified Commit 662373f6 authored by Edgar Andrés Margffoy Tuay's avatar Edgar Andrés Margffoy Tuay Committed by GitHub
Browse files

Add encoding and writing JPEG ops (#2696)



* Add decode and write JPEG ops

* Fix styling issues

* Use int64_t instead of int

* Use std::string

* Use jpegcommon.h for read_jpeg

* Minor updates to error handling in read

* Include header only once

* Reverse header inclusion

* Update common header

* Add common definitions

* Include string

* Include header?

* Include header?

* Add Python frontend calls

* Use unsigned long directly

* Fix style issues

* Include cstddef

* Ignore clang-format on cstddef

* Also include stdio

* Add JPEG and PNG include dirs

* Use C10_EXPORT

* Add JPEG encoding test

* Set quality to 75 by default and add write jpeg test

* Minor error correction

* Use assertEquals by assertEqual

* Remove test results

* Use pre-saved PIL output

* Remove extra PIL call

* Use read_jpeg instead of PIL

* Add error tests

* Address review comments

* Fix style issues

* Set test case to uint8

* Update test error check

* Apply suggestions from code review

* Fix clang-format

* Fix lint

* Fix test

* Remove unused file

* Fix regex error message

* Fix tests
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 6e10e3f8
......@@ -33,7 +33,7 @@ add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES} ${IMAGE
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${PNG_LIBRARY} ${JPEG_LIBRARIES} Python3::Python)
set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchVision)
include_directories(torchvision/csrc)
include_directories(torchvision/csrc ${JPEG_INCLUDE_DIRS} ${PNG_INCLUDE_DIRS})
include(GNUInstallDirs)
include(CMakePackageConfigHelpers)
......
import os
import io
import glob
import unittest
import sys
......@@ -6,7 +7,8 @@ import sys
import torch
import torchvision
from PIL import Image
from torchvision.io.image import read_png, decode_png, read_jpeg, decode_jpeg
from torchvision.io.image import (
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg)
import numpy as np
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
......@@ -17,7 +19,7 @@ DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
def get_images(directory, img_ext):
assert os.path.isdir(directory)
for root, _, files in os.walk(directory):
if os.path.basename(root) == 'damaged_jpeg':
if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}:
continue
for fl in files:
......@@ -66,6 +68,71 @@ class ImageTester(unittest.TestCase):
with self.assertRaises(RuntimeError):
read_jpeg(image_path)
def test_encode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = read_jpeg(img_path)
with open(expected_file, 'rb') as f:
pil_bytes = f.read()
pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75)
self.assertTrue(jpeg_bytes.equal(pil_bytes))
with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
with self.assertRaisesRegex(
ValueError, "Image quality should be a positive number "
"between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
with self.assertRaisesRegex(
RuntimeError, "The number of channels should be 1 or 3, got: 5"):
encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
with self.assertRaisesRegex(
RuntimeError, "Input data should be a 3-dimensional tensor"):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def test_write_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img = read_jpeg(img_path)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
basedir, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
write_jpeg(img, torch_jpeg, quality=75)
with open(torch_jpeg, 'rb') as f:
torch_bytes = f.read()
with open(pil_jpeg, 'rb') as f:
pil_bytes = f.read()
os.remove(torch_jpeg)
self.assertEqual(torch_bytes, pil_bytes)
def test_read_png(self):
# Check across .png
for img_path in get_images(IMAGE_DIR, ".png"):
......
......@@ -14,4 +14,6 @@ PyMODINIT_FUNC PyInit_image(void) {
static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decodePNG)
.op("image::decode_jpeg", &decodeJPEG);
.op("image::decode_jpeg", &decodeJPEG)
.op("image::encode_jpeg", &encodeJPEG)
.op("image::write_jpeg", &writeJPEG);
......@@ -6,3 +6,4 @@
#include <torch/torch.h>
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
#include "writejpeg_cpu.h"
#include "jpegcommon.h"
#include <string>
void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
* pointer */
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
/* Always display the message. */
/* We could postpone this until after returning, if we chose. */
// (*cinfo->err->output_message)(cinfo);
/* Create the message */
(*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg);
/* Return control to the setjmp point */
longjmp(myerr->setjmp_buffer, 1);
}
#pragma once
// clang-format off
#include <cstdio>
#include <cstddef>
// clang-format on
#include <jpeglib.h>
#include <setjmp.h>
#include <string>
static const JOCTET EOI_BUFFER[1] = {JPEG_EOI};
struct torch_jpeg_error_mgr {
struct jpeg_error_mgr pub; /* "public" fields */
char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */
jmp_buf setjmp_buffer; /* for return to caller */
};
typedef struct torch_jpeg_error_mgr* torch_jpeg_error_ptr;
void torch_jpeg_error_exit(j_common_ptr cinfo);
......@@ -7,36 +7,13 @@
#if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data) {
AT_ERROR("decodeJPEG: torchvision not compiled with libjpeg support");
TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support");
}
#else
#include <jpeglib.h>
const static JOCTET EOI_BUFFER[1] = {JPEG_EOI};
char jpegLastErrorMsg[JMSG_LENGTH_MAX];
struct torch_jpeg_error_mgr {
struct jpeg_error_mgr pub; /* "public" fields */
jmp_buf setjmp_buffer; /* for return to caller */
};
typedef struct torch_jpeg_error_mgr* torch_jpeg_error_ptr;
void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
* pointer */
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
/* Always display the message. */
/* We could postpone this until after returning, if we chose. */
// (*cinfo->err->output_message)(cinfo);
/* Create the message */
(*(cinfo->err->format_message))(cinfo, jpegLastErrorMsg);
/* Return control to the setjmp point */
longjmp(myerr->setjmp_buffer, 1);
}
#include "jpegcommon.h"
struct torch_jpeg_mgr {
struct jpeg_source_mgr pub;
......@@ -50,7 +27,7 @@ static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) {
torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src;
// No more data. Probably an incomplete image; Raise exception.
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
strcpy(jpegLastErrorMsg, "Image is incomplete or truncated");
strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated");
longjmp(myerr->setjmp_buffer, 1);
src->pub.next_input_byte = EOI_BUFFER;
src->pub.bytes_in_buffer = 1;
......@@ -108,7 +85,7 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) {
* We need to clean up the JPEG object.
*/
jpeg_destroy_decompress(&cinfo);
AT_ERROR(jpegLastErrorMsg);
TORCH_CHECK(false, jerr.jpegLastErrorMsg);
}
jpeg_create_decompress(&cinfo);
......
......@@ -2,4 +2,4 @@
#include <torch/torch.h>
torch::Tensor decodeJPEG(const torch::Tensor& data);
C10_EXPORT torch::Tensor decodeJPEG(const torch::Tensor& data);
......@@ -4,4 +4,4 @@
#include <torch/torch.h>
#include <string>
torch::Tensor decodePNG(const torch::Tensor& data);
C10_EXPORT torch::Tensor decodePNG(const torch::Tensor& data);
#include "writejpeg_cpu.h"
#include <setjmp.h>
#include <string>
#if !JPEG_FOUND
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
TORCH_CHECK(
false, "encodeJPEG: torchvision not compiled with libjpeg support");
}
void writeJPEG(
const torch::Tensor& data,
std::string filename,
int64_t quality) {
TORCH_CHECK(
false, "writeJPEG: torchvision not compiled with libjpeg support");
}
#else
#include <jpeglib.h>
#include "jpegcommon.h"
torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality) {
// Define compression structures and error handling
struct jpeg_compress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
// Define buffer to write JPEG information to and its size
unsigned long jpegSize = 0;
uint8_t* jpegBuf = NULL;
cinfo.err = jpeg_std_error(&jerr.pub);
jerr.pub.error_exit = torch_jpeg_error_exit;
/* Establish the setjmp return context for my_error_exit to use. */
if (setjmp(jerr.setjmp_buffer)) {
/* If we get here, the JPEG code has signaled an error.
* We need to clean up the JPEG object and the buffer.
*/
jpeg_destroy_compress(&cinfo);
if (jpegBuf != NULL) {
free(jpegBuf);
}
TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg);
}
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");
// Get image info
int channels = data.size(0);
int height = data.size(1);
int width = data.size(2);
auto input = data.permute({1, 2, 0}).contiguous();
TORCH_CHECK(
channels == 1 || channels == 3,
"The number of channels should be 1 or 3, got: ",
channels);
// Initialize JPEG structure
jpeg_create_compress(&cinfo);
// Set output image information
cinfo.image_width = width;
cinfo.image_height = height;
cinfo.input_components = channels;
cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB;
jpeg_set_defaults(&cinfo);
jpeg_set_quality(&cinfo, quality, TRUE);
// Save JPEG output to a buffer
jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize);
// Start JPEG compression
jpeg_start_compress(&cinfo, TRUE);
auto stride = width * channels;
auto ptr = input.data_ptr<uint8_t>();
// Encode JPEG file
while (cinfo.next_scanline < cinfo.image_height) {
jpeg_write_scanlines(&cinfo, &ptr, 1);
ptr += stride;
}
jpeg_finish_compress(&cinfo);
jpeg_destroy_compress(&cinfo);
torch::TensorOptions options = torch::TensorOptions{torch::kU8};
auto outTensor = torch::empty({(long)jpegSize}, options);
// Copy memory from jpeg buffer, since torch cannot get ownership of it via
// `from_blob`
auto outPtr = outTensor.data_ptr<uint8_t>();
std::memcpy(outPtr, jpegBuf, sizeof(uint8_t) * outTensor.numel());
free(jpegBuf);
return outTensor;
}
void writeJPEG(
const torch::Tensor& data,
std::string filename,
int64_t quality) {
auto jpegBuf = encodeJPEG(data, quality);
auto fileBytes = jpegBuf.data_ptr<uint8_t>();
auto fileCStr = filename.c_str();
FILE* outfile = fopen(fileCStr, "wb");
TORCH_CHECK(outfile != NULL, "Error opening output jpeg file");
fwrite(fileBytes, sizeof(uint8_t), jpegBuf.numel(), outfile);
fclose(outfile);
}
#endif
#pragma once
#include <torch/torch.h>
C10_EXPORT torch::Tensor encodeJPEG(const torch::Tensor& data, int64_t quality);
C10_EXPORT void writeJPEG(
const torch::Tensor& data,
std::string filename,
int64_t quality);
......@@ -102,3 +102,42 @@ def read_jpeg(path: str) -> torch.Tensor:
raise ValueError("Expected a non empty file.")
data = torch.from_file(path, dtype=torch.uint8, size=size)
return decode_jpeg(data)
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
"""
Takes an input tensor in CHW layout (or HW in the case of grayscale images)
and returns a buffer with the contents of its corresponding JPEG file.
Arguments:
input (Tensor[channels, image_height, image_width]): int8 image tensor
of `c` channels, where `c` must be 1 or 3.
quality (int): Quality of the resulting JPEG file, it must be a number
between 1 and 100. Default: 75
Returns
output (Tensor[1]): A one dimensional int8 tensor that contains the raw
bytes of the JPEG file.
"""
if quality < 1 or quality > 100:
raise ValueError('Image quality should be a positive number '
'between 1 and 100')
output = torch.ops.image.encode_jpeg(input, quality)
return output
def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
"""
Takes an input tensor in CHW layout (or HW in the case of grayscale images)
and saves it in a JPEG file.
Arguments:
input (Tensor[channels, image_height, image_width]): int8 image tensor
of `c` channels, where `c` must be 1 or 3.
filename (str): Path to save the image.
quality (int): Quality of the resulting JPEG file, it must be a number
between 1 and 100. Default: 75
"""
if quality < 1 or quality > 100:
raise ValueError('Image quality should be a positive number '
'between 1 and 100')
torch.ops.image.write_jpeg(input, filename, quality)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment