Commit 05bcb18e authored by Soumith Chintala's avatar Soumith Chintala
Browse files

making cifar data loader also return PIL Image

parent 98b9aa54
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -77,6 +77,10 @@ class CIFAR10(data.Dataset):
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1,2,0)))
if self.transform is not None:
img = self.transform(img)
......
import torch
import math
def make_grid(tensor, nrow=8, padding=2):
"""
Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size,
makes a grid of images
"""
import math
tensorlist = None
if isinstance(tensor, list):
tensorlist = tensor
numImages = len(tensorlist)
size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size())
tensor = tensorlist[0].new(size)
for i in range(numImages):
tensor[i].copy_(tensorlist[i])
if tensor.dim() == 3: # single image
return tensor
# make the mini-batch of images into a grid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment