"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9d49b45b190bc953eb965abd3d70ec30a799f505"
Commit f7c78114 authored by Tzu-Wei Huang's avatar Tzu-Wei Huang Committed by Soumith Chintala
Browse files

make ToTensor() cope with all PIL image types (#67)

YCbCr breaks the code without warning. Changed the code to make `ToTensor()` cope with all PIL image types.
parent d359dfdf
...@@ -29,7 +29,7 @@ class Compose(object): ...@@ -29,7 +29,7 @@ class Compose(object):
class ToTensor(object): class ToTensor(object):
"""Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
""" """
def __call__(self, pic): def __call__(self, pic):
...@@ -39,7 +39,12 @@ class ToTensor(object): ...@@ -39,7 +39,12 @@ class ToTensor(object):
else: else:
# handle PIL Image # handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode)) # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format # put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU # yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous() img = img.transpose(0, 1).transpose(0, 2).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment