Commit 6fc69a42 authored by Alykhan Tejani's avatar Alykhan Tejani Committed by Soumith Chintala
Browse files

fix for to_tensor when input is np.ndarray of shape [H,W,C]. (#55)

* fix for to_tensor when input is np.ndarray of shape [H,W,C]. Issue #48 pytorch/vision

* update cifar datasets to transpose images from CHW -> HWC

* fix flake8 issue on test_transforms.py
parent 520f35c4
......@@ -115,6 +115,20 @@ class Tester(unittest.TestCase):
y = trans(x)
assert (y.equal(x))
def test_to_tensor(self):
channels = 3
height, width = 4, 4
trans = transforms.ToTensor()
input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
img = transforms.ToPILImage()(input_data)
output = trans(img)
assert np.allclose(input_data.numpy(), output.numpy())
ndarray = np.random.randint(low=0, high=255, size=(height, width, channels))
output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
assert np.allclose(output.numpy(), expected_output)
def test_tensor_to_pil_image(self):
trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
......
......@@ -63,6 +63,7 @@ class CIFAR10(data.Dataset):
self.train_data = np.concatenate(self.train_data)
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
else:
f = self.test_list[0][0]
file = os.path.join(root, self.base_folder, f)
......@@ -78,6 +79,7 @@ class CIFAR10(data.Dataset):
self.test_labels = entry['fine_labels']
fo.close()
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
def __getitem__(self, index):
if self.train:
......@@ -87,7 +89,7 @@ class CIFAR10(data.Dataset):
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
......
......@@ -38,7 +38,7 @@ class ToTensor(object):
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic)
img = torch.from_numpy(pic.transpose((2, 0, 1)))
else:
# handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment