"vscode:/vscode.git/clone" did not exist on "5a98f7e4218882f6932f18ff7cb87fd33a38bb31"
Unverified Commit 78159d61 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Extend the supported types of decodePNG (#2984)

* Add support of different color types in readpng.

* Adding test images and unit-tests.

* Use closest possible type.

* Fix formatting.
parent 481ef519
......@@ -16,7 +16,8 @@ from common_utils import get_tmp_dir
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
......@@ -133,9 +134,12 @@ class ImageTester(unittest.TestCase):
self.assertEqual(torch_bytes, pil_bytes)
def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, ".png"):
for img_path in get_images(FAKEDATA_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
if len(img_pil.shape) == 3:
img_pil = img_pil.permute(2, 0, 1)
else:
img_pil = img_pil.unsqueeze(0)
data = read_file(img_path)
img_lpng = decode_png(data)
self.assertTrue(img_lpng.equal(img_pil))
......
......@@ -71,17 +71,34 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}
if (color_type != PNG_COLOR_TYPE_RGB) {
int channels;
switch (color_type) {
case PNG_COLOR_TYPE_RGB:
channels = 3;
break;
case PNG_COLOR_TYPE_RGB_ALPHA:
channels = 4;
break;
case PNG_COLOR_TYPE_GRAY:
channels = 1;
break;
case PNG_COLOR_TYPE_GRAY_ALPHA:
channels = 2;
break;
case PNG_COLOR_TYPE_PALETTE:
channels = 1;
break;
default:
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(
color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.")
TORCH_CHECK(false, "Image color type is not supported.");
}
auto tensor =
torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8);
auto tensor = torch::empty(
{int64_t(height), int64_t(width), int64_t(channels)}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
auto bytes = png_get_rowbytes(png_ptr, info_ptr);
for (decltype(height) i = 0; i < height; ++i) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += bytes;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment