Unverified Commit 195bb86e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_image.py (#3877)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 05e061f5
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from common_utils import get_tmp_dir, needs_cuda from common_utils import get_tmp_dir, needs_cuda
from _assert_utils import assert_equal
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
...@@ -107,7 +108,7 @@ class ImageTester(unittest.TestCase): ...@@ -107,7 +108,7 @@ class ImageTester(unittest.TestCase):
for src_img in [img, img.contiguous()]: for src_img in [img, img.contiguous()]:
# PIL sets jpeg quality to 75 by default # PIL sets jpeg quality to 75 by default
jpeg_bytes = encode_jpeg(src_img, quality=75) jpeg_bytes = encode_jpeg(src_img, quality=75)
self.assertTrue(jpeg_bytes.equal(pil_bytes)) assert_equal(jpeg_bytes, pil_bytes)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"): RuntimeError, "Input tensor dtype should be uint8"):
...@@ -191,7 +192,7 @@ class ImageTester(unittest.TestCase): ...@@ -191,7 +192,7 @@ class ImageTester(unittest.TestCase):
rec_img = torch.from_numpy(np.array(rec_img)) rec_img = torch.from_numpy(np.array(rec_img))
rec_img = rec_img.permute(2, 0, 1) rec_img = rec_img.permute(2, 0, 1)
self.assertTrue(img_pil.equal(rec_img)) assert_equal(img_pil, rec_img)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Input tensor dtype should be uint8"): RuntimeError, "Input tensor dtype should be uint8"):
...@@ -224,7 +225,7 @@ class ImageTester(unittest.TestCase): ...@@ -224,7 +225,7 @@ class ImageTester(unittest.TestCase):
saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1) saved_image = saved_image.permute(2, 0, 1)
self.assertTrue(img_pil.equal(saved_image)) assert_equal(img_pil, saved_image)
def test_read_file(self): def test_read_file(self):
with get_tmp_dir() as d: with get_tmp_dir() as d:
...@@ -235,7 +236,7 @@ class ImageTester(unittest.TestCase): ...@@ -235,7 +236,7 @@ class ImageTester(unittest.TestCase):
data = read_file(fpath) data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8) expected = torch.tensor(list(content), dtype=torch.uint8)
self.assertTrue(data.equal(expected)) assert_equal(data, expected)
os.unlink(fpath) os.unlink(fpath)
with self.assertRaisesRegex( with self.assertRaisesRegex(
...@@ -251,7 +252,7 @@ class ImageTester(unittest.TestCase): ...@@ -251,7 +252,7 @@ class ImageTester(unittest.TestCase):
data = read_file(fpath) data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8) expected = torch.tensor(list(content), dtype=torch.uint8)
self.assertTrue(data.equal(expected)) assert_equal(data, expected)
os.unlink(fpath) os.unlink(fpath)
def test_write_file(self): def test_write_file(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment