Unverified Commit 2b2dedc3 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Remove read_jpeg/read_png in favor of read_image (#2764)

parent de908627
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torchvision import torchvision
from PIL import Image from PIL import Image
from torchvision.io.image import ( from torchvision.io.image import (
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png) encode_png, write_png)
import numpy as np import numpy as np
...@@ -33,19 +33,12 @@ def get_images(directory, img_ext): ...@@ -33,19 +33,12 @@ def get_images(directory, img_ext):
class ImageTester(unittest.TestCase): class ImageTester(unittest.TestCase):
def test_read_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
img_ljpeg = read_jpeg(img_path)
self.assertTrue(img_ljpeg.equal(img_pil))
def test_decode_jpeg(self): def test_decode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth')) img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path) data = read_file(img_path)
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size)) img_ljpeg = decode_jpeg(data)
self.assertTrue(img_ljpeg.equal(img_pil)) self.assertTrue(img_ljpeg.equal(img_pil))
with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"): with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
...@@ -59,9 +52,9 @@ class ImageTester(unittest.TestCase): ...@@ -59,9 +52,9 @@ class ImageTester(unittest.TestCase):
def test_damaged_images(self): def test_damaged_images(self):
# Test image with bad Huffman encoding (should not raise) # Test image with bad Huffman encoding (should not raise)
bad_huff = os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg') bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
try: try:
_ = read_jpeg(bad_huff) _ = decode_jpeg(bad_huff)
except RuntimeError: except RuntimeError:
self.assertTrue(False) self.assertTrue(False)
...@@ -69,8 +62,9 @@ class ImageTester(unittest.TestCase): ...@@ -69,8 +62,9 @@ class ImageTester(unittest.TestCase):
truncated_images = glob.glob( truncated_images = glob.glob(
os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
for image_path in truncated_images: for image_path in truncated_images:
data = read_file(image_path)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
read_jpeg(image_path) decode_jpeg(data)
def test_encode_jpeg(self): def test_encode_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(IMAGE_ROOT, ".jpg"):
...@@ -79,7 +73,7 @@ class ImageTester(unittest.TestCase): ...@@ -79,7 +73,7 @@ class ImageTester(unittest.TestCase):
write_folder = os.path.join(dirname, 'jpeg_write') write_folder = os.path.join(dirname, 'jpeg_write')
expected_file = os.path.join( expected_file = os.path.join(
write_folder, '{0}_pil.jpg'.format(filename)) write_folder, '{0}_pil.jpg'.format(filename))
img = read_jpeg(img_path) img = decode_jpeg(read_file(img_path))
with open(expected_file, 'rb') as f: with open(expected_file, 'rb') as f:
pil_bytes = f.read() pil_bytes = f.read()
...@@ -117,7 +111,8 @@ class ImageTester(unittest.TestCase): ...@@ -117,7 +111,8 @@ class ImageTester(unittest.TestCase):
def test_write_jpeg(self): def test_write_jpeg(self):
for img_path in get_images(IMAGE_ROOT, ".jpg"): for img_path in get_images(IMAGE_ROOT, ".jpg"):
img = read_jpeg(img_path) data = read_file(img_path)
img = decode_jpeg(data)
basedir = os.path.dirname(img_path) basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
...@@ -137,20 +132,12 @@ class ImageTester(unittest.TestCase): ...@@ -137,20 +132,12 @@ class ImageTester(unittest.TestCase):
os.remove(torch_jpeg) os.remove(torch_jpeg)
self.assertEqual(torch_bytes, pil_bytes) self.assertEqual(torch_bytes, pil_bytes)
def test_read_png(self):
# Check across .png
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
img_lpng = read_png(img_path)
self.assertTrue(img_lpng.equal(img_pil))
def test_decode_png(self): def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, ".png"): for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path))) img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
size = os.path.getsize(img_path) data = read_file(img_path)
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size)) img_lpng = decode_png(data)
self.assertTrue(img_lpng.equal(img_pil)) self.assertTrue(img_lpng.equal(img_pil))
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
......
...@@ -54,21 +54,6 @@ def decode_png(input: torch.Tensor) -> torch.Tensor: ...@@ -54,21 +54,6 @@ def decode_png(input: torch.Tensor) -> torch.Tensor:
return output return output
def read_png(path: str) -> torch.Tensor:
"""
Reads a PNG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the PNG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
data = read_file(path)
return decode_png(data)
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor: def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
""" """
Takes an input tensor in CHW layout and returns a buffer with the contents Takes an input tensor in CHW layout and returns a buffer with the contents
...@@ -124,19 +109,6 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor: ...@@ -124,19 +109,6 @@ def decode_jpeg(input: torch.Tensor) -> torch.Tensor:
return output return output
def read_jpeg(path: str) -> torch.Tensor:
"""
Reads a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
Arguments:
path (str): path of the JPEG image.
Returns:
output (Tensor[3, image_height, image_width])
"""
data = read_file(path)
return decode_jpeg(data)
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
""" """
Takes an input tensor in CHW layout and returns a buffer with the contents Takes an input tensor in CHW layout and returns a buffer with the contents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment