Unverified Commit 6afb3496 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add decode_image op (#2718)

* Add decode_image op

* Fix lint

* More lint

* Add C10_EXPORT
parent 898802fe
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torchvision import torchvision
from PIL import Image from PIL import Image
from torchvision.io.image import ( from torchvision.io.image import (
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg) read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, _read_file)
import numpy as np import numpy as np
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
...@@ -44,10 +44,10 @@ class ImageTester(unittest.TestCase): ...@@ -44,10 +44,10 @@ class ImageTester(unittest.TestCase):
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size)) img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertTrue(img_ljpeg.equal(img_pil)) self.assertTrue(img_ljpeg.equal(img_pil))
with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."): with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."): with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100, ), dtype=torch.float16)) decode_jpeg(torch.empty((100, ), dtype=torch.float16))
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
...@@ -149,11 +149,24 @@ class ImageTester(unittest.TestCase): ...@@ -149,11 +149,24 @@ class ImageTester(unittest.TestCase):
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size)) img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertTrue(img_lpng.equal(img_pil)) self.assertTrue(img_lpng.equal(img_pil))
with self.assertRaises(ValueError): with self.assertRaises(RuntimeError):
decode_png(torch.empty((), dtype=torch.uint8)) decode_png(torch.empty((), dtype=torch.uint8))
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
def test_decode_image(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
img_ljpeg = decode_image(_read_file(img_path))
self.assertTrue(img_ljpeg.equal(img_pil))
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
img_lpng = decode_image(_read_file(img_path))
self.assertTrue(img_lpng.equal(img_pil))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,4 +16,5 @@ static auto registry = torch::RegisterOperators() ...@@ -16,4 +16,5 @@ static auto registry = torch::RegisterOperators()
.op("image::decode_png", &decodePNG) .op("image::decode_png", &decodePNG)
.op("image::decode_jpeg", &decodeJPEG) .op("image::decode_jpeg", &decodeJPEG)
.op("image::encode_jpeg", &encodeJPEG) .op("image::encode_jpeg", &encodeJPEG)
.op("image::write_jpeg", &writeJPEG); .op("image::write_jpeg", &writeJPEG)
.op("image::decode_image", &decode_image);
#pragma once #pragma once
// Comment // Comment
#include <torch/script.h> #include <torch/script.h>
#include <torch/torch.h> #include <torch/torch.h>
#include "read_image_cpu.h"
#include "readjpeg_cpu.h" #include "readjpeg_cpu.h"
#include "readpng_cpu.h" #include "readpng_cpu.h"
#include "writejpeg_cpu.h" #include "writejpeg_cpu.h"
#include "read_image_cpu.h"
#include <string.h>
torch::Tensor decode_image(const torch::Tensor& data) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
auto datap = data.data_ptr<uint8_t>();
const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data);
} else {
TORCH_CHECK(
false,
"Unsupported image file. Only jpeg and png ",
"are currently supported.");
}
}
#pragma once
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
C10_EXPORT torch::Tensor decode_image(const torch::Tensor& data);
...@@ -72,6 +72,13 @@ static void torch_jpeg_set_source_mgr( ...@@ -72,6 +72,13 @@ static void torch_jpeg_set_source_mgr(
} }
torch::Tensor decodeJPEG(const torch::Tensor& data) { torch::Tensor decodeJPEG(const torch::Tensor& data) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
struct jpeg_decompress_struct cinfo; struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr; struct torch_jpeg_error_mgr jerr;
......
...@@ -13,6 +13,13 @@ torch::Tensor decodePNG(const torch::Tensor& data) { ...@@ -13,6 +13,13 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
#include <png.h> #include <png.h>
torch::Tensor decodePNG(const torch::Tensor& data) { torch::Tensor decodePNG(const torch::Tensor& data) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
auto png_ptr = auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
TORCH_CHECK(png_ptr, "libpng read structure allocation failed!") TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")
......
...@@ -23,23 +23,29 @@ except (ImportError, OSError): ...@@ -23,23 +23,29 @@ except (ImportError, OSError):
pass pass
def _read_file(path: str) -> torch.Tensor:
if not os.path.isfile(path):
raise ValueError("Expected a valid file path.")
size = os.path.getsize(path)
if size == 0:
raise ValueError("Expected a non empty file.")
data = torch.from_file(path, dtype=torch.uint8, size=size)
return data
def decode_png(input: torch.Tensor) -> torch.Tensor: def decode_png(input: torch.Tensor) -> torch.Tensor:
""" """
Decodes a PNG image into a 3 dimensional RGB Tensor. Decodes a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Arguments: Arguments:
input (Tensor[1]): a one dimensional int8 tensor containing input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the PNG image. the raw bytes of the PNG image.
Returns: Returns:
output (Tensor[3, image_height, image_width]) output (Tensor[3, image_height, image_width])
""" """
if not isinstance(input, torch.Tensor) or input.numel() == 0 or input.ndim != 1: # type: ignore[attr-defined]
raise ValueError("Expected a non empty 1-dimensional tensor.")
if not input.dtype == torch.uint8:
raise ValueError("Expected a torch.uint8 tensor.")
output = torch.ops.image.decode_png(input) output = torch.ops.image.decode_png(input)
return output return output
...@@ -55,13 +61,7 @@ def read_png(path: str) -> torch.Tensor: ...@@ -55,13 +61,7 @@ def read_png(path: str) -> torch.Tensor:
Returns: Returns:
output (Tensor[3, image_height, image_width]) output (Tensor[3, image_height, image_width])
""" """
if not os.path.isfile(path): data = _read_file(path)
raise ValueError("Expected a valid file path.")
size = os.path.getsize(path)
if size == 0:
raise ValueError("Expected a non empty file.")
data = torch.from_file(path, dtype=torch.uint8, size=size)
return decode_png(data) return decode_png(data)
...@@ -70,17 +70,11 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor: ...@@ -70,17 +70,11 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
Decodes a JPEG image into a 3 dimensional RGB Tensor. Decodes a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
Arguments: Arguments:
input (Tensor[1]): a one dimensional int8 tensor containing input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the JPEG image. the raw bytes of the JPEG image.
Returns: Returns:
output (Tensor[3, image_height, image_width]) output (Tensor[3, image_height, image_width])
""" """
if not isinstance(input, torch.Tensor) or len(input) == 0 or input.ndim != 1: # type: ignore[attr-defined]
raise ValueError("Expected a non empty 1-dimensional tensor.")
if not input.dtype == torch.uint8:
raise ValueError("Expected a torch.uint8 tensor.")
output = torch.ops.image.decode_jpeg(input) output = torch.ops.image.decode_jpeg(input)
return output return output
...@@ -94,13 +88,7 @@ def read_jpeg(path: str) -> torch.Tensor: ...@@ -94,13 +88,7 @@ def read_jpeg(path: str) -> torch.Tensor:
Returns: Returns:
output (Tensor[3, image_height, image_width]) output (Tensor[3, image_height, image_width])
""" """
if not os.path.isfile(path): data = _read_file(path)
raise ValueError("Expected a valid file path.")
size = os.path.getsize(path)
if size == 0:
raise ValueError("Expected a non empty file.")
data = torch.from_file(path, dtype=torch.uint8, size=size)
return decode_jpeg(data) return decode_jpeg(data)
...@@ -141,3 +129,33 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): ...@@ -141,3 +129,33 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
'between 1 and 100') 'between 1 and 100')
torch.ops.image.write_jpeg(input, filename, quality) torch.ops.image.write_jpeg(input, filename, quality)
def decode_image(input: torch.Tensor) -> torch.Tensor:
"""
Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
input (Tensor): a one dimensional uint8 tensor containing
the raw bytes of the PNG or JPEG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
output = torch.ops.image.decode_image(input)
return output
def read_image(path: str) -> torch.Tensor:
"""
Reads a JPEG or PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the JPEG or PNG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
data = _read_file(path)
return decode_image(data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment