• Alexey Demyanchuk's avatar
    Add explicit check for number of channels (#3013) · a51c49e4
    Alexey Demyanchuk authored
    
    
    * Add explicit check for number of channels
    
    Example why you need to check it:
    `M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
    When you put this input through to_pil_image without mode argument, it converts to uint8 here:
    ```
    if pic.is_floating_point() and mode != 'F':
                pic = pic.mul(255).byte()
    ```
    and change the mode to RGB here:
    ```
    if mode is None and npimg.dtype == np.uint8:
                mode = 'RGB'
    ```
    Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3
    
    * Check number of channels before processing
    
    * Add test for invalid number of channels
    
    * Add explicit check for number of channels
    
    Example why you need to check it:
    `M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
    When you put this input through to_pil_image without mode argument, it converts to uint8 here:
    ```
    if pic.is_floating_point() and mode != 'F':
                pic = pic.mul(255).byte()
    ```
    and change the mode to RGB here:
    ```
    if mode is None and npimg.dtype == np.uint8:
                mode = 'RGB'
    ```
    Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3
    
    * Check number of channels before processing
    
    * Add test for invalid number of channels
    
    * Put check after channel dim unsqueeze
    
    * Add test if error message is matching
    
    * Delete redundant code
    
    * Bug fix in checking for bad types
    Co-authored-by: default avatarDemyanchuk <demyanca@mh-hannover.local>
    Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
    a51c49e4
functional.py 42.9 KB