Commit dd53c974 authored by Aiden Nibali's avatar Aiden Nibali Committed by Francisco Massa
Browse files

Add RGBA support to ToPILImage (#189)

parent 7a54c6be
...@@ -264,6 +264,22 @@ class Tester(unittest.TestCase): ...@@ -264,6 +264,22 @@ class Tester(unittest.TestCase):
assert np.allclose(img_data_short.numpy(), to_tensor(img_short).numpy()) assert np.allclose(img_data_short.numpy(), to_tensor(img_short).numpy())
assert np.allclose(img_data_int.numpy(), to_tensor(img_int).numpy()) assert np.allclose(img_data_int.numpy(), to_tensor(img_int).numpy())
def test_tensor_rgba_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
img_data = torch.Tensor(4, 4, 4).uniform_()
img = trans(img_data)
assert img.mode == 'RGBA'
assert img.getbands() == ('R', 'G', 'B', 'A')
r, g, b, a = img.split()
expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(r).numpy())
assert np.allclose(expected_output[1].numpy(), to_tensor(g).numpy())
assert np.allclose(expected_output[2].numpy(), to_tensor(b).numpy())
assert np.allclose(expected_output[3].numpy(), to_tensor(a).numpy())
def test_ndarray_to_pil_image(self): def test_ndarray_to_pil_image(self):
trans = transforms.ToPILImage() trans = transforms.ToPILImage()
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
......
...@@ -119,6 +119,9 @@ class ToPILImage(object): ...@@ -119,6 +119,9 @@ class ToPILImage(object):
mode = 'I' mode = 'I'
elif npimg.dtype == np.float32: elif npimg.dtype == np.float32:
mode = 'F' mode = 'F'
elif npimg.shape[2] == 4:
if npimg.dtype == np.uint8:
mode = 'RGBA'
else: else:
if npimg.dtype == np.uint8: if npimg.dtype == np.uint8:
mode = 'RGB' mode = 'RGB'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment