Unverified Commit 78159d61 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Extend the supported types of decodePNG (#2984)

* Add support of different color types in readpng.

* Adding test images and unit-tests.

* Use closest possible type.

* Fix formatting.
parent 481ef519
...@@ -16,7 +16,8 @@ from common_utils import get_tmp_dir ...@@ -16,7 +16,8 @@ from common_utils import get_tmp_dir
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
...@@ -133,9 +134,12 @@ class ImageTester(unittest.TestCase): ...@@ -133,9 +134,12 @@ class ImageTester(unittest.TestCase):
self.assertEqual(torch_bytes, pil_bytes) self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self): def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, ".png"): for img_path in get_images(FAKEDATA_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path))) img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1) if len(img_pil.shape) == 3:
img_pil = img_pil.permute(2, 0, 1)
else:
img_pil = img_pil.unsqueeze(0)
data = read_file(img_path) data = read_file(img_path)
img_lpng = decode_png(data) img_lpng = decode_png(data)
self.assertTrue(img_lpng.equal(img_pil)) self.assertTrue(img_lpng.equal(img_pil))
......
...@@ -71,17 +71,34 @@ torch::Tensor decodePNG(const torch::Tensor& data) { ...@@ -71,17 +71,34 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(retval == 1, "Could read image metadata from content.") TORCH_CHECK(retval == 1, "Could read image metadata from content.")
} }
if (color_type != PNG_COLOR_TYPE_RGB) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); int channels;
TORCH_CHECK( switch (color_type) {
color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.") case PNG_COLOR_TYPE_RGB:
channels = 3;
break;
case PNG_COLOR_TYPE_RGB_ALPHA:
channels = 4;
break;
case PNG_COLOR_TYPE_GRAY:
channels = 1;
break;
case PNG_COLOR_TYPE_GRAY_ALPHA:
channels = 2;
break;
case PNG_COLOR_TYPE_PALETTE:
channels = 1;
break;
default:
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Image color type is not supported.");
} }
auto tensor = auto tensor = torch::empty(
torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8); {int64_t(height), int64_t(width), int64_t(channels)}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data(); auto ptr = tensor.accessor<uint8_t, 3>().data();
auto bytes = png_get_rowbytes(png_ptr, info_ptr); auto bytes = png_get_rowbytes(png_ptr, info_ptr);
for (decltype(height) i = 0; i < height; ++i) { for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr); png_read_row(png_ptr, ptr, nullptr);
ptr += bytes; ptr += bytes;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment