Commit 6fc69a42 authored by Alykhan Tejani's avatar Alykhan Tejani Committed by Soumith Chintala
Browse files

fix for to_tensor when input is np.ndarray of shape [H,W,C]. (#55)

* fix for to_tensor when input is np.ndarray of shape [H,W,C]. Issue #48 pytorch/vision

* update cifar datasets to transpose images from CHW -> HWC

* fix flake8 issue on test_transforms.py
parent 520f35c4
...@@ -115,6 +115,20 @@ class Tester(unittest.TestCase): ...@@ -115,6 +115,20 @@ class Tester(unittest.TestCase):
y = trans(x) y = trans(x)
assert (y.equal(x)) assert (y.equal(x))
def test_to_tensor(self):
channels = 3
height, width = 4, 4
trans = transforms.ToTensor()
input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
img = transforms.ToPILImage()(input_data)
output = trans(img)
assert np.allclose(input_data.numpy(), output.numpy())
ndarray = np.random.randint(low=0, high=255, size=(height, width, channels))
output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
assert np.allclose(output.numpy(), expected_output)
def test_tensor_to_pil_image(self): def test_tensor_to_pil_image(self):
trans = transforms.ToPILImage() trans = transforms.ToPILImage()
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
......
...@@ -63,6 +63,7 @@ class CIFAR10(data.Dataset): ...@@ -63,6 +63,7 @@ class CIFAR10(data.Dataset):
self.train_data = np.concatenate(self.train_data) self.train_data = np.concatenate(self.train_data)
self.train_data = self.train_data.reshape((50000, 3, 32, 32)) self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
else: else:
f = self.test_list[0][0] f = self.test_list[0][0]
file = os.path.join(root, self.base_folder, f) file = os.path.join(root, self.base_folder, f)
...@@ -78,6 +79,7 @@ class CIFAR10(data.Dataset): ...@@ -78,6 +79,7 @@ class CIFAR10(data.Dataset):
self.test_labels = entry['fine_labels'] self.test_labels = entry['fine_labels']
fo.close() fo.close()
self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
def __getitem__(self, index): def __getitem__(self, index):
if self.train: if self.train:
...@@ -87,7 +89,7 @@ class CIFAR10(data.Dataset): ...@@ -87,7 +89,7 @@ class CIFAR10(data.Dataset):
# doing this so that it is consistent with all other datasets # doing this so that it is consistent with all other datasets
# to return a PIL Image # to return a PIL Image
img = Image.fromarray(np.transpose(img, (1, 2, 0))) img = Image.fromarray(img)
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
......
...@@ -38,7 +38,7 @@ class ToTensor(object): ...@@ -38,7 +38,7 @@ class ToTensor(object):
def __call__(self, pic): def __call__(self, pic):
if isinstance(pic, np.ndarray): if isinstance(pic, np.ndarray):
# handle numpy array # handle numpy array
img = torch.from_numpy(pic) img = torch.from_numpy(pic.transpose((2, 0, 1)))
else: else:
# handle PIL Image # handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment