Unverified Commit 2b2dedc3 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Remove read_jpeg/read_png in favor of read_image (#2764)

parent de908627
......@@ -8,7 +8,7 @@ import torch
import torchvision
from PIL import Image
from torchvision.io.image import (
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png)
import numpy as np
......@@ -33,19 +33,12 @@ def get_images(directory, img_ext):
class ImageTester(unittest.TestCase):
def test_read_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
img_ljpeg = read_jpeg(img_path)
self.assertTrue(img_ljpeg.equal(img_pil))
def test_decode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path)
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
data = read_file(img_path)
img_ljpeg = decode_jpeg(data)
self.assertTrue(img_ljpeg.equal(img_pil))
with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
......@@ -59,9 +52,9 @@ class ImageTester(unittest.TestCase):
def test_damaged_images(self):
# Test image with bad Huffman encoding (should not raise)
bad_huff = os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')
bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
try:
_ = read_jpeg(bad_huff)
_ = decode_jpeg(bad_huff)
except RuntimeError:
self.assertTrue(False)
......@@ -69,8 +62,9 @@ class ImageTester(unittest.TestCase):
truncated_images = glob.glob(
os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
for image_path in truncated_images:
data = read_file(image_path)
with self.assertRaises(RuntimeError):
read_jpeg(image_path)
decode_jpeg(data)
def test_encode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
......@@ -79,7 +73,7 @@ class ImageTester(unittest.TestCase):
write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename))
img = read_jpeg(img_path)
img = decode_jpeg(read_file(img_path))
with open(expected_file, 'rb') as f:
pil_bytes = f.read()
......@@ -117,7 +111,8 @@ class ImageTester(unittest.TestCase):
def test_write_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img = read_jpeg(img_path)
data = read_file(img_path)
img = decode_jpeg(data)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
......@@ -137,20 +132,12 @@ class ImageTester(unittest.TestCase):
os.remove(torch_jpeg)
self.assertEqual(torch_bytes, pil_bytes)
def test_read_png(self):
# Check across .png
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
img_lpng = read_png(img_path)
self.assertTrue(img_lpng.equal(img_pil))
def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path)
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
data = read_file(img_path)
img_lpng = decode_png(data)
self.assertTrue(img_lpng.equal(img_pil))
with self.assertRaises(RuntimeError):
......
......@@ -54,21 +54,6 @@ def decode_png(input: torch.Tensor) -> torch.Tensor:
return output
def read_png(path: str) -> torch.Tensor:
"""
Reads a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the PNG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
data = read_file(path)
return decode_png(data)
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
......@@ -124,19 +109,6 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
return output
def read_jpeg(path: str) -> torch.Tensor:
"""
Reads a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the JPEG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
data = read_file(path)
return decode_jpeg(data)
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment