Unverified Commit a51c49e4 authored by Alexey Demyanchuk's avatar Alexey Demyanchuk Committed by GitHub
Browse files

Add explicit check for number of channels (#3013)



* Add explicit check for number of channels

Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
            pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3

* Check number of channels before processing

* Add test for invalid number of channels

* Add explicit check for number of channels

Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
            pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3

* Check number of channels before processing

* Add test for invalid number of channels

* Put check after channel dim unsqueeze

* Add test if error message is matching

* Delete redundant code

* Bug fix in checking for bad types
Co-authored-by: default avatarDemyanchuk <demyanca@mh-hannover.local>
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent cd0268cd
......@@ -987,19 +987,27 @@ class Tester(unittest.TestCase):
self.assertTrue(np.allclose(img_data, img))
def test_tensor_bad_types_to_pil_image(self):
with self.assertRaises(ValueError):
with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'):
transforms.ToPILImage()(torch.ones(1, 3, 4, 4))
with self.assertRaisesRegex(ValueError, r'pic should not have > 4 channels. Got \d+ channels.'):
transforms.ToPILImage()(torch.ones(6, 4, 4))
def test_ndarray_bad_types_to_pil_image(self):
trans = transforms.ToPILImage()
with self.assertRaises(TypeError):
reg_msg = r'Input type \w+ is not supported'
with self.assertRaisesRegex(TypeError, reg_msg):
trans(np.ones([4, 4, 1], np.int64))
with self.assertRaisesRegex(TypeError, reg_msg):
trans(np.ones([4, 4, 1], np.uint16))
with self.assertRaisesRegex(TypeError, reg_msg):
trans(np.ones([4, 4, 1], np.uint32))
with self.assertRaisesRegex(TypeError, reg_msg):
trans(np.ones([4, 4, 1], np.float64))
with self.assertRaises(ValueError):
with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'):
transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
with self.assertRaisesRegex(ValueError, r'pic should not have > 4 channels. Got \d+ channels.'):
transforms.ToPILImage()(np.ones([4, 4, 6]))
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_vertical_flip(self):
......
......@@ -183,6 +183,10 @@ def to_pil_image(pic, mode=None):
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)
# check number of channels
if pic.shape[-3] > 4:
raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-3]))
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
......@@ -191,6 +195,10 @@ def to_pil_image(pic, mode=None):
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
# check number of channels
if pic.shape[-1] > 4:
raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-1]))
npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != 'F':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment