Commit be6f6c29 authored by surgan12's avatar surgan12 Committed by Francisco Massa
Browse files

modes added (#688)

* modes added

* tests_added

* Update test_transforms.py

* Update test_transforms.py

* Update test_transforms.py
parent 6ee98fc6
...@@ -491,6 +491,53 @@ class Tester(unittest.TestCase): ...@@ -491,6 +491,53 @@ class Tester(unittest.TestCase):
assert img.mode == mode assert img.mode == mode
assert np.allclose(img_data[:, :, 0], img) assert np.allclose(img_data[:, :, 0], img)
def test_2_channel_ndarray_to_pil_image(self):
def verify_img_data(img_data, mode):
if mode is None:
img = transforms.ToPILImage()(img_data)
assert img.mode == 'LA' # default should assume LA
else:
img = transforms.ToPILImage(mode=mode)(img_data)
assert img.mode == mode
split = img.split()
for i in range(2):
assert np.allclose(img_data[:, :, i], split[i])
img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
for mode in [None, 'LA']:
verify_img_data(img_data, mode)
transforms.ToPILImage().__repr__()
with self.assertRaises(ValueError):
# should raise if we try a mode for 4 or 1 or 3 channel images
transforms.ToPILImage(mode='RGBA')(img_data)
transforms.ToPILImage(mode='P')(img_data)
transforms.ToPILImage(mode='RGB')(img_data)
def test_2_channel_tensor_to_pil_image(self):
def verify_img_data(img_data, expected_output, mode):
if mode is None:
img = transforms.ToPILImage()(img_data)
assert img.mode == 'LA' # default should assume LA
else:
img = transforms.ToPILImage(mode=mode)(img_data)
assert img.mode == mode
split = img.split()
for i in range(2):
assert np.allclose(expected_output[i].numpy(), F.to_tensor(split[i]).numpy())
img_data = torch.Tensor(2, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255)
for mode in [None, 'LA']:
verify_img_data(img_data, expected_output, mode=mode)
with self.assertRaises(ValueError):
# should raise if we try a mode for 4 or 1 or 3 channel images
transforms.ToPILImage(mode='RGBA')(img_data)
transforms.ToPILImage(mode='P')(img_data)
transforms.ToPILImage(mode='RGB')(img_data)
def test_3_channel_tensor_to_pil_image(self): def test_3_channel_tensor_to_pil_image(self):
def verify_img_data(img_data, expected_output, mode): def verify_img_data(img_data, expected_output, mode):
if mode is None: if mode is None:
...@@ -509,9 +556,10 @@ class Tester(unittest.TestCase): ...@@ -509,9 +556,10 @@ class Tester(unittest.TestCase):
verify_img_data(img_data, expected_output, mode=mode) verify_img_data(img_data, expected_output, mode=mode)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# should raise if we try a mode for 4 or 1 channel images # should raise if we try a mode for 4 or 1 or 2 channel images
transforms.ToPILImage(mode='RGBA')(img_data) transforms.ToPILImage(mode='RGBA')(img_data)
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode='P')(img_data)
transforms.ToPILImage(mode='LA')(img_data)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_()) transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_())
...@@ -536,9 +584,10 @@ class Tester(unittest.TestCase): ...@@ -536,9 +584,10 @@ class Tester(unittest.TestCase):
transforms.ToPILImage().__repr__() transforms.ToPILImage().__repr__()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# should raise if we try a mode for 4 or 1 channel images # should raise if we try a mode for 4 or 1 or 2 channel images
transforms.ToPILImage(mode='RGBA')(img_data) transforms.ToPILImage(mode='RGBA')(img_data)
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode='P')(img_data)
transforms.ToPILImage(mode='LA')(img_data)
def test_4_channel_tensor_to_pil_image(self): def test_4_channel_tensor_to_pil_image(self):
def verify_img_data(img_data, expected_output, mode): def verify_img_data(img_data, expected_output, mode):
...@@ -555,13 +604,14 @@ class Tester(unittest.TestCase): ...@@ -555,13 +604,14 @@ class Tester(unittest.TestCase):
img_data = torch.Tensor(4, 4, 4).uniform_() img_data = torch.Tensor(4, 4, 4).uniform_()
expected_output = img_data.mul(255).int().float().div(255) expected_output = img_data.mul(255).int().float().div(255)
for mode in [None, 'RGBA', 'CMYK']: for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
verify_img_data(img_data, expected_output, mode) verify_img_data(img_data, expected_output, mode)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# should raise if we try a mode for 3 or 1 channel images # should raise if we try a mode for 3 or 1 or 2 channel images
transforms.ToPILImage(mode='RGB')(img_data) transforms.ToPILImage(mode='RGB')(img_data)
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode='P')(img_data)
transforms.ToPILImage(mode='LA')(img_data)
def test_4_channel_ndarray_to_pil_image(self): def test_4_channel_ndarray_to_pil_image(self):
def verify_img_data(img_data, mode): def verify_img_data(img_data, mode):
...@@ -576,13 +626,14 @@ class Tester(unittest.TestCase): ...@@ -576,13 +626,14 @@ class Tester(unittest.TestCase):
assert np.allclose(img_data[:, :, i], split[i]) assert np.allclose(img_data[:, :, i], split[i])
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
for mode in [None, 'RGBA', 'CMYK']: for mode in [None, 'RGBA', 'CMYK', 'RGBX']:
verify_img_data(img_data, mode) verify_img_data(img_data, mode)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# should raise if we try a mode for 3 or 1 channel images # should raise if we try a mode for 3 or 1 or 2 channel images
transforms.ToPILImage(mode='RGB')(img_data) transforms.ToPILImage(mode='RGB')(img_data)
transforms.ToPILImage(mode='P')(img_data) transforms.ToPILImage(mode='P')(img_data)
transforms.ToPILImage(mode='LA')(img_data)
def test_2d_tensor_to_pil_image(self): def test_2d_tensor_to_pil_image(self):
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
......
...@@ -153,8 +153,16 @@ def to_pil_image(pic, mode=None): ...@@ -153,8 +153,16 @@ def to_pil_image(pic, mode=None):
.format(mode, np.dtype, expected_mode)) .format(mode, np.dtype, expected_mode))
mode = expected_mode mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA']
if mode is not None and mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'LA'
elif npimg.shape[2] == 4: elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK'] permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if mode is not None and mode not in permitted_4_channel_modes: if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment