Commit 44d69d40 authored by Soumith Chintala's avatar Soumith Chintala
Browse files

add a fakedata generator for easy debugging

parent 3211b7ee
...@@ -6,9 +6,10 @@ from .stl10 import STL10 ...@@ -6,9 +6,10 @@ from .stl10 import STL10
from .mnist import MNIST from .mnist import MNIST
from .svhn import SVHN from .svhn import SVHN
from .phototour import PhotoTour from .phototour import PhotoTour
from .fakedata import FakeData
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'CIFAR10', 'CIFAR100',
'MNIST', 'STL10', 'SVHN', 'PhotoTour') 'MNIST', 'STL10', 'SVHN', 'PhotoTour')
import torch
import torch.utils.data as data
from .. import transforms
class FakeData(data.Dataset):
"""A fake dataset that returns randomly generated images and returns them as PIL images
Args:
size (int, optional): Size of the dataset. Default: 1000 images
image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
num_classes(int, optional): Number of classes in the datset. Default: 10
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None):
self.size = size
self.num_classes = num_classes
self.image_size = image_size
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
# create random image that is consistent with the index id
rng_state = torch.get_rng_state()
torch.manual_seed(index)
img = torch.randn(*self.image_size)
target = torch.Tensor(1).random_(0, self.num_classes)[0]
torch.set_rng_state(rng_state)
# convert to PIL Image
img = transforms.ToPILImage()(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return self.size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment