Commit 2d493d6a authored by Varun Agrawal's avatar Varun Agrawal Committed by Francisco Massa
Browse files

Fix for #409 (#673)

* added separate checks for dimensionality in to_pil_image and added tests

* updated to_pil_image to use both 2D ndarrays and tensors, as well as refactored the tests
parent e8e04f06
......@@ -513,6 +513,9 @@ class Tester(unittest.TestCase):
transforms.ToPILImage(mode='RGBA')(img_data)
transforms.ToPILImage(mode='P')(img_data)
with self.assertRaises(ValueError):
transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_())
def test_3_channel_ndarray_to_pil_image(self):
def verify_img_data(img_data, mode):
if mode is None:
......@@ -581,6 +584,45 @@ class Tester(unittest.TestCase):
transforms.ToPILImage(mode='RGB')(img_data)
transforms.ToPILImage(mode='P')(img_data)
def test_2d_tensor_to_pil_image(self):
to_tensor = transforms.ToTensor()
img_data_float = torch.Tensor(4, 4).uniform_()
img_data_byte = torch.ByteTensor(4, 4).random_(0, 255)
img_data_short = torch.ShortTensor(4, 4).random_()
img_data_int = torch.IntTensor(4, 4).random_()
inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
expected_outputs = [img_data_float.mul(255).int().float().div(255).numpy(),
img_data_byte.float().div(255.0).numpy(),
img_data_short.numpy(),
img_data_int.numpy()]
expected_modes = ['L', 'L', 'I;16', 'I']
for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes):
for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
img = transform(img_data)
assert img.mode == mode
assert np.allclose(expected_output, to_tensor(img).numpy())
def test_2d_ndarray_to_pil_image(self):
img_data_float = torch.Tensor(4, 4).uniform_().numpy()
img_data_byte = torch.ByteTensor(4, 4).random_(0, 255).numpy()
img_data_short = torch.ShortTensor(4, 4).random_().numpy()
img_data_int = torch.IntTensor(4, 4).random_().numpy()
inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
expected_modes = ['F', 'L', 'I;16', 'I']
for img_data, mode in zip(inputs, expected_modes):
for transform in [transforms.ToPILImage(), transforms.ToPILImage(mode=mode)]:
img = transform(img_data)
assert img.mode == mode
assert np.allclose(img_data, img)
def test_tensor_bad_types_to_pil_image(self):
with self.assertRaises(ValueError):
transforms.ToPILImage()(torch.ones(1, 3, 4, 4))
def test_ndarray_bad_types_to_pil_image(self):
trans = transforms.ToPILImage()
with self.assertRaises(TypeError):
......@@ -589,6 +631,9 @@ class Tester(unittest.TestCase):
trans(np.ones([4, 4, 1], np.uint32))
trans(np.ones([4, 4, 1], np.float64))
with self.assertRaises(ValueError):
transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_vertical_flip(self):
random_state = random.getstate()
......
......@@ -105,13 +105,29 @@ def to_pil_image(pic, mode=None):
Returns:
PIL Image: Image converted to PIL Image.
"""
if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic.unsqueeze_(0)
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
npimg = pic
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
if isinstance(pic, torch.Tensor):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment