Commit 0c75d99d authored by Surgan Jandial's avatar Surgan Jandial Committed by Francisco Massa
Browse files

Allowing 'F' mode for 1 channel FloatTensor (#1100)

* to_pil_image updates

* lint

* Update test_transforms.py

* Update test_transforms.py
parent 487c9bf4
...@@ -528,6 +528,11 @@ class Tester(unittest.TestCase): ...@@ -528,6 +528,11 @@ class Tester(unittest.TestCase):
img = transform(img_data) img = transform(img_data)
assert img.mode == mode assert img.mode == mode
assert np.allclose(expected_output, to_tensor(img).numpy()) assert np.allclose(expected_output, to_tensor(img).numpy())
# 'F' mode for torch.FloatTensor
img_F_mode = transforms.ToPILImage(mode='F')(img_data_float)
assert img_F_mode.mode == 'F'
assert np.allclose(np.array(Image.fromarray(img_data_float.squeeze(0).numpy(), mode='F')),
np.array(img_F_mode))
def test_1_channel_ndarray_to_pil_image(self): def test_1_channel_ndarray_to_pil_image(self):
img_data_float = torch.Tensor(4, 4, 1).uniform_().numpy() img_data_float = torch.Tensor(4, 4, 1).uniform_().numpy()
......
...@@ -135,7 +135,7 @@ def to_pil_image(pic, mode=None): ...@@ -135,7 +135,7 @@ def to_pil_image(pic, mode=None):
pic = np.expand_dims(pic, 2) pic = np.expand_dims(pic, 2)
npimg = pic npimg = pic
if isinstance(pic, torch.FloatTensor): if isinstance(pic, torch.FloatTensor) and mode != 'F':
pic = pic.mul(255).byte() pic = pic.mul(255).byte()
if isinstance(pic, torch.Tensor): if isinstance(pic, torch.Tensor):
npimg = np.transpose(pic.numpy(), (1, 2, 0)) npimg = np.transpose(pic.numpy(), (1, 2, 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment