"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "b2cf6045d36e400ad43931a0f47512387c6f8693"
Unverified Commit 4d6ba678 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Support specifying output channels in io.image.read_image (#2988)

* Adding output channels implementation for pngs.

* Adding tests for png.

* Adding channels in the API and documentation.

* Fixing formatting.

* Refactoring test_image.py to remove huge grace_hopper_517x606.pth file from assets and reduce duplicate code. Moving jpeg assets used by encode and write unit-tests on their separate folders.

* Adding output channels implementation for jpegs. Fix asset locations.

* Add tests for JPEG, adding the channels in the API and documentation and adding checks for inputs.

* Changing folder for unit-test.

* Fixing windows flakiness, removing duplicate test.

* Replacing components to channels.

* Adding reference for supporting CMYK.

* Minor changes: num_components to output_components, adding comments, fixing variable name etc.

* Reverting output_components to num_components.

* Replacing decoding with generic method on tests.

* Palette converted to Gray.
parent 74de51d6
...@@ -25,7 +25,8 @@ def process_model(model, tensor, func, name): ...@@ -25,7 +25,8 @@ def process_model(model, tensor, func, name):
def read_image1(): def read_image1():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg',
'grace_hopper_517x606.jpg')
image = Image.open(image_path) image = Image.open(image_path)
image = image.resize((224, 224)) image = image.resize((224, 224))
x = F.to_tensor(image) x = F.to_tensor(image)
...@@ -33,7 +34,8 @@ def read_image1(): ...@@ -33,7 +34,8 @@ def read_image1():
def read_image2(): def read_image2():
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg',
'grace_hopper_517x606.jpg')
image = Image.open(image_path) image = Image.open(image_path)
image = image.resize((299, 299)) image = image.resize((299, 299))
x = F.to_tensor(image) x = F.to_tensor(image)
......
...@@ -14,7 +14,7 @@ from common_utils import get_tmp_dir ...@@ -14,7 +14,7 @@ from common_utils import get_tmp_dir
TEST_FILE = get_file_path_2( TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
......
...@@ -19,6 +19,7 @@ IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") ...@@ -19,6 +19,7 @@ IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
def get_images(directory, img_ext): def get_images(directory, img_ext):
...@@ -33,14 +34,44 @@ def get_images(directory, img_ext): ...@@ -33,14 +34,44 @@ def get_images(directory, img_ext):
yield os.path.join(root, fl) yield os.path.join(root, fl)
def pil_read_image(img_path):
with Image.open(img_path) as img:
return torch.from_numpy(np.array(img))
def normalize_dimensions(img_pil):
if len(img_pil.shape) == 3:
img_pil = img_pil.permute(2, 0, 1)
else:
img_pil = img_pil.unsqueeze(0)
return img_pil
class ImageTester(unittest.TestCase): class ImageTester(unittest.TestCase):
def test_decode_jpeg(self): def test_decode_jpeg(self):
conversion = [(None, 0), ("L", 1), ("RGB", 3)]
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth')) for pil_mode, channels in conversion:
img_pil = img_pil.permute(2, 0, 1) with Image.open(img_path) as img:
data = read_file(img_path) is_cmyk = img.mode == "CMYK"
img_ljpeg = decode_jpeg(data) if pil_mode is not None:
self.assertTrue(img_ljpeg.equal(img_pil)) if is_cmyk:
# libjpeg does not support the conversion
continue
img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img))
if is_cmyk:
# flip the colors to match libjpeg
img_pil = 255 - img_pil
img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_ljpeg = decode_image(data, channels=channels)
# Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
self.assertTrue(abs_mean_diff < 2)
with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
...@@ -68,7 +99,7 @@ class ImageTester(unittest.TestCase): ...@@ -68,7 +99,7 @@ class ImageTester(unittest.TestCase):
decode_jpeg(data) decode_jpeg(data)
def test_encode_jpeg(self): def test_encode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(ENCODE_JPEG, ".jpg"):
dirname = os.path.dirname(img_path) dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, 'jpeg_write') write_folder = os.path.join(dirname, 'jpeg_write')
...@@ -111,7 +142,7 @@ class ImageTester(unittest.TestCase): ...@@ -111,7 +142,7 @@ class ImageTester(unittest.TestCase):
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def test_write_jpeg(self): def test_write_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(ENCODE_JPEG, ".jpg"):
data = read_file(img_path) data = read_file(img_path)
img = decode_jpeg(data) img = decode_jpeg(data)
...@@ -134,20 +165,25 @@ class ImageTester(unittest.TestCase): ...@@ -134,20 +165,25 @@ class ImageTester(unittest.TestCase):
self.assertEqual(torch_bytes, pil_bytes) self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self): def test_decode_png(self):
conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)]
for img_path in get_images(FAKEDATA_DIR, ".png"): for img_path in get_images(FAKEDATA_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path))) for pil_mode, channels in conversion:
if len(img_pil.shape) == 3: with Image.open(img_path) as img:
img_pil = img_pil.permute(2, 0, 1) if pil_mode is not None:
else: img = img.convert(pil_mode)
img_pil = img_pil.unsqueeze(0) img_pil = torch.from_numpy(np.array(img))
data = read_file(img_path)
img_lpng = decode_png(data)
self.assertTrue(img_lpng.equal(img_pil))
with self.assertRaises(RuntimeError): img_pil = normalize_dimensions(img_pil)
decode_png(torch.empty((), dtype=torch.uint8)) data = read_file(img_path)
with self.assertRaises(RuntimeError): img_lpng = decode_image(data, channels=channels)
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
tol = 0 if conversion is None else 1
self.assertTrue(img_lpng.allclose(img_pil, atol=tol))
with self.assertRaises(RuntimeError):
decode_png(torch.empty((), dtype=torch.uint8))
with self.assertRaises(RuntimeError):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_encode_png(self): def test_encode_png(self):
for img_path in get_images(IMAGE_DIR, '.png'): for img_path in get_images(IMAGE_DIR, '.png'):
...@@ -196,19 +232,6 @@ class ImageTester(unittest.TestCase): ...@@ -196,19 +232,6 @@ class ImageTester(unittest.TestCase):
self.assertTrue(img_pil.equal(saved_image)) self.assertTrue(img_pil.equal(saved_image))
def test_decode_image(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
img_ljpeg = decode_image(read_file(img_path))
self.assertTrue(img_ljpeg.equal(img_pil))
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
img_lpng = decode_image(read_file(img_path))
self.assertTrue(img_lpng.equal(img_pil))
def test_read_file(self): def test_read_file(self):
with get_tmp_dir() as d: with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n' fname, content = 'test1.bin', b'TorchVision\211\n'
......
...@@ -24,7 +24,7 @@ from common_utils import cycle_over, int_dtypes, float_dtypes ...@@ -24,7 +24,7 @@ from common_utils import cycle_over, int_dtypes, float_dtypes
GRACE_HOPPER = get_file_path_2( GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
......
#include "read_image_cpu.h" #include "read_image_cpu.h"
#include <cstring> #include <cstring>
torch::Tensor decode_image(const torch::Tensor& data) { torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
TORCH_CHECK( TORCH_CHECK(
data.dim() == 1 && data.numel() > 0, data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor"); "Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");
auto datap = data.data_ptr<uint8_t>(); auto datap = data.data_ptr<uint8_t>();
...@@ -15,9 +17,9 @@ torch::Tensor decode_image(const torch::Tensor& data) { ...@@ -15,9 +17,9 @@ torch::Tensor decode_image(const torch::Tensor& data) {
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
if (memcmp(jpeg_signature, datap, 3) == 0) { if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data); return decodeJPEG(data, channels);
} else if (memcmp(png_signature, datap, 4) == 0) { } else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data); return decodePNG(data, channels);
} else { } else {
TORCH_CHECK( TORCH_CHECK(
false, false,
......
...@@ -3,4 +3,6 @@ ...@@ -3,4 +3,6 @@
#include "readjpeg_cpu.h" #include "readjpeg_cpu.h"
#include "readpng_cpu.h" #include "readpng_cpu.h"
C10_EXPORT torch::Tensor decode_image(const torch::Tensor& data); C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
int64_t channels = 0);
#include "readjpeg_cpu.h" #include "readjpeg_cpu.h"
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <setjmp.h>
#include <string> #include <string>
#if !JPEG_FOUND #if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodeJPEG(const torch::Tensor& data) {
TORCH_CHECK( TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support"); false, "decodeJPEG: torchvision not compiled with libjpeg support");
} }
#else #else
#include <jpeglib.h> #include <jpeglib.h>
#include <setjmp.h>
#include "jpegcommon.h" #include "jpegcommon.h"
struct torch_jpeg_mgr { struct torch_jpeg_mgr {
...@@ -71,13 +69,16 @@ static void torch_jpeg_set_source_mgr( ...@@ -71,13 +69,16 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data; src->pub.next_input_byte = src->data;
} }
torch::Tensor decodeJPEG(const torch::Tensor& data) { torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
TORCH_CHECK( TORCH_CHECK(
data.dim() == 1 && data.numel() > 0, data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor"); "Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels == 0 || channels == 1 || channels == 3,
"Number of channels not supported");
struct jpeg_decompress_struct cinfo; struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr; struct torch_jpeg_error_mgr jerr;
...@@ -100,15 +101,41 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) { ...@@ -100,15 +101,41 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) {
// read info from header. // read info from header.
jpeg_read_header(&cinfo, TRUE); jpeg_read_header(&cinfo, TRUE);
int current_channels = cinfo.num_components;
if (channels > 0 && channels != current_channels) {
switch (channels) {
case 1: // Gray
cinfo.out_color_space = JCS_GRAYSCALE;
break;
case 3: // RGB
cinfo.out_color_space = JCS_RGB;
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
* is a way to do this but it involves converting it manually to RGB:
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
*
*/
default:
jpeg_destroy_decompress(&cinfo);
TORCH_CHECK(false, "Invalid number of output channels.");
}
jpeg_calc_output_dimensions(&cinfo);
} else {
channels = current_channels;
}
jpeg_start_decompress(&cinfo); jpeg_start_decompress(&cinfo);
int height = cinfo.output_height; int height = cinfo.output_height;
int width = cinfo.output_width; int width = cinfo.output_width;
int components = cinfo.output_components;
auto stride = width * components; int stride = width * channels;
auto tensor = torch::empty( auto tensor =
{int64_t(height), int64_t(width), int64_t(components)}, torch::kU8); torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>(); auto ptr = tensor.data_ptr<uint8_t>();
while (cinfo.output_scanline < cinfo.output_height) { while (cinfo.output_scanline < cinfo.output_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines. /* jpeg_read_scanlines expects an array of pointers to scanlines.
......
...@@ -2,4 +2,6 @@ ...@@ -2,4 +2,6 @@
#include <torch/torch.h> #include <torch/torch.h>
C10_EXPORT torch::Tensor decodeJPEG(const torch::Tensor& data); C10_EXPORT torch::Tensor decodeJPEG(
const torch::Tensor& data,
int64_t channels = 0);
...@@ -2,23 +2,26 @@ ...@@ -2,23 +2,26 @@
// Comment // Comment
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <setjmp.h>
#include <string> #include <string>
#define PNG_FOUND 1
#if !PNG_FOUND #if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data) { torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support"); TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
} }
#else #else
#include <png.h> #include <png.h>
#include <setjmp.h>
torch::Tensor decodePNG(const torch::Tensor& data) { torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
TORCH_CHECK( TORCH_CHECK(
data.dim() == 1 && data.numel() > 0, data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor"); "Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");
auto png_ptr = auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
...@@ -72,30 +75,79 @@ torch::Tensor decodePNG(const torch::Tensor& data) { ...@@ -72,30 +75,79 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.") TORCH_CHECK(retval == 1, "Could read image metadata from content.")
} }
int channels; int current_channels = png_get_channels(png_ptr, info_ptr);
switch (color_type) {
case PNG_COLOR_TYPE_RGB: if (channels > 0) {
channels = 3; // TODO: consider supporting PNG_INFO_tRNS
break; bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
case PNG_COLOR_TYPE_RGB_ALPHA: bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
channels = 4; bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;
break;
case PNG_COLOR_TYPE_GRAY: switch (channels) {
channels = 1; case 1: // Gray
break; if (is_palette) {
case PNG_COLOR_TYPE_GRAY_ALPHA: png_set_palette_to_rgb(png_ptr);
channels = 2; has_alpha = true;
break; }
case PNG_COLOR_TYPE_PALETTE:
channels = 1; if (has_alpha) {
break; png_set_strip_alpha(png_ptr);
default: }
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Image color type is not supported."); if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
}
break;
case 2: // Gray + Alpha
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}
if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}
if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
}
break;
case 3:
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}
if (has_alpha) {
png_set_strip_alpha(png_ptr);
}
break;
case 4:
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}
if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}
break;
default:
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Invalid number of output channels.");
}
png_read_update_info(png_ptr, info_ptr);
} else {
channels = current_channels;
} }
auto tensor = torch::empty( auto tensor =
{int64_t(height), int64_t(width), int64_t(channels)}, torch::kU8); torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data(); auto ptr = tensor.accessor<uint8_t, 3>().data();
auto bytes = png_get_rowbytes(png_ptr, info_ptr); auto bytes = png_get_rowbytes(png_ptr, info_ptr);
for (png_uint_32 i = 0; i < height; ++i) { for (png_uint_32 i = 0; i < height; ++i) {
......
...@@ -4,4 +4,6 @@ ...@@ -4,4 +4,6 @@
#include <torch/torch.h> #include <torch/torch.h>
#include <string> #include <string>
C10_EXPORT torch::Tensor decodePNG(const torch::Tensor& data); C10_EXPORT torch::Tensor decodePNG(
const torch::Tensor& data,
int64_t channels = 0);
...@@ -74,19 +74,24 @@ def write_file(filename: str, data: torch.Tensor) -> None: ...@@ -74,19 +74,24 @@ def write_file(filename: str, data: torch.Tensor) -> None:
torch.ops.image.write_file(filename, data) torch.ops.image.write_file(filename, data)
def decode_png(input: torch.Tensor) -> torch.Tensor: def decode_png(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
""" """
Decodes a PNG image into a 3 dimensional RGB Tensor. Decodes a PNG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Arguments: Arguments:
input (Tensor[1]): a one dimensional uint8 tensor containing input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the PNG image. the raw bytes of the PNG image.
channels (int): the number of output channels for the decoded
image. 0 keeps the original number of channels, 1 converts to Grayscale
2 converts to Grayscale with Alpha, 3 converts to RGB and 4 coverts to
RGB with Alpha. Default: 0
Returns: Returns:
output (Tensor[3, image_height, image_width]) output (Tensor[image_channels, image_height, image_width])
""" """
output = torch.ops.image.decode_png(input) output = torch.ops.image.decode_png(input, channels)
return output return output
...@@ -132,17 +137,23 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): ...@@ -132,17 +137,23 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output) write_file(filename, output)
def decode_jpeg(input: torch.Tensor) -> torch.Tensor: def decode_jpeg(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
""" """
Decodes a JPEG image into a 3 dimensional RGB Tensor. Decodes a JPEG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Arguments: Arguments:
input (Tensor[1]): a one dimensional uint8 tensor containing input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the JPEG image. the raw bytes of the JPEG image.
channels (int): the number of output channels for the decoded
image. 0 keeps the original number of channels, 1 converts to Grayscale
and 3 converts to RGB. Default: 0
Returns: Returns:
output (Tensor[3, image_height, image_width]) output (Tensor[image_channels, image_height, image_width])
""" """
output = torch.ops.image.decode_jpeg(input) output = torch.ops.image.decode_jpeg(input, channels)
return output return output
...@@ -191,11 +202,12 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): ...@@ -191,11 +202,12 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
write_file(filename, output) write_file(filename, output)
def decode_image(input: torch.Tensor) -> torch.Tensor: def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
""" """
Detects whether an image is a JPEG or PNG and performs the appropriate Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB Tensor. operation to decode the image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Parameters Parameters
...@@ -203,28 +215,39 @@ def decode_image(input: torch.Tensor) -> torch.Tensor: ...@@ -203,28 +215,39 @@ def decode_image(input: torch.Tensor) -> torch.Tensor:
input: Tensor input: Tensor
a one dimensional uint8 tensor containing the raw bytes of the a one dimensional uint8 tensor containing the raw bytes of the
PNG or JPEG image. PNG or JPEG image.
channels: int
the number of output channels of the decoded image. JPEG and PNG images
have different permitted values. The default value is 0 and it keeps
the original number of channels. See `decode_jpeg()` and `decode_png()`
for more information. Default: 0
Returns Returns
------- -------
output: Tensor[3, image_height, image_width] output: Tensor[image_channels, image_height, image_width]
""" """
output = torch.ops.image.decode_image(input) output = torch.ops.image.decode_image(input, channels)
return output return output
def read_image(path: str) -> torch.Tensor: def read_image(path: str, channels: int = 0) -> torch.Tensor:
""" """
Reads a JPEG or PNG image into a 3 dimensional RGB Tensor. Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Parameters Parameters
---------- ----------
path: str path: str
path of the JPEG or PNG image. path of the JPEG or PNG image.
channels: int
the number of output channels of the decoded image. JPEG and PNG images
have different permitted values. The default value is 0 and it keeps
the original number of channels. See `decode_jpeg()` and `decode_png()`
for more information. Default: 0
Returns Returns
------- -------
output: Tensor[3, image_height, image_width] output: Tensor[image_channels, image_height, image_width]
""" """
data = read_file(path) data = read_file(path)
return decode_image(data) return decode_image(data, channels)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment