Commit 683852d2 authored by Alykhan Tejani's avatar Alykhan Tejani Committed by Soumith Chintala
Browse files

add single image support to ToPILImage + tests (#54)

parent 14d4e4a7
......@@ -2,9 +2,11 @@ import torch
import torchvision.transforms as transforms
import unittest
import random
import numpy as np
class Tester(unittest.TestCase):
def test_crop(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
......@@ -113,6 +115,45 @@ class Tester(unittest.TestCase):
y = trans(x)
assert (y.equal(x))
def test_tensor_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
img_data = torch.Tensor(3, 4, 4).uniform_()
img = trans(img_data)
assert img.getbands() == ('R', 'G', 'B')
r, g, b = img.split()
expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(r).numpy())
assert np.allclose(expected_output[1].numpy(), to_tensor(g).numpy())
assert np.allclose(expected_output[2].numpy(), to_tensor(b).numpy())
# single channel image
img_data = torch.Tensor(1, 4, 4).uniform_()
img = trans(img_data)
assert img.getbands() == ('L',)
l, = img.split()
expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(l).numpy())
def test_ndarray_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
img = trans(img_data)
assert img.getbands() == ('R', 'G', 'B')
r, g, b = img.split()
assert np.allclose(r, img_data[:, :, 0])
assert np.allclose(g, img_data[:, :, 1])
assert np.allclose(b, img_data[:, :, 2])
# single channel image
img_data = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
img = trans(img_data)
assert img.getbands() == ('L',)
l, = img.split()
assert np.allclose(l, img_data[:, :, 0])
if __name__ == '__main__':
unittest.main()
......@@ -52,14 +52,18 @@ class ToPILImage(object):
to a PIL.Image of range [0, 255]
"""
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
img = Image.fromarray(pic)
else:
npimg = pic
mode = None
if not isinstance(npimg, np.ndarray):
npimg = pic.mul(255).byte().numpy()
npimg = np.transpose(npimg, (1,2,0))
img = Image.fromarray(npimg)
return img
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
mode = "L"
return Image.fromarray(npimg, mode=mode)
class Normalize(object):
"""Given mean: (R, G, B) and std: (R, G, B),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment