Unverified Commit 963d432c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_transforms.py (#3884)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 88358528
...@@ -22,6 +22,7 @@ except ImportError: ...@@ -22,6 +22,7 @@ except ImportError:
stats = None stats = None
from common_utils import cycle_over, int_dtypes, float_dtypes from common_utils import cycle_over, int_dtypes, float_dtypes
from _assert_utils import assert_equal
GRACE_HOPPER = get_file_path_2( GRACE_HOPPER = get_file_path_2(
...@@ -102,8 +103,10 @@ class Tester(unittest.TestCase): ...@@ -102,8 +103,10 @@ class Tester(unittest.TestCase):
"image_size: {} crop_size: {}".format(input_image_size, crop_size)) "image_size: {} crop_size: {}".format(input_image_size, crop_size))
# Ensure output for PIL and Tensor are equal # Ensure output for PIL and Tensor are equal
self.assertEqual((output_tensor - output_pil).sum(), 0, assert_equal(
"image_size: {} crop_size: {}".format(input_image_size, crop_size)) output_tensor, output_pil, check_stride=False,
msg="image_size: {} crop_size: {}".format(input_image_size, crop_size)
)
# Check if content in center of both image and cropped output is same. # Check if content in center of both image and cropped output is same.
center_size = (min(crop_size[0], input_image_size[0]), min(crop_size[1], input_image_size[1])) center_size = (min(crop_size[0], input_image_size[0]), min(crop_size[1], input_image_size[1]))
...@@ -126,8 +129,10 @@ class Tester(unittest.TestCase): ...@@ -126,8 +129,10 @@ class Tester(unittest.TestCase):
input_center_tl[1]:input_center_tl[1] + center_size[1] input_center_tl[1]:input_center_tl[1] + center_size[1]
] ]
self.assertEqual((output_center - img_center).sum(), 0, assert_equal(
"image_size: {} crop_size: {}".format(input_image_size, crop_size)) output_center, img_center, check_stride=False,
msg="image_size: {} crop_size: {}".format(input_image_size, crop_size)
)
def test_five_crop(self): def test_five_crop(self):
to_pil_image = transforms.ToPILImage() to_pil_image = transforms.ToPILImage()
...@@ -382,7 +387,7 @@ class Tester(unittest.TestCase): ...@@ -382,7 +387,7 @@ class Tester(unittest.TestCase):
])(img) ])(img)
self.assertEqual(result.size(1), height) self.assertEqual(result.size(1), height)
self.assertEqual(result.size(2), width) self.assertEqual(result.size(2), width)
self.assertTrue(np.allclose(img.numpy(), result.numpy())) torch.testing.assert_close(result, img)
result = transforms.Compose([ result = transforms.Compose([
transforms.ToPILImage(), transforms.ToPILImage(),
...@@ -414,8 +419,14 @@ class Tester(unittest.TestCase): ...@@ -414,8 +419,14 @@ class Tester(unittest.TestCase):
# to the pad value # to the pad value
fill_v = fill / 255 fill_v = fill / 255
eps = 1e-5 eps = 1e-5
self.assertTrue((result[:, :padding, :] - fill_v).abs().max() < eps) h_padded = result[:, :padding, :]
self.assertTrue((result[:, :, :padding] - fill_v).abs().max() < eps) w_padded = result[:, :, :padding]
torch.testing.assert_close(
h_padded, torch.full_like(h_padded, fill_value=fill_v), check_stride=False, rtol=0.0, atol=eps
)
torch.testing.assert_close(
w_padded, torch.full_like(w_padded, fill_value=fill_v), check_stride=False, rtol=0.0, atol=eps
)
self.assertRaises(ValueError, transforms.Pad(padding, fill=(1, 2)), self.assertRaises(ValueError, transforms.Pad(padding, fill=(1, 2)),
transforms.ToPILImage()(img)) transforms.ToPILImage()(img))
...@@ -448,7 +459,7 @@ class Tester(unittest.TestCase): ...@@ -448,7 +459,7 @@ class Tester(unittest.TestCase):
# First 6 elements of leftmost edge in the middle of the image, values are in order: # First 6 elements of leftmost edge in the middle of the image, values are in order:
# edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0 # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6] edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
self.assertTrue(np.all(edge_middle_slice == np.asarray([200, 200, 200, 200, 1, 0]))) assert_equal(edge_middle_slice, np.asarray([200, 200, 200, 200, 1, 0], dtype=np.uint8), check_stride=False)
self.assertEqual(transforms.ToTensor()(edge_padded_img).size(), (3, 35, 35)) self.assertEqual(transforms.ToTensor()(edge_padded_img).size(), (3, 35, 35))
# Pad 3 to left/right, 2 to top/bottom # Pad 3 to left/right, 2 to top/bottom
...@@ -456,7 +467,7 @@ class Tester(unittest.TestCase): ...@@ -456,7 +467,7 @@ class Tester(unittest.TestCase):
# First 6 elements of leftmost edge in the middle of the image, values are in order: # First 6 elements of leftmost edge in the middle of the image, values are in order:
# reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0 # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6] reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
self.assertTrue(np.all(reflect_middle_slice == np.asarray([0, 0, 1, 200, 1, 0]))) assert_equal(reflect_middle_slice, np.asarray([0, 0, 1, 200, 1, 0], dtype=np.uint8), check_stride=False)
self.assertEqual(transforms.ToTensor()(reflect_padded_img).size(), (3, 33, 35)) self.assertEqual(transforms.ToTensor()(reflect_padded_img).size(), (3, 33, 35))
# Pad 3 to left, 2 to top, 2 to right, 1 to bottom # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
...@@ -464,7 +475,7 @@ class Tester(unittest.TestCase): ...@@ -464,7 +475,7 @@ class Tester(unittest.TestCase):
# First 6 elements of leftmost edge in the middle of the image, values are in order: # First 6 elements of leftmost edge in the middle of the image, values are in order:
# sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0 # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6] symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
self.assertTrue(np.all(symmetric_middle_slice == np.asarray([0, 1, 200, 200, 1, 0]))) assert_equal(symmetric_middle_slice, np.asarray([0, 1, 200, 200, 1, 0], dtype=np.uint8), check_stride=False)
self.assertEqual(transforms.ToTensor()(symmetric_padded_img).size(), (3, 32, 34)) self.assertEqual(transforms.ToTensor()(symmetric_padded_img).size(), (3, 32, 34))
# Check negative padding explicitly for symmetric case, since it is not # Check negative padding explicitly for symmetric case, since it is not
...@@ -473,8 +484,8 @@ class Tester(unittest.TestCase): ...@@ -473,8 +484,8 @@ class Tester(unittest.TestCase):
symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode='symmetric') symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode='symmetric')
symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3] symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3]
symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:] symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:]
self.assertTrue(np.all(symmetric_neg_middle_left == np.asarray([1, 0, 0]))) assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8), check_stride=False)
self.assertTrue(np.all(symmetric_neg_middle_right == np.asarray([200, 200, 0, 0]))) assert_equal(symmetric_neg_middle_right, np.asarray([200, 200, 0, 0], dtype=np.uint8), check_stride=False)
self.assertEqual(transforms.ToTensor()(symmetric_padded_img_neg).size(), (3, 28, 31)) self.assertEqual(transforms.ToTensor()(symmetric_padded_img_neg).size(), (3, 28, 31))
def test_pad_raises_with_invalid_pad_sequence_len(self): def test_pad_raises_with_invalid_pad_sequence_len(self):
...@@ -499,12 +510,12 @@ class Tester(unittest.TestCase): ...@@ -499,12 +510,12 @@ class Tester(unittest.TestCase):
trans = transforms.Lambda(lambda x: x.add(10)) trans = transforms.Lambda(lambda x: x.add(10))
x = torch.randn(10) x = torch.randn(10)
y = trans(x) y = trans(x)
self.assertTrue(y.equal(torch.add(x, 10))) assert_equal(y, torch.add(x, 10))
trans = transforms.Lambda(lambda x: x.add_(10)) trans = transforms.Lambda(lambda x: x.add_(10))
x = torch.randn(10) x = torch.randn(10)
y = trans(x) y = trans(x)
self.assertTrue(y.equal(x)) assert_equal(y, x)
# Checking if Lambda can be printed as string # Checking if Lambda can be printed as string
trans.__repr__() trans.__repr__()
...@@ -613,23 +624,23 @@ class Tester(unittest.TestCase): ...@@ -613,23 +624,23 @@ class Tester(unittest.TestCase):
input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255) input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
img = transforms.ToPILImage()(input_data) img = transforms.ToPILImage()(input_data)
output = trans(img) output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) torch.testing.assert_close(output, input_data, check_stride=False)
ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
output = trans(ndarray) output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1)) / 255.0 expected_output = ndarray.transpose((2, 0, 1)) / 255.0
self.assertTrue(np.allclose(output.numpy(), expected_output)) torch.testing.assert_close(output.numpy(), expected_output, check_stride=False, check_dtype=False)
ndarray = np.random.rand(height, width, channels).astype(np.float32) ndarray = np.random.rand(height, width, channels).astype(np.float32)
output = trans(ndarray) output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1)) expected_output = ndarray.transpose((2, 0, 1))
self.assertTrue(np.allclose(output.numpy(), expected_output)) torch.testing.assert_close(output.numpy(), expected_output, check_stride=False, check_dtype=False)
# separate test for mode '1' PIL images # separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_() input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1') img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
output = trans(img) output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) torch.testing.assert_close(input_data, output, check_dtype=False, check_stride=False)
def test_to_tensor_with_other_default_dtypes(self): def test_to_tensor_with_other_default_dtypes(self):
current_def_dtype = torch.get_default_dtype() current_def_dtype = torch.get_default_dtype()
...@@ -665,8 +676,7 @@ class Tester(unittest.TestCase): ...@@ -665,8 +676,7 @@ class Tester(unittest.TestCase):
output_image = transform(input_image) output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype) output_image_script = transform_script(input_image, output_dtype)
script_diff = output_image_script - output_image torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
self.assertLess(script_diff.abs().max(), 1e-6)
actual_min, actual_max = output_image.tolist() actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0 desired_min, desired_max = 0.0, 1.0
...@@ -691,8 +701,7 @@ class Tester(unittest.TestCase): ...@@ -691,8 +701,7 @@ class Tester(unittest.TestCase):
output_image = transform(input_image) output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype) output_image_script = transform_script(input_image, output_dtype)
script_diff = output_image_script - output_image torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
self.assertLess(script_diff.abs().max(), 1e-6)
actual_min, actual_max = output_image.tolist() actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max desired_min, desired_max = 0, torch.iinfo(output_dtype).max
...@@ -711,8 +720,7 @@ class Tester(unittest.TestCase): ...@@ -711,8 +720,7 @@ class Tester(unittest.TestCase):
output_image = transform(input_image) output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype) output_image_script = transform_script(input_image, output_dtype)
script_diff = output_image_script - output_image torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
self.assertLess(script_diff.abs().max(), 1e-6)
actual_min, actual_max = output_image.tolist() actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0 desired_min, desired_max = 0.0, 1.0
...@@ -736,9 +744,12 @@ class Tester(unittest.TestCase): ...@@ -736,9 +744,12 @@ class Tester(unittest.TestCase):
output_image = transform(input_image) output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype) output_image_script = transform_script(input_image, output_dtype)
script_diff = output_image_script.float() - output_image.float() torch.testing.assert_close(
self.assertLess( output_image_script,
script_diff.abs().max(), 1e-6, msg="{} vs {}".format(output_image_script, output_image) output_image,
rtol=0.0,
atol=1e-6,
msg="{} vs {}".format(output_image_script, output_image),
) )
actual_min, actual_max = output_image.tolist() actual_min, actual_max = output_image.tolist()
...@@ -780,8 +791,7 @@ class Tester(unittest.TestCase): ...@@ -780,8 +791,7 @@ class Tester(unittest.TestCase):
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
self.assertEqual(expected_output.size(), output.size()) torch.testing.assert_close(output, expected_output)
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
def test_pil_to_tensor(self): def test_pil_to_tensor(self):
test_channels = [1, 3, 4] test_channels = [1, 3, 4]
...@@ -796,25 +806,25 @@ class Tester(unittest.TestCase): ...@@ -796,25 +806,25 @@ class Tester(unittest.TestCase):
input_data = torch.ByteTensor(channels, height, width).random_(0, 255) input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
img = transforms.ToPILImage()(input_data) img = transforms.ToPILImage()(input_data)
output = trans(img) output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) torch.testing.assert_close(input_data, output, check_stride=False)
input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
img = transforms.ToPILImage()(input_data) img = transforms.ToPILImage()(input_data)
output = trans(img) output = trans(img)
expected_output = input_data.transpose((2, 0, 1)) expected_output = input_data.transpose((2, 0, 1))
self.assertTrue(np.allclose(output.numpy(), expected_output)) torch.testing.assert_close(output.numpy(), expected_output)
input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32)) input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32))
img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte() img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte()
output = trans(img) # HWC -> CHW output = trans(img) # HWC -> CHW
expected_output = (input_data * 255).byte() expected_output = (input_data * 255).byte()
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) torch.testing.assert_close(output, expected_output, check_stride=False)
# separate test for mode '1' PIL images # separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_() input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1') img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
output = trans(img) output = trans(img).view(torch.uint8).bool().to(torch.uint8)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) torch.testing.assert_close(input_data, output, check_stride=False)
@unittest.skipIf(accimage is None, 'accimage not available') @unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_pil_to_tensor(self): def test_accimage_pil_to_tensor(self):
...@@ -824,7 +834,7 @@ class Tester(unittest.TestCase): ...@@ -824,7 +834,7 @@ class Tester(unittest.TestCase):
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
self.assertEqual(expected_output.size(), output.size()) self.assertEqual(expected_output.size(), output.size())
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) torch.testing.assert_close(output, expected_output)
@unittest.skipIf(accimage is None, 'accimage not available') @unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_resize(self): def test_accimage_resize(self):
...@@ -859,7 +869,7 @@ class Tester(unittest.TestCase): ...@@ -859,7 +869,7 @@ class Tester(unittest.TestCase):
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
self.assertEqual(expected_output.size(), output.size()) self.assertEqual(expected_output.size(), output.size())
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) torch.testing.assert_close(output, expected_output)
def test_1_channel_tensor_to_pil_image(self): def test_1_channel_tensor_to_pil_image(self):
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
...@@ -880,12 +890,13 @@ class Tester(unittest.TestCase): ...@@ -880,12 +890,13 @@ class Tester(unittest.TestCase):
for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]: for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
img = transform(img_data) img = transform(img_data)
self.assertEqual(img.mode, mode) self.assertEqual(img.mode, mode)
self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy())) torch.testing.assert_close(expected_output, to_tensor(img).numpy(), check_stride=False)
# 'F' mode for torch.FloatTensor # 'F' mode for torch.FloatTensor
img_F_mode = transforms.ToPILImage(mode='F')(img_data_float) img_F_mode = transforms.ToPILImage(mode='F')(img_data_float)
self.assertEqual(img_F_mode.mode, 'F') self.assertEqual(img_F_mode.mode, 'F')
self.assertTrue(np.allclose(np.array(Image.fromarray(img_data_float.squeeze(0).numpy(), mode='F')), torch.testing.assert_close(
np.array(img_F_mode))) np.array(Image.fromarray(img_data_float.squeeze(0).numpy(), mode='F')), np.array(img_F_mode)
)
def test_1_channel_ndarray_to_pil_image(self): def test_1_channel_ndarray_to_pil_image(self):
img_data_float = torch.Tensor(4, 4, 1).uniform_().numpy() img_data_float = torch.Tensor(4, 4, 1).uniform_().numpy()
...@@ -899,7 +910,9 @@ class Tester(unittest.TestCase): ...@@ -899,7 +910,9 @@ class Tester(unittest.TestCase):
for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]: for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
img = transform(img_data) img = transform(img_data)
self.assertEqual(img.mode, mode) self.assertEqual(img.mode, mode)
self.assertTrue(np.allclose(img_data[:, :, 0], img)) # note: we explicitly convert img's dtype because pytorch doesn't support uint16
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype))
def test_2_channel_ndarray_to_pil_image(self): def test_2_channel_ndarray_to_pil_image(self):
def verify_img_data(img_data, mode): def verify_img_data(img_data, mode):
...@@ -911,7 +924,7 @@ class Tester(unittest.TestCase): ...@@ -911,7 +924,7 @@ class Tester(unittest.TestCase):
self.assertEqual(img.mode, mode) self.assertEqual(img.mode, mode)
split = img.split() split = img.split()
for i in range(2): for i in range(2):
self.assertTrue(np.allclose(img_data[:, :, i], split[i])) torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False)
img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
for mode in [None, 'LA']: for mode in [None, 'LA']:
...@@ -984,7 +997,7 @@ class Tester(unittest.TestCase): ...@@ -984,7 +997,7 @@ class Tester(unittest.TestCase):
self.assertEqual(img.mode, mode) self.assertEqual(img.mode, mode)
split = img.split() split = img.split()
for i in range(3): for i in range(3):
self.assertTrue(np.allclose(img_data[:, :, i], split[i])) torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False)
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
for mode in [None, 'RGB', 'HSV', 'YCbCr']: for mode in [None, 'RGB', 'HSV', 'YCbCr']:
...@@ -1033,7 +1046,7 @@ class Tester(unittest.TestCase): ...@@ -1033,7 +1046,7 @@ class Tester(unittest.TestCase):
self.assertEqual(img.mode, mode) self.assertEqual(img.mode, mode)
split = img.split() split = img.split()
for i in range(4): for i in range(4):
self.assertTrue(np.allclose(img_data[:, :, i], split[i])) torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False)
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
for mode in [None, 'RGBA', 'CMYK', 'RGBX']: for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
...@@ -1064,7 +1077,7 @@ class Tester(unittest.TestCase): ...@@ -1064,7 +1077,7 @@ class Tester(unittest.TestCase):
for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]: for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
img = transform(img_data) img = transform(img_data)
self.assertEqual(img.mode, mode) self.assertEqual(img.mode, mode)
self.assertTrue(np.allclose(expected_output, to_tensor(img).numpy())) np.testing.assert_allclose(expected_output, to_tensor(img).numpy()[0])
def test_2d_ndarray_to_pil_image(self): def test_2d_ndarray_to_pil_image(self):
img_data_float = torch.Tensor(4, 4).uniform_().numpy() img_data_float = torch.Tensor(4, 4).uniform_().numpy()
...@@ -1078,7 +1091,7 @@ class Tester(unittest.TestCase): ...@@ -1078,7 +1091,7 @@ class Tester(unittest.TestCase):
for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]: for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
img = transform(img_data) img = transform(img_data)
self.assertEqual(img.mode, mode) self.assertEqual(img.mode, mode)
self.assertTrue(np.allclose(img_data, img)) np.testing.assert_allclose(img_data, img)
def test_tensor_bad_types_to_pil_image(self): def test_tensor_bad_types_to_pil_image(self):
with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'): with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'):
...@@ -1189,7 +1202,7 @@ class Tester(unittest.TestCase): ...@@ -1189,7 +1202,7 @@ class Tester(unittest.TestCase):
# Checking the optional in-place behaviour # Checking the optional in-place behaviour
tensor = torch.rand((1, 16, 16)) tensor = torch.rand((1, 16, 16))
tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor) tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor)
self.assertTrue(torch.equal(tensor, tensor_inplace)) assert_equal(tensor, tensor_inplace)
def test_normalize_different_dtype(self): def test_normalize_different_dtype(self):
for dtype1 in [torch.float32, torch.float64]: for dtype1 in [torch.float32, torch.float64]:
...@@ -1207,7 +1220,7 @@ class Tester(unittest.TestCase): ...@@ -1207,7 +1220,7 @@ class Tester(unittest.TestCase):
mean = torch.rand(n_channels) mean = torch.rand(n_channels)
std = torch.rand(n_channels) std = torch.rand(n_channels)
img = torch.rand(n_channels, img_size, img_size) img = torch.rand(n_channels, img_size, img_size)
target = F.normalize(img, mean, std).numpy() target = F.normalize(img, mean, std)
mean_unsqueezed = mean.view(-1, 1, 1) mean_unsqueezed = mean.view(-1, 1, 1)
std_unsqueezed = std.view(-1, 1, 1) std_unsqueezed = std.view(-1, 1, 1)
...@@ -1215,8 +1228,8 @@ class Tester(unittest.TestCase): ...@@ -1215,8 +1228,8 @@ class Tester(unittest.TestCase):
result2 = F.normalize(img, result2 = F.normalize(img,
mean_unsqueezed.repeat(1, img_size, img_size), mean_unsqueezed.repeat(1, img_size, img_size),
std_unsqueezed.repeat(1, img_size, img_size)) std_unsqueezed.repeat(1, img_size, img_size))
assert_array_almost_equal(target, result1.numpy()) torch.testing.assert_close(target, result1)
assert_array_almost_equal(target, result2.numpy()) torch.testing.assert_close(target, result2)
def test_adjust_brightness(self): def test_adjust_brightness(self):
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
...@@ -1227,21 +1240,21 @@ class Tester(unittest.TestCase): ...@@ -1227,21 +1240,21 @@ class Tester(unittest.TestCase):
# test 0 # test 0
y_pil = F.adjust_brightness(x_pil, 1) y_pil = F.adjust_brightness(x_pil, 1)
y_np = np.array(y_pil) y_np = np.array(y_pil)
self.assertTrue(np.allclose(y_np, x_np)) torch.testing.assert_close(y_np, x_np)
# test 1 # test 1
y_pil = F.adjust_brightness(x_pil, 0.5) y_pil = F.adjust_brightness(x_pil, 0.5)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0] y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
# test 2 # test 2
y_pil = F.adjust_brightness(x_pil, 2) y_pil = F.adjust_brightness(x_pil, 2)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2] y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
def test_adjust_contrast(self): def test_adjust_contrast(self):
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
...@@ -1252,21 +1265,21 @@ class Tester(unittest.TestCase): ...@@ -1252,21 +1265,21 @@ class Tester(unittest.TestCase):
# test 0 # test 0
y_pil = F.adjust_contrast(x_pil, 1) y_pil = F.adjust_contrast(x_pil, 1)
y_np = np.array(y_pil) y_np = np.array(y_pil)
self.assertTrue(np.allclose(y_np, x_np)) torch.testing.assert_close(y_np, x_np)
# test 1 # test 1
y_pil = F.adjust_contrast(x_pil, 0.5) y_pil = F.adjust_contrast(x_pil, 0.5)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43] y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
# test 2 # test 2
y_pil = F.adjust_contrast(x_pil, 2) y_pil = F.adjust_contrast(x_pil, 2)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0] y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
@unittest.skipIf(Image.__version__ >= '7', "Temporarily disabled") @unittest.skipIf(Image.__version__ >= '7', "Temporarily disabled")
def test_adjust_saturation(self): def test_adjust_saturation(self):
...@@ -1278,21 +1291,21 @@ class Tester(unittest.TestCase): ...@@ -1278,21 +1291,21 @@ class Tester(unittest.TestCase):
# test 0 # test 0
y_pil = F.adjust_saturation(x_pil, 1) y_pil = F.adjust_saturation(x_pil, 1)
y_np = np.array(y_pil) y_np = np.array(y_pil)
self.assertTrue(np.allclose(y_np, x_np)) torch.testing.assert_close(y_np, x_np)
# test 1 # test 1
y_pil = F.adjust_saturation(x_pil, 0.5) y_pil = F.adjust_saturation(x_pil, 0.5)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88] y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
# test 2 # test 2
y_pil = F.adjust_saturation(x_pil, 2) y_pil = F.adjust_saturation(x_pil, 2)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0] y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
def test_adjust_hue(self): def test_adjust_hue(self):
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
...@@ -1310,21 +1323,21 @@ class Tester(unittest.TestCase): ...@@ -1310,21 +1323,21 @@ class Tester(unittest.TestCase):
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1] y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
# test 1 # test 1
y_pil = F.adjust_hue(x_pil, 0.25) y_pil = F.adjust_hue(x_pil, 0.25)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255] y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
# test 2 # test 2
y_pil = F.adjust_hue(x_pil, -0.25) y_pil = F.adjust_hue(x_pil, -0.25)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1] y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
def test_adjust_sharpness(self): def test_adjust_sharpness(self):
x_shape = [4, 4, 3] x_shape = [4, 4, 3]
...@@ -1337,7 +1350,7 @@ class Tester(unittest.TestCase): ...@@ -1337,7 +1350,7 @@ class Tester(unittest.TestCase):
# test 0 # test 0
y_pil = F.adjust_sharpness(x_pil, 1) y_pil = F.adjust_sharpness(x_pil, 1)
y_np = np.array(y_pil) y_np = np.array(y_pil)
self.assertTrue(np.allclose(y_np, x_np)) torch.testing.assert_close(y_np, x_np)
# test 1 # test 1
y_pil = F.adjust_sharpness(x_pil, 0.5) y_pil = F.adjust_sharpness(x_pil, 0.5)
...@@ -1346,7 +1359,7 @@ class Tester(unittest.TestCase): ...@@ -1346,7 +1359,7 @@ class Tester(unittest.TestCase):
30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101, 30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101,
107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] 107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
# test 2 # test 2
y_pil = F.adjust_sharpness(x_pil, 2) y_pil = F.adjust_sharpness(x_pil, 2)
...@@ -1355,7 +1368,7 @@ class Tester(unittest.TestCase): ...@@ -1355,7 +1368,7 @@ class Tester(unittest.TestCase):
0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112, 0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112,
119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] 119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
# test 3 # test 3
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
...@@ -1366,7 +1379,7 @@ class Tester(unittest.TestCase): ...@@ -1366,7 +1379,7 @@ class Tester(unittest.TestCase):
y_pil = F.adjust_sharpness(x_pil, 2) y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil).transpose(2, 0, 1) y_np = np.array(y_pil).transpose(2, 0, 1)
y_th = F.adjust_sharpness(x_th, 2) y_th = F.adjust_sharpness(x_th, 2)
self.assertTrue(np.allclose(y_np, y_th.numpy())) torch.testing.assert_close(y_np, y_th.numpy())
def test_adjust_gamma(self): def test_adjust_gamma(self):
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
...@@ -1377,21 +1390,21 @@ class Tester(unittest.TestCase): ...@@ -1377,21 +1390,21 @@ class Tester(unittest.TestCase):
# test 0 # test 0
y_pil = F.adjust_gamma(x_pil, 1) y_pil = F.adjust_gamma(x_pil, 1)
y_np = np.array(y_pil) y_np = np.array(y_pil)
self.assertTrue(np.allclose(y_np, x_np)) torch.testing.assert_close(y_np, x_np)
# test 1 # test 1
y_pil = F.adjust_gamma(x_pil, 0.5) y_pil = F.adjust_gamma(x_pil, 0.5)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16] y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
# test 2 # test 2
y_pil = F.adjust_gamma(x_pil, 2) y_pil = F.adjust_gamma(x_pil, 2)
y_np = np.array(y_pil) y_np = np.array(y_pil)
y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0] y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) torch.testing.assert_close(y_np, y_ans)
def test_adjusts_L_mode(self): def test_adjusts_L_mode(self):
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
...@@ -1450,10 +1463,10 @@ class Tester(unittest.TestCase): ...@@ -1450,10 +1463,10 @@ class Tester(unittest.TestCase):
cov += np.dot(xwhite, xwhite.T) / num_features cov += np.dot(xwhite, xwhite.T) / num_features
mean += np.sum(xwhite) / num_features mean += np.sum(xwhite) / num_features
# if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
self.assertTrue(np.allclose(cov / num_samples, np.identity(1), rtol=2e-3), torch.testing.assert_close(cov / num_samples, np.identity(1), rtol=2e-3, atol=1e-8, check_dtype=False,
"cov not close to 1") msg="cov not close to 1")
self.assertTrue(np.allclose(mean / num_samples, 0, rtol=1e-3), torch.testing.assert_close(mean / num_samples, 0, rtol=1e-3, atol=1e-8, check_dtype=False,
"mean not close to 0") msg="mean not close to 0")
# Checking if LinearTransformation can be printed as string # Checking if LinearTransformation can be printed as string
whitening.__repr__() whitening.__repr__()
...@@ -1491,7 +1504,7 @@ class Tester(unittest.TestCase): ...@@ -1491,7 +1504,7 @@ class Tester(unittest.TestCase):
result_a = F.rotate(img, 90) result_a = F.rotate(img, 90)
result_b = F.rotate(img, -270) result_b = F.rotate(img, -270)
self.assertTrue(np.all(np.array(result_a) == np.array(result_b))) assert_equal(np.array(result_a), np.array(result_b))
def test_rotate_fill(self): def test_rotate_fill(self):
img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB") img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB")
...@@ -1732,7 +1745,7 @@ class Tester(unittest.TestCase): ...@@ -1732,7 +1745,7 @@ class Tester(unittest.TestCase):
gray_np_1 = np.array(gray_pil_1) gray_np_1 = np.array(gray_pil_1)
self.assertEqual(gray_pil_1.mode, 'L', 'mode should be L') self.assertEqual(gray_pil_1.mode, 'L', 'mode should be L')
self.assertEqual(gray_np_1.shape, tuple(x_shape[0:2]), 'should be 1 channel') self.assertEqual(gray_np_1.shape, tuple(x_shape[0:2]), 'should be 1 channel')
np.testing.assert_equal(gray_np, gray_np_1) assert_equal(gray_np, gray_np_1)
# Case 2: RGB -> 3 channel grayscale # Case 2: RGB -> 3 channel grayscale
trans2 = transforms.Grayscale(num_output_channels=3) trans2 = transforms.Grayscale(num_output_channels=3)
...@@ -1740,9 +1753,9 @@ class Tester(unittest.TestCase): ...@@ -1740,9 +1753,9 @@ class Tester(unittest.TestCase):
gray_np_2 = np.array(gray_pil_2) gray_np_2 = np.array(gray_pil_2)
self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB') self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel') self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
np.testing.assert_equal(gray_np, gray_np_2[:, :, 0]) assert_equal(gray_np, gray_np_2[:, :, 0], check_stride=False)
# Case 3: 1 channel grayscale -> 1 channel grayscale # Case 3: 1 channel grayscale -> 1 channel grayscale
trans3 = transforms.Grayscale(num_output_channels=1) trans3 = transforms.Grayscale(num_output_channels=1)
...@@ -1750,7 +1763,7 @@ class Tester(unittest.TestCase): ...@@ -1750,7 +1763,7 @@ class Tester(unittest.TestCase):
gray_np_3 = np.array(gray_pil_3) gray_np_3 = np.array(gray_pil_3)
self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L') self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel') self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
np.testing.assert_equal(gray_np, gray_np_3) assert_equal(gray_np, gray_np_3)
# Case 4: 1 channel grayscale -> 3 channel grayscale # Case 4: 1 channel grayscale -> 3 channel grayscale
trans4 = transforms.Grayscale(num_output_channels=3) trans4 = transforms.Grayscale(num_output_channels=3)
...@@ -1758,9 +1771,9 @@ class Tester(unittest.TestCase): ...@@ -1758,9 +1771,9 @@ class Tester(unittest.TestCase):
gray_np_4 = np.array(gray_pil_4) gray_np_4 = np.array(gray_pil_4)
self.assertEqual(gray_pil_4.mode, 'RGB', 'mode should be RGB') self.assertEqual(gray_pil_4.mode, 'RGB', 'mode should be RGB')
self.assertEqual(gray_np_4.shape, tuple(x_shape), 'should be 3 channel') self.assertEqual(gray_np_4.shape, tuple(x_shape), 'should be 3 channel')
np.testing.assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1]) assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2]) assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
np.testing.assert_equal(gray_np, gray_np_4[:, :, 0]) assert_equal(gray_np, gray_np_4[:, :, 0], check_stride=False)
# Checking if Grayscale can be printed as string # Checking if Grayscale can be printed as string
trans4.__repr__() trans4.__repr__()
...@@ -1827,9 +1840,9 @@ class Tester(unittest.TestCase): ...@@ -1827,9 +1840,9 @@ class Tester(unittest.TestCase):
gray_np_2 = np.array(gray_pil_2) gray_np_2 = np.array(gray_pil_2)
self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB') self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel') self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
np.testing.assert_equal(gray_np, gray_np_2[:, :, 0]) assert_equal(gray_np, gray_np_2[:, :, 0], check_stride=False)
# Case 3b: RGB -> 3 channel grayscale (unchanged) # Case 3b: RGB -> 3 channel grayscale (unchanged)
trans2 = transforms.RandomGrayscale(p=0.0) trans2 = transforms.RandomGrayscale(p=0.0)
...@@ -1837,7 +1850,7 @@ class Tester(unittest.TestCase): ...@@ -1837,7 +1850,7 @@ class Tester(unittest.TestCase):
gray_np_2 = np.array(gray_pil_2) gray_np_2 = np.array(gray_pil_2)
self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB') self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB')
self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel') self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel')
np.testing.assert_equal(x_np, gray_np_2) assert_equal(x_np, gray_np_2)
# Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled) # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled)
trans3 = transforms.RandomGrayscale(p=1.0) trans3 = transforms.RandomGrayscale(p=1.0)
...@@ -1845,7 +1858,7 @@ class Tester(unittest.TestCase): ...@@ -1845,7 +1858,7 @@ class Tester(unittest.TestCase):
gray_np_3 = np.array(gray_pil_3) gray_np_3 = np.array(gray_pil_3)
self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L') self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel') self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
np.testing.assert_equal(gray_np, gray_np_3) assert_equal(gray_np, gray_np_3)
# Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged) # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged)
trans3 = transforms.RandomGrayscale(p=0.0) trans3 = transforms.RandomGrayscale(p=0.0)
...@@ -1853,7 +1866,7 @@ class Tester(unittest.TestCase): ...@@ -1853,7 +1866,7 @@ class Tester(unittest.TestCase):
gray_np_3 = np.array(gray_pil_3) gray_np_3 = np.array(gray_pil_3)
self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L') self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L')
self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel') self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel')
np.testing.assert_equal(gray_np, gray_np_3) assert_equal(gray_np, gray_np_3)
# Checking if RandomGrayscale can be printed as string # Checking if RandomGrayscale can be printed as string
trans3.__repr__() trans3.__repr__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment