"vscode:/vscode.git/clone" did not exist on "0acac5cb7a9426614998802dbb1c4fcb797f2598"
Commit 991bad2f authored by Bodo Kaiser's avatar Bodo Kaiser Committed by Soumith Chintala
Browse files

updated ToTensor to support more types

parent 6cbb22bb
......@@ -151,6 +151,30 @@ class Tester(unittest.TestCase):
expected_output = img_data.mul(255).int().float().div(255)
assert np.allclose(expected_output[0].numpy(), to_tensor(l).numpy())
def test_tensor_gray_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
img_data_short = torch.ShortTensor(1, 4, 4).random_()
img_data_int = torch.IntTensor(1, 4, 4).random_()
img_data_float = torch.FloatTensor(1, 4, 4).uniform_()
img_byte = trans(img_data_byte)
img_short = trans(img_data_short)
img_int = trans(img_data_int)
img_float = trans(img_data_float)
assert img_byte.mode == 'L'
assert img_short.mode == 'I;16'
assert img_int.mode == 'I'
#assert img_float.mode == 'F'
assert np.allclose(img_data_short.numpy(), to_tensor(img_short).numpy())
assert np.allclose(img_data_int.numpy(), to_tensor(img_int).numpy())
# would cause breaking changes as ToTensor converts to range [0, 1]
#assert np.allclose(img_data_byte.numpy(), to_tensor(img_byte).numpy())
#assert np.allclose(img_data_float.numpy(), to_tensor(img_float).numpy())
def test_ndarray_to_pil_image(self):
trans = transforms.ToPILImage()
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
......
......@@ -39,19 +39,30 @@ class ToTensor(object):
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
else:
# backard compability
return img.float().div(255)
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
class ToPILImage(object):
......@@ -67,7 +78,6 @@ class ToPILImage(object):
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
......@@ -83,7 +93,6 @@ class ToPILImage(object):
if npimg.dtype == np.uint8:
mode = 'RGB'
assert mode is not None, '{} is not supported'.format(npimg.dtype)
return Image.fromarray(npimg, mode=mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment