Commit c1fdfa68 authored by Elad Hoffer's avatar Elad Hoffer Committed by Soumith Chintala
Browse files

add stl10 dataset (#83)

parent 66183f50
...@@ -49,6 +49,7 @@ The following dataset loaders are available: ...@@ -49,6 +49,7 @@ The following dataset loaders are available:
- `ImageFolder <#imagefolder>`__ - `ImageFolder <#imagefolder>`__
- `Imagenet-12 <#imagenet-12>`__ - `Imagenet-12 <#imagenet-12>`__
- `CIFAR10 and CIFAR100 <#cifar>`__ - `CIFAR10 and CIFAR100 <#cifar>`__
- `STL10 <#stl10>`__
Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass
from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded
...@@ -156,6 +157,18 @@ CIFAR ...@@ -156,6 +157,18 @@ CIFAR
puts it in root directory. If dataset already downloaded, does not do puts it in root directory. If dataset already downloaded, does not do
anything. anything.
STL10
~~~~~
``dset.STL10(root, split='train', transform=None, target_transform=None, download=False)``
- ``root`` : root directory of dataset where there is folder ``stl10_binary``
- ``split`` : ``'train'`` = Training set, ``'test'`` = Test set, ``'unlabeled'`` = Unlabeled set,
``'train+unlabeled'`` = Training + Unlabeled set (missing label marked as ``-1``)
- ``download`` : ``True`` = downloads the dataset from the internet and
puts it in root directory. If dataset already downloaded, does not do
anything.
ImageFolder ImageFolder
~~~~~~~~~~~ ~~~~~~~~~~~
......
...@@ -2,10 +2,11 @@ from .lsun import LSUN, LSUNClass ...@@ -2,10 +2,11 @@ from .lsun import LSUN, LSUNClass
from .folder import ImageFolder from .folder import ImageFolder
from .coco import CocoCaptions, CocoDetection from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
from .mnist import MNIST from .mnist import MNIST
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'ImageFolder',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'CIFAR10', 'CIFAR100',
'MNIST') 'MNIST', 'STL10')
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
from .cifar import CIFAR10
class STL10(CIFAR10):
base_folder = 'stl10_binary'
url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
filename = "stl10_binary.tar.gz"
tgz_md5 = '91f7769df0f17e558f3565bffb0c7dfb'
class_names_file = 'class_names.txt'
train_list = [
['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'],
['train_y.bin', '5a34089d4802c674881badbb80307741'],
['unlabeled_X.bin', '5242ba1fed5e4be9e1e742405eb56ca4']
]
test_list = [
['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'],
['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e']
]
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.split = split # train/test/unlabeled set
if download:
self.download()
if not self._check_integrity():
raise RuntimeError(
'Dataset not found or corrupted. You can use download=True to download it')
# now load the picked numpy arrays
if self.split == 'train':
self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0])
elif self.split == 'train+unlabeled':
self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0])
unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
self.data = np.concatenate((self.data, unlabeled_data))
self.labels = np.concatenate(
(self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
elif self.split == 'unlabeled':
self.data, _ = self.__loadfile(self.train_list[2][0])
self.labels = None
else: # self.split == 'test':
self.data, self.labels = self.__loadfile(
self.test_list[0][0], self.test_list[1][0])
class_file = os.path.join(
root, self.base_folder, self.class_names_file)
if os.path.isfile(class_file):
with open(class_file) as f:
self.classes = f.read().splitlines()
def __getitem__(self, index):
if self.labels is not None:
img, target = self.data[index], int(self.labels[index])
else:
img, target = self.data[index], None
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return self.data.shape[0]
def __loadfile(self, data_file, labels_file=None):
labels = None
if labels_file:
path_to_labels = os.path.join(
self.root, self.base_folder, labels_file)
with open(path_to_labels, 'rb') as f:
labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based
path_to_data = os.path.join(self.root, self.base_folder, data_file)
with open(path_to_data, 'rb') as f:
# read whole file in uint8 chunks
everything = np.fromfile(f, dtype=np.uint8)
images = np.reshape(everything, (-1, 3, 96, 96))
images = np.transpose(images, (0, 1, 3, 2))
return images, labels
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment