"docs/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "ae843a6f9998eaef5d4e9604ae2127cc6cc3fecb"
Commit 63dabcaf authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #3 from pytorch/cifar

cifar 10 and 100
parents e37323d9 754d526f
build/
dist/
torchvision.egg-info/
*/**/__pycache__
*/**/*.pyc
*/**/*~
*~
\ No newline at end of file
...@@ -29,6 +29,7 @@ The following dataset loaders are available: ...@@ -29,6 +29,7 @@ The following dataset loaders are available:
- [LSUN Classification](#lsun) - [LSUN Classification](#lsun)
- [ImageFolder](#imagefolder) - [ImageFolder](#imagefolder)
- [Imagenet-12](#imagenet-12) - [Imagenet-12](#imagenet-12)
- [CIFAR10 and CIFAR100](#cifar)
Datasets have the API: Datasets have the API:
- `__getitem__` - `__getitem__`
...@@ -97,6 +98,15 @@ u'A mountain view with a plume of smoke in the background'] ...@@ -97,6 +98,15 @@ u'A mountain view with a plume of smoke in the background']
- ['bedroom_train', 'church_train', ...] : a list of categories to load - ['bedroom_train', 'church_train', ...] : a list of categories to load
### CIFAR
`dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)`
`dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)`
- `root` : root directory of dataset where there is folder `cifar-10-batches-py`
- `train` : `True` = Training set, `False` = Test set
- `download` : `True` = downloads the dataset from the internet and puts it in root directory. If dataset already downloaded, does not do anything.
### ImageFolder ### ImageFolder
A generic data loader where the images are arranged in this way: A generic data loader where the images are arranged in this way:
......
import torch
import torchvision.datasets as dset
print('\n\nCifar 10')
a = dset.CIFAR10(root="abc/def/ghi", download=True)
print(a[3])
print('\n\nCifar 100')
a = dset.CIFAR100(root="abc/def/ghi", download=True)
print(a[3])
from .lsun import LSUN, LSUNClass from .lsun import LSUN, LSUNClass
from .folder import ImageFolder from .folder import ImageFolder
from .coco import CocoCaptions, CocoDetection from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'ImageFolder',
'CocoCaptions', 'CocoDetection') 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100')
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
class CIFAR10(data.Dataset):
base_folder = 'cifar-10-batches-py'
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_mdf = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.'
+ ' You can use download=True to download it')
# now load the picked numpy arrays
self.train_data = []
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb')
entry = pickle.load(fo)
self.train_data.append(entry['data'])
if 'labels' in entry:
self.train_labels += entry['labels']
else:
self.train_labels += entry['fine_labels']
fo.close()
self.train_data = np.concatenate(self.train_data)
f = self.test_list[0][0]
file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb')
entry = pickle.load(fo)
self.test_data = entry['data']
if 'labels' in entry:
self.test_labels = entry['labels']
else:
self.test_labels = entry['fine_labels']
fo.close()
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return 50000
else:
return 10000
def _check_integrity(self):
import hashlib
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not os.path.isfile(fpath):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False
return True
def download(self):
from six.moves import urllib
import tarfile
import hashlib
root = self.root
fpath = os.path.join(root, self.filename)
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
if self._check_integrity():
print('Files already downloaded and verified')
return
# downloads file
if os.path.isfile(fpath) and \
hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5:
print('Using downloaded file: ' + fpath)
else:
print('Downloading ' + self.url + ' to ' + fpath)
urllib.request.urlretrieve(self.url, fpath)
# extract file
cwd = os.getcwd()
print('Extracting tar file')
tar = tarfile.open(fpath, "r:gz")
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
print('Done!')
class CIFAR100(CIFAR10):
base_folder = 'cifar-100-python'
url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment