"...git@developer.sourcefind.cn:OpenDAS/fastllm.git" did not exist on "44be91d3d15b485aed091f920e863545a8765489"
Commit f954ea4c authored by Bodo Kaiser's avatar Bodo Kaiser Committed by Soumith Chintala
Browse files

updated transforms.ToPILImage, see #105

parent 831ba8cf
...@@ -169,6 +169,12 @@ class Tester(unittest.TestCase): ...@@ -169,6 +169,12 @@ class Tester(unittest.TestCase):
l, = img.split() l, = img.split()
assert np.allclose(l, img_data[:, :, 0]) assert np.allclose(l, img_data[:, :, 0])
def test_ndarray16_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = np.random.randint(0, 65535, [4, 4, 1], np.uint16)
img = trans(img_data)
assert img.mode == 'I;16'
assert np.allclose(img, img_data[:, :, 0])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -55,21 +55,32 @@ class ToTensor(object): ...@@ -55,21 +55,32 @@ class ToTensor(object):
class ToPILImage(object): class ToPILImage(object):
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W """Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C H x W x C to a PIL.Image while preserving value range.
to a PIL.Image of range [0, 255]
""" """
def __call__(self, pic): def __call__(self, pic):
npimg = pic npimg = pic
mode = None mode = None
if not isinstance(npimg, np.ndarray): if isinstance(pic, torch.FloatTensor):
npimg = pic.mul(255).byte().numpy() pic = pic.mul(255).byte()
npimg = np.transpose(npimg, (1, 2, 0)) if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
if npimg.shape[2] == 1: if npimg.shape[2] == 1:
npimg = npimg[:, :, 0] npimg = npimg[:, :, 0]
mode = "L"
if npimg.dtype == np.uint8:
mode = 'L'
if npimg.dtype == np.uint16:
mode = 'I;16'
elif npimg.dtype == np.float32:
mode = 'F'
else:
if npimg.dtype == np.uint8:
mode = 'RGB'
assert mode is not None, '{} is not supported'.format(npimg.dtype)
return Image.fromarray(npimg, mode=mode) return Image.fromarray(npimg, mode=mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment