Unverified Commit 5e4a9f6d authored by Edgar Andrés Margffoy Tuay's avatar Edgar Andrés Margffoy Tuay Committed by GitHub
Browse files

PR: Make JPEG/PNG reading ops return images in CHW format (#2680)

* Make JPEG/PNG return images in CHW format

* Use int array
parent c4dcfb06
...@@ -30,12 +30,14 @@ class ImageTester(unittest.TestCase): ...@@ -30,12 +30,14 @@ class ImageTester(unittest.TestCase):
def test_read_jpeg(self): def test_read_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth')) img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
img_ljpeg = read_jpeg(img_path) img_ljpeg = read_jpeg(img_path)
self.assertTrue(img_ljpeg.equal(img_pil)) self.assertTrue(img_ljpeg.equal(img_pil))
def test_decode_jpeg(self): def test_decode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth')) img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path) size = os.path.getsize(img_path)
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size)) img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertTrue(img_ljpeg.equal(img_pil)) self.assertTrue(img_ljpeg.equal(img_pil))
...@@ -68,12 +70,14 @@ class ImageTester(unittest.TestCase): ...@@ -68,12 +70,14 @@ class ImageTester(unittest.TestCase):
# Check across .png # Check across .png
for img_path in get_images(IMAGE_DIR, ".png"): for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path))) img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
img_lpng = read_png(img_path) img_lpng = read_png(img_path)
self.assertTrue(img_lpng.equal(img_pil)) self.assertTrue(img_lpng.equal(img_pil))
def test_decode_png(self): def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, ".png"): for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path))) img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path) size = os.path.getsize(img_path)
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size)) img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
self.assertTrue(img_lpng.equal(img_pil)) self.assertTrue(img_lpng.equal(img_pil))
......
...@@ -137,7 +137,7 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) { ...@@ -137,7 +137,7 @@ torch::Tensor decodeJPEG(const torch::Tensor& data) {
jpeg_finish_decompress(&cinfo); jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo); jpeg_destroy_decompress(&cinfo);
return tensor; return tensor.permute({2, 0, 1});
} }
#endif // JPEG_FOUND #endif // JPEG_FOUND
...@@ -79,6 +79,6 @@ torch::Tensor decodePNG(const torch::Tensor& data) { ...@@ -79,6 +79,6 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
ptr += bytes; ptr += bytes;
} }
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor; return tensor.permute({2, 0, 1});
} }
#endif // PNG_FOUND #endif // PNG_FOUND
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment