Commit 05bcb18e authored by Soumith Chintala's avatar Soumith Chintala
Browse files

making cifar data loader also return PIL Image

parent 98b9aa54
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -77,6 +77,10 @@ class CIFAR10(data.Dataset): ...@@ -77,6 +77,10 @@ class CIFAR10(data.Dataset):
img, target = self.train_data[index], self.train_labels[index] img, target = self.train_data[index], self.train_labels[index]
else: else:
img, target = self.test_data[index], self.test_labels[index] img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1,2,0)))
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
......
import torch
import math
def make_grid(tensor, nrow=8, padding=2): def make_grid(tensor, nrow=8, padding=2):
""" """
Given a 4D mini-batch Tensor of shape (B x C x H x W), Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size,
makes a grid of images makes a grid of images
""" """
import math tensorlist = None
if isinstance(tensor, list):
tensorlist = tensor
numImages = len(tensorlist)
size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size())
tensor = tensorlist[0].new(size)
for i in range(numImages):
tensor[i].copy_(tensorlist[i])
if tensor.dim() == 3: # single image if tensor.dim() == 3: # single image
return tensor return tensor
# make the mini-batch of images into a grid # make the mini-batch of images into a grid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment