Unverified Commit a884cb7b authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add support of mode and remove channels (#3024)

* Add support of mode and remove channels.

* Replacing integer mode with define constants.
parent 1706921b
...@@ -2,14 +2,12 @@ import os ...@@ -2,14 +2,12 @@ import os
import io import io
import glob import glob
import unittest import unittest
import sys
import torch import torch
import torchvision
from PIL import Image from PIL import Image
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file) encode_png, write_png, write_file, ImageReadMode)
import numpy as np import numpy as np
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
...@@ -49,9 +47,9 @@ def normalize_dimensions(img_pil): ...@@ -49,9 +47,9 @@ def normalize_dimensions(img_pil):
class ImageTester(unittest.TestCase): class ImageTester(unittest.TestCase):
def test_decode_jpeg(self): def test_decode_jpeg(self):
conversion = [(None, 0), ("L", 1), ("RGB", 3)] conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)]
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(IMAGE_ROOT, ".jpg"):
for pil_mode, channels in conversion: for pil_mode, mode in conversion:
with Image.open(img_path) as img: with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK" is_cmyk = img.mode == "CMYK"
if pil_mode is not None: if pil_mode is not None:
...@@ -66,7 +64,7 @@ class ImageTester(unittest.TestCase): ...@@ -66,7 +64,7 @@ class ImageTester(unittest.TestCase):
img_pil = normalize_dimensions(img_pil) img_pil = normalize_dimensions(img_pil)
data = read_file(img_path) data = read_file(img_path)
img_ljpeg = decode_image(data, channels=channels) img_ljpeg = decode_image(data, mode=mode)
# Permit a small variation on pixel values to account for implementation # Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG. # differences between Pillow and LibJPEG.
...@@ -165,9 +163,10 @@ class ImageTester(unittest.TestCase): ...@@ -165,9 +163,10 @@ class ImageTester(unittest.TestCase):
self.assertEqual(torch_bytes, pil_bytes) self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self): def test_decode_png(self):
conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)] conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
for img_path in get_images(FAKEDATA_DIR, ".png"): for img_path in get_images(FAKEDATA_DIR, ".png"):
for pil_mode, channels in conversion: for pil_mode, mode in conversion:
with Image.open(img_path) as img: with Image.open(img_path) as img:
if pil_mode is not None: if pil_mode is not None:
img = img.convert(pil_mode) img = img.convert(pil_mode)
...@@ -175,7 +174,7 @@ class ImageTester(unittest.TestCase): ...@@ -175,7 +174,7 @@ class ImageTester(unittest.TestCase):
img_pil = normalize_dimensions(img_pil) img_pil = normalize_dimensions(img_pil)
data = read_file(img_path) data = read_file(img_path)
img_lpng = decode_image(data, channels=channels) img_lpng = decode_image(data, mode=mode)
tol = 0 if conversion is None else 1 tol = 0 if conversion is None else 1
self.assertTrue(img_lpng.allclose(img_pil, atol=tol)) self.assertTrue(img_lpng.allclose(img_pil, atol=tol))
......
#pragma once
/* Should be kept in-sync with Python ImageReadMode enum */
using ImageReadMode = int64_t;
#define IMAGE_READ_MODE_UNCHANGED 0
#define IMAGE_READ_MODE_GRAY 1
#define IMAGE_READ_MODE_GRAY_ALPHA 2
#define IMAGE_READ_MODE_RGB 3
#define IMAGE_READ_MODE_RGB_ALPHA 4
\ No newline at end of file
#include "read_image_cpu.h" #include "read_image_cpu.h"
#include <cstring> #include "readjpeg_cpu.h"
#include "readpng_cpu.h"
torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) { torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
TORCH_CHECK( TORCH_CHECK(
data.dim() == 1 && data.numel() > 0, data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor"); "Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");
auto datap = data.data_ptr<uint8_t>(); auto datap = data.data_ptr<uint8_t>();
...@@ -17,9 +16,9 @@ torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) { ...@@ -17,9 +16,9 @@ torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) {
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
if (memcmp(jpeg_signature, datap, 3) == 0) { if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data, channels); return decodeJPEG(data, mode);
} else if (memcmp(png_signature, datap, 4) == 0) { } else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data, channels); return decodePNG(data, mode);
} else { } else {
TORCH_CHECK( TORCH_CHECK(
false, false,
......
#pragma once #pragma once
#include "readjpeg_cpu.h" #include <torch/torch.h>
#include "readpng_cpu.h" #include "image_read_mode.h"
C10_EXPORT torch::Tensor decode_image( C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data, const torch::Tensor& data,
int64_t channels = 0); ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
#include "readjpeg_cpu.h" #include "readjpeg_cpu.h"
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <string>
#if !JPEG_FOUND #if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK( TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support"); false, "decodeJPEG: torchvision not compiled with libjpeg support");
} }
...@@ -69,16 +68,13 @@ static void torch_jpeg_set_source_mgr( ...@@ -69,16 +68,13 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data; src->pub.next_input_byte = src->data;
} }
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
TORCH_CHECK( TORCH_CHECK(
data.dim() == 1 && data.numel() > 0, data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor"); "Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels == 0 || channels == 1 || channels == 3,
"Number of channels not supported");
struct jpeg_decompress_struct cinfo; struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr; struct torch_jpeg_error_mgr jerr;
...@@ -102,30 +98,33 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { ...@@ -102,30 +98,33 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
// read info from header. // read info from header.
jpeg_read_header(&cinfo, TRUE); jpeg_read_header(&cinfo, TRUE);
int current_channels = cinfo.num_components; int channels = cinfo.num_components;
if (channels > 0 && channels != current_channels) { if (mode != IMAGE_READ_MODE_UNCHANGED) {
switch (channels) { switch (mode) {
case 1: // Gray case IMAGE_READ_MODE_GRAY:
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
cinfo.out_color_space = JCS_GRAYSCALE; cinfo.out_color_space = JCS_GRAYSCALE;
channels = 1;
}
break; break;
case 3: // RGB case IMAGE_READ_MODE_RGB:
if (cinfo.jpeg_color_space != JCS_RGB) {
cinfo.out_color_space = JCS_RGB; cinfo.out_color_space = JCS_RGB;
channels = 3;
}
break; break;
/* /*
* Libjpeg does not support converting from CMYK to grayscale etc. There * Libjpeg does not support converting from CMYK to grayscale etc. There
* is a way to do this but it involves converting it manually to RGB: * is a way to do this but it involves converting it manually to RGB:
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313 * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
*
*/ */
default: default:
jpeg_destroy_decompress(&cinfo); jpeg_destroy_decompress(&cinfo);
TORCH_CHECK(false, "Invalid number of output channels."); TORCH_CHECK(false, "Provided mode not supported");
} }
jpeg_calc_output_dimensions(&cinfo); jpeg_calc_output_dimensions(&cinfo);
} else {
channels = current_channels;
} }
jpeg_start_decompress(&cinfo); jpeg_start_decompress(&cinfo);
......
#pragma once #pragma once
#include <torch/torch.h> #include <torch/torch.h>
#include "image_read_mode.h"
C10_EXPORT torch::Tensor decodeJPEG( C10_EXPORT torch::Tensor decodeJPEG(
const torch::Tensor& data, const torch::Tensor& data,
int64_t channels = 0); ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
#include "readpng_cpu.h" #include "readpng_cpu.h"
// Comment
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <string>
#if !PNG_FOUND #if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support"); TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
} }
#else #else
#include <png.h> #include <png.h>
#include <setjmp.h> #include <setjmp.h>
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
TORCH_CHECK( TORCH_CHECK(
data.dim() == 1 && data.numel() > 0, data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor"); "Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");
auto png_ptr = auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
...@@ -74,16 +70,17 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { ...@@ -74,16 +70,17 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.") TORCH_CHECK(retval == 1, "Could read image metadata from content.")
} }
int current_channels = png_get_channels(png_ptr, info_ptr); int channels = png_get_channels(png_ptr, info_ptr);
if (channels > 0) { if (mode != IMAGE_READ_MODE_UNCHANGED) {
// TODO: consider supporting PNG_INFO_tRNS // TODO: consider supporting PNG_INFO_tRNS
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0; bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0; bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0; bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;
switch (channels) { switch (mode) {
case 1: // Gray case IMAGE_READ_MODE_GRAY:
if (color_type != PNG_COLOR_TYPE_GRAY) {
if (is_palette) { if (is_palette) {
png_set_palette_to_rgb(png_ptr); png_set_palette_to_rgb(png_ptr);
has_alpha = true; has_alpha = true;
...@@ -96,8 +93,11 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { ...@@ -96,8 +93,11 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
if (has_color) { if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
} }
channels = 1;
}
break; break;
case 2: // Gray + Alpha case IMAGE_READ_MODE_GRAY_ALPHA:
if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) {
if (is_palette) { if (is_palette) {
png_set_palette_to_rgb(png_ptr); png_set_palette_to_rgb(png_ptr);
has_alpha = true; has_alpha = true;
...@@ -110,8 +110,11 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { ...@@ -110,8 +110,11 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
if (has_color) { if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
} }
channels = 2;
}
break; break;
case 3: case IMAGE_READ_MODE_RGB:
if (color_type != PNG_COLOR_TYPE_RGB) {
if (is_palette) { if (is_palette) {
png_set_palette_to_rgb(png_ptr); png_set_palette_to_rgb(png_ptr);
has_alpha = true; has_alpha = true;
...@@ -122,8 +125,11 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { ...@@ -122,8 +125,11 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
if (has_alpha) { if (has_alpha) {
png_set_strip_alpha(png_ptr); png_set_strip_alpha(png_ptr);
} }
channels = 3;
}
break; break;
case 4: case IMAGE_READ_MODE_RGB_ALPHA:
if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) {
if (is_palette) { if (is_palette) {
png_set_palette_to_rgb(png_ptr); png_set_palette_to_rgb(png_ptr);
has_alpha = true; has_alpha = true;
...@@ -134,15 +140,15 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { ...@@ -134,15 +140,15 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
if (!has_alpha) { if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
} }
channels = 4;
}
break; break;
default: default:
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Invalid number of output channels."); TORCH_CHECK(false, "Provided mode not supported");
} }
png_read_update_info(png_ptr, info_ptr); png_read_update_info(png_ptr, info_ptr);
} else {
channels = current_channels;
} }
auto tensor = auto tensor =
......
#pragma once #pragma once
// Comment
#include <torch/torch.h> #include <torch/torch.h>
#include <string> #include "image_read_mode.h"
C10_EXPORT torch::Tensor decodePNG( C10_EXPORT torch::Tensor decodePNG(
const torch::Tensor& data, const torch::Tensor& data,
int64_t channels = 0); ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
...@@ -4,6 +4,8 @@ import os ...@@ -4,6 +4,8 @@ import os
import os.path as osp import os.path as osp
import importlib.machinery import importlib.machinery
from enum import Enum
_HAS_IMAGE_OPT = False _HAS_IMAGE_OPT = False
try: try:
...@@ -47,6 +49,14 @@ except (ImportError, OSError): ...@@ -47,6 +49,14 @@ except (ImportError, OSError):
pass pass
class ImageReadMode(Enum):
UNCHANGED = 0
GRAY = 1
GRAY_ALPHA = 2
RGB = 3
RGB_ALPHA = 4
def read_file(path: str) -> torch.Tensor: def read_file(path: str) -> torch.Tensor:
""" """
Reads and outputs the bytes contents of a file as a uint8 Tensor Reads and outputs the bytes contents of a file as a uint8 Tensor
...@@ -74,24 +84,26 @@ def write_file(filename: str, data: torch.Tensor) -> None: ...@@ -74,24 +84,26 @@ def write_file(filename: str, data: torch.Tensor) -> None:
torch.ops.image.write_file(filename, data) torch.ops.image.write_file(filename, data)
def decode_png(input: torch.Tensor, channels: int = 0) -> torch.Tensor: def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
""" """
Decodes a PNG image into a 3 dimensional RGB Tensor. Decodes a PNG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels. Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Arguments: Arguments:
input (Tensor[1]): a one dimensional uint8 tensor containing input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the PNG image. the raw bytes of the PNG image.
channels (int): the number of output channels for the decoded mode (ImageReadMode): the read mode used for optionally
image. 0 keeps the original number of channels, 1 converts to Grayscale converting the image. Use `ImageReadMode.UNCHANGED` for loading
2 converts to Grayscale with Alpha, 3 converts to RGB and 4 coverts to the image as-is, `ImageReadMode.GRAY` for converting to grayscale,
RGB with Alpha. Default: 0 `ImageReadMode.GRAY_ALPHA` for grayscale with transparency,
`ImageReadMode.RGB` for RGB and `ImageReadMode.RGB_ALPHA` for
RGB with transparency. Default: `ImageReadMode.UNCHANGED`
Returns: Returns:
output (Tensor[image_channels, image_height, image_width]) output (Tensor[image_channels, image_height, image_width])
""" """
output = torch.ops.image.decode_png(input, channels) output = torch.ops.image.decode_png(input, mode.value)
return output return output
...@@ -137,23 +149,24 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): ...@@ -137,23 +149,24 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output) write_file(filename, output)
def decode_jpeg(input: torch.Tensor, channels: int = 0) -> torch.Tensor: def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
""" """
Decodes a JPEG image into a 3 dimensional RGB Tensor. Decodes a JPEG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels. Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Arguments: Arguments:
input (Tensor[1]): a one dimensional uint8 tensor containing input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the JPEG image. the raw bytes of the JPEG image.
channels (int): the number of output channels for the decoded mode (ImageReadMode): the read mode used for optionally
image. 0 keeps the original number of channels, 1 converts to Grayscale converting the image. Use `ImageReadMode.UNCHANGED` for loading
and 3 converts to RGB. Default: 0 the image as-is, `ImageReadMode.GRAY` for converting to grayscale
and `ImageReadMode.RGB` for RGB. Default: `ImageReadMode.UNCHANGED`
Returns: Returns:
output (Tensor[image_channels, image_height, image_width]) output (Tensor[image_channels, image_height, image_width])
""" """
output = torch.ops.image.decode_jpeg(input, channels) output = torch.ops.image.decode_jpeg(input, mode.value)
return output return output
...@@ -202,12 +215,12 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): ...@@ -202,12 +215,12 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
write_file(filename, output) write_file(filename, output)
def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor: def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
""" """
Detects whether an image is a JPEG or PNG and performs the appropriate Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB Tensor. operation to decode the image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels. Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Parameters Parameters
...@@ -215,39 +228,41 @@ def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor: ...@@ -215,39 +228,41 @@ def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor:
input: Tensor input: Tensor
a one dimensional uint8 tensor containing the raw bytes of the a one dimensional uint8 tensor containing the raw bytes of the
PNG or JPEG image. PNG or JPEG image.
channels: int mode: ImageReadMode
the number of output channels of the decoded image. JPEG and PNG images the read mode used for optionally converting the image. JPEG
have different permitted values. The default value is 0 and it keeps and PNG images have different permitted values. The default
the original number of channels. See `decode_jpeg()` and `decode_png()` value is `ImageReadMode.UNCHANGED` and it keeps the image as-is.
for more information. Default: 0 See `decode_jpeg()` and `decode_png()` for more information.
Default: `ImageReadMode.UNCHANGED`
Returns Returns
------- -------
output: Tensor[image_channels, image_height, image_width] output: Tensor[image_channels, image_height, image_width]
""" """
output = torch.ops.image.decode_image(input, channels) output = torch.ops.image.decode_image(input, mode.value)
return output return output
def read_image(path: str, channels: int = 0) -> torch.Tensor: def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
""" """
Reads a JPEG or PNG image into a 3 dimensional RGB Tensor. Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired number of color channels. Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Parameters Parameters
---------- ----------
path: str path: str
path of the JPEG or PNG image. path of the JPEG or PNG image.
channels: int mode: ImageReadMode
the number of output channels of the decoded image. JPEG and PNG images the read mode used for optionally converting the image. JPEG
have different permitted values. The default value is 0 and it keeps and PNG images have different permitted values. The default
the original number of channels. See `decode_jpeg()` and `decode_png()` value is `ImageReadMode.UNCHANGED` and it keeps the image as-is.
for more information. Default: 0 See `decode_jpeg()` and `decode_png()` for more information.
Default: `ImageReadMode.UNCHANGED`
Returns Returns
------- -------
output: Tensor[image_channels, image_height, image_width] output: Tensor[image_channels, image_height, image_width]
""" """
data = read_file(path) data = read_file(path)
return decode_image(data, channels) return decode_image(data, mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment