Commit 44da562d authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #4 from pytorch/numpy

fix ToTensor to handle numpy
parents 63dabcaf e659e27e
import torch import torch
import torchvision.datasets as dset import torchvision.datasets as dset
import torchvision.transforms as transforms
print('\n\nCifar 10') # print('\n\nCifar 10')
a = dset.CIFAR10(root="abc/def/ghi", download=True) # a = dset.CIFAR10(root="abc/def/ghi", download=True)
print(a[3]) # print(a[3])
print('\n\nCifar 100') # print('\n\nCifar 100')
a = dset.CIFAR100(root="abc/def/ghi", download=True) # a = dset.CIFAR100(root="abc/def/ghi", download=True)
print(a[3]) # print(a[3])
dataset = dset.CIFAR10(root='cifar', download=True, transform=transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32,
shuffle=True, num_workers=2)
# miter = dataloader.__iter__()
# def getBatch():
# global miter
# try:
# return miter.next()
# except StopIteration:
# miter = dataloader.__iter__()
# return miter.next()
# i=0
# while True:
# print(i)
# img, target = getBatch()
# i+=1
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import math import math
import random import random
from PIL import Image from PIL import Image
import numpy as np
class Compose(object): class Compose(object):
...@@ -16,11 +17,16 @@ class Compose(object): ...@@ -16,11 +17,16 @@ class Compose(object):
class ToTensor(object): class ToTensor(object):
def __call__(self, pic): def __call__(self, pic):
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) if isinstance(pic, np.ndarray):
img = img.view(pic.size[0], pic.size[1], 3) # handle numpy array
# put it in CHW format img = torch.from_numpy(pic)
# yikes, this transpose takes 80% of the loading time/CPU else:
img = img.transpose(0, 2).transpose(1, 2).contiguous() # handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[0], pic.size[1], 3)
# put it in CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 2).transpose(1, 2).contiguous()
return img.float() return img.float()
class Normalize(object): class Normalize(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment