Add explicit check for number of channels (#3013)
* Add explicit check for number of channels
Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3
* Check number of channels before processing
* Add test for invalid number of channels
* Add explicit check for number of channels
Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3
* Check number of channels before processing
* Add test for invalid number of channels
* Put check after channel dim unsqueeze
* Add test if error message is matching
* Delete redundant code
* Bug fix in checking for bad types
Co-authored-by:
Demyanchuk <demyanca@mh-hannover.local>
Co-authored-by:
vfdev <vfdev.5@gmail.com>
Showing
Please register or sign in to comment