"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9f10306b3fd8168a100e749716e99b75b769e3ef"
Commit 683852d2 authored by Alykhan Tejani's avatar Alykhan Tejani Committed by Soumith Chintala
Browse files

add single image support to ToPILImage + tests (#54)

parent 14d4e4a7
...@@ -2,9 +2,11 @@ import torch ...@@ -2,9 +2,11 @@ import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import unittest import unittest
import random import random
import numpy as np
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def test_crop(self): def test_crop(self):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
...@@ -113,6 +115,45 @@ class Tester(unittest.TestCase): ...@@ -113,6 +115,45 @@ class Tester(unittest.TestCase):
y = trans(x) y = trans(x)
assert (y.equal(x)) assert (y.equal(x))
def test_tensor_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
img_data = torch.Tensor(3, 4, 4).uniform_()
img = trans(img_data)
assert img.getbands() == ('R', 'G', 'B')
r, g, b = img.split()
expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(r).numpy())
assert np.allclose(expected_output[1].numpy(), to_tensor(g).numpy())
assert np.allclose(expected_output[2].numpy(), to_tensor(b).numpy())
# single channel image
img_data = torch.Tensor(1, 4, 4).uniform_()
img = trans(img_data)
assert img.getbands() == ('L',)
l, = img.split()
expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(l).numpy())
def test_ndarray_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
img = trans(img_data)
assert img.getbands() == ('R', 'G', 'B')
r, g, b = img.split()
assert np.allclose(r, img_data[:, :, 0])
assert np.allclose(g, img_data[:, :, 1])
assert np.allclose(b, img_data[:, :, 2])
# single channel image
img_data = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
img = trans(img_data)
assert img.getbands() == ('L',)
l, = img.split()
assert np.allclose(l, img_data[:, :, 0])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -52,14 +52,18 @@ class ToPILImage(object): ...@@ -52,14 +52,18 @@ class ToPILImage(object):
to a PIL.Image of range [0, 255] to a PIL.Image of range [0, 255]
""" """
def __call__(self, pic): def __call__(self, pic):
if isinstance(pic, np.ndarray): npimg = pic
# handle numpy array mode = None
img = Image.fromarray(pic) if not isinstance(npimg, np.ndarray):
else:
npimg = pic.mul(255).byte().numpy() npimg = pic.mul(255).byte().numpy()
npimg = np.transpose(npimg, (1,2,0)) npimg = np.transpose(npimg, (1,2,0))
img = Image.fromarray(npimg)
return img if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
mode = "L"
return Image.fromarray(npimg, mode=mode)
class Normalize(object): class Normalize(object):
"""Given mean: (R, G, B) and std: (R, G, B), """Given mean: (R, G, B) and std: (R, G, B),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment