Unverified Commit a884cb7b authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add support of mode and remove channels (#3024)

* Add support of mode and remove channels.

* Replacing integer mode with define constants.
parent 1706921b
......@@ -2,14 +2,12 @@ import os
import io
import glob
import unittest
import sys
import torch
import torchvision
from PIL import Image
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file)
encode_png, write_png, write_file, ImageReadMode)
import numpy as np
from common_utils import get_tmp_dir
......@@ -49,9 +47,9 @@ def normalize_dimensions(img_pil):
class ImageTester(unittest.TestCase):
def test_decode_jpeg(self):
conversion = [(None, 0), ("L", 1), ("RGB", 3)]
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)]
for img_path in get_images(IMAGE_ROOT, ".jpg"):
for pil_mode, channels in conversion:
for pil_mode, mode in conversion:
with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK"
if pil_mode is not None:
......@@ -66,7 +64,7 @@ class ImageTester(unittest.TestCase):
img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_ljpeg = decode_image(data, channels=channels)
img_ljpeg = decode_image(data, mode=mode)
# Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
......@@ -165,9 +163,10 @@ class ImageTester(unittest.TestCase):
self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self):
conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)]
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
for img_path in get_images(FAKEDATA_DIR, ".png"):
for pil_mode, channels in conversion:
for pil_mode, mode in conversion:
with Image.open(img_path) as img:
if pil_mode is not None:
img = img.convert(pil_mode)
......@@ -175,7 +174,7 @@ class ImageTester(unittest.TestCase):
img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_lpng = decode_image(data, channels=channels)
img_lpng = decode_image(data, mode=mode)
tol = 0 if conversion is None else 1
self.assertTrue(img_lpng.allclose(img_pil, atol=tol))
......
#pragma once
/* Should be kept in-sync with Python ImageReadMode enum */
using ImageReadMode = int64_t;
#define IMAGE_READ_MODE_UNCHANGED 0
#define IMAGE_READ_MODE_GRAY 1
#define IMAGE_READ_MODE_GRAY_ALPHA 2
#define IMAGE_READ_MODE_RGB 3
#define IMAGE_READ_MODE_RGB_ALPHA 4
\ No newline at end of file
#include "read_image_cpu.h"
#include <cstring>
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) {
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");
auto datap = data.data_ptr<uint8_t>();
......@@ -17,9 +16,9 @@ torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) {
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data, channels);
return decodeJPEG(data, mode);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data, channels);
return decodePNG(data, mode);
} else {
TORCH_CHECK(
false,
......
#pragma once
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
#include <torch/torch.h>
#include "image_read_mode.h"
C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
int64_t channels = 0);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
#include "readjpeg_cpu.h"
#include <ATen/ATen.h>
#include <string>
#if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support");
}
......@@ -69,16 +68,13 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data;
}
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels == 0 || channels == 1 || channels == 3,
"Number of channels not supported");
struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
......@@ -102,30 +98,33 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
// read info from header.
jpeg_read_header(&cinfo, TRUE);
int current_channels = cinfo.num_components;
int channels = cinfo.num_components;
if (channels > 0 && channels != current_channels) {
switch (channels) {
case 1: // Gray
cinfo.out_color_space = JCS_GRAYSCALE;
if (mode != IMAGE_READ_MODE_UNCHANGED) {
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
cinfo.out_color_space = JCS_GRAYSCALE;
channels = 1;
}
break;
case 3: // RGB
cinfo.out_color_space = JCS_RGB;
case IMAGE_READ_MODE_RGB:
if (cinfo.jpeg_color_space != JCS_RGB) {
cinfo.out_color_space = JCS_RGB;
channels = 3;
}
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
* is a way to do this but it involves converting it manually to RGB:
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
*
*/
default:
jpeg_destroy_decompress(&cinfo);
TORCH_CHECK(false, "Invalid number of output channels.");
TORCH_CHECK(false, "Provided mode not supported");
}
jpeg_calc_output_dimensions(&cinfo);
} else {
channels = current_channels;
}
jpeg_start_decompress(&cinfo);
......
#pragma once
#include <torch/torch.h>
#include "image_read_mode.h"
C10_EXPORT torch::Tensor decodeJPEG(
const torch::Tensor& data,
int64_t channels = 0);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
#include "readpng_cpu.h"
// Comment
#include <ATen/ATen.h>
#include <string>
#if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
}
#else
#include <png.h>
#include <setjmp.h>
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");
auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
......@@ -74,75 +70,85 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}
int current_channels = png_get_channels(png_ptr, info_ptr);
int channels = png_get_channels(png_ptr, info_ptr);
if (channels > 0) {
if (mode != IMAGE_READ_MODE_UNCHANGED) {
// TODO: consider supporting PNG_INFO_tRNS
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;
switch (channels) {
case 1: // Gray
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}
if (has_alpha) {
png_set_strip_alpha(png_ptr);
}
if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (color_type != PNG_COLOR_TYPE_GRAY) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}
if (has_alpha) {
png_set_strip_alpha(png_ptr);
}
if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
}
channels = 1;
}
break;
case 2: // Gray + Alpha
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}
if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}
if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
case IMAGE_READ_MODE_GRAY_ALPHA:
if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}
if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}
if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
}
channels = 2;
}
break;
case 3:
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}
if (has_alpha) {
png_set_strip_alpha(png_ptr);
case IMAGE_READ_MODE_RGB:
if (color_type != PNG_COLOR_TYPE_RGB) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}
if (has_alpha) {
png_set_strip_alpha(png_ptr);
}
channels = 3;
}
break;
case 4:
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}
if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
case IMAGE_READ_MODE_RGB_ALPHA:
if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}
if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}
channels = 4;
}
break;
default:
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Invalid number of output channels.");
TORCH_CHECK(false, "Provided mode not supported");
}
png_read_update_info(png_ptr, info_ptr);
} else {
channels = current_channels;
}
auto tensor =
......
#pragma once
// Comment
#include <torch/torch.h>
#include <string>
#include "image_read_mode.h"
C10_EXPORT torch::Tensor decodePNG(
const torch::Tensor& data,
int64_t channels = 0);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
......@@ -4,6 +4,8 @@ import os
import os.path as osp
import importlib.machinery
from enum import Enum
_HAS_IMAGE_OPT = False
try:
......@@ -47,6 +49,14 @@ except (ImportError, OSError):
pass
class ImageReadMode(Enum):
UNCHANGED = 0
GRAY = 1
GRAY_ALPHA = 2
RGB = 3
RGB_ALPHA = 4
def read_file(path: str) -> torch.Tensor:
"""
Reads and outputs the bytes contents of a file as a uint8 Tensor
......@@ -74,24 +84,26 @@ def write_file(filename: str, data: torch.Tensor) -> None:
torch.ops.image.write_file(filename, data)
def decode_png(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the PNG image.
channels (int): the number of output channels for the decoded
image. 0 keeps the original number of channels, 1 converts to Grayscale
2 converts to Grayscale with Alpha, 3 converts to RGB and 4 coverts to
RGB with Alpha. Default: 0
mode (ImageReadMode): the read mode used for optionally
converting the image. Use `ImageReadMode.UNCHANGED` for loading
the image as-is, `ImageReadMode.GRAY` for converting to grayscale,
`ImageReadMode.GRAY_ALPHA` for grayscale with transparency,
`ImageReadMode.RGB` for RGB and `ImageReadMode.RGB_ALPHA` for
RGB with transparency. Default: `ImageReadMode.UNCHANGED`
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
output = torch.ops.image.decode_png(input, channels)
output = torch.ops.image.decode_png(input, mode.value)
return output
......@@ -137,23 +149,24 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output)
def decode_jpeg(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the JPEG image.
channels (int): the number of output channels for the decoded
image. 0 keeps the original number of channels, 1 converts to Grayscale
and 3 converts to RGB. Default: 0
mode (ImageReadMode): the read mode used for optionally
converting the image. Use `ImageReadMode.UNCHANGED` for loading
the image as-is, `ImageReadMode.GRAY` for converting to grayscale
and `ImageReadMode.RGB` for RGB. Default: `ImageReadMode.UNCHANGED`
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
output = torch.ops.image.decode_jpeg(input, channels)
output = torch.ops.image.decode_jpeg(input, mode.value)
return output
......@@ -202,12 +215,12 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
write_file(filename, output)
def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
"""
Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
Parameters
......@@ -215,39 +228,41 @@ def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
input: Tensor
a one dimensional uint8 tensor containing the raw bytes of the
PNG or JPEG image.
channels: int
the number of output channels of the decoded image. JPEG and PNG images
have different permitted values. The default value is 0 and it keeps
the original number of channels. See `decode_jpeg()` and `decode_png()`
for more information. Default: 0
mode: ImageReadMode
the read mode used for optionally converting the image. JPEG
and PNG images have different permitted values. The default
value is `ImageReadMode.UNCHANGED` and it keeps the image as-is.
See `decode_jpeg()` and `decode_png()` for more information.
Default: `ImageReadMode.UNCHANGED`
Returns
-------
output: Tensor[image_channels, image_height, image_width]
"""
output = torch.ops.image.decode_image(input, channels)
output = torch.ops.image.decode_image(input, mode.value)
return output
def read_image(path: str, channels: int = 0) -> torch.Tensor:
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
"""
Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
Parameters
----------
path: str
path of the JPEG or PNG image.
channels: int
the number of output channels of the decoded image. JPEG and PNG images
have different permitted values. The default value is 0 and it keeps
the original number of channels. See `decode_jpeg()` and `decode_png()`
for more information. Default: 0
mode: ImageReadMode
the read mode used for optionally converting the image. JPEG
and PNG images have different permitted values. The default
value is `ImageReadMode.UNCHANGED` and it keeps the image as-is.
See `decode_jpeg()` and `decode_png()` for more information.
Default: `ImageReadMode.UNCHANGED`
Returns
-------
output: Tensor[image_channels, image_height, image_width]
"""
data = read_file(path)
return decode_image(data, channels)
return decode_image(data, mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment