Commit 62cfcb70 authored by soumith's avatar soumith
Browse files

first commit

parents
# torch-vision
This repository consists of:
- `vision.datasets` : Data loaders for popular vision datasets
- `vision.transforms` : Common image transformations such as random crop, rotations etc.
- `[WIP] vision.models` : Model definitions and Pre-trained models for popular models such as AlexNet, VGG, ResNet etc.
# Installation
Binaries:
```bash
conda install pytorch-vision -c https://conda.anaconda.org/t/6N-MsQ4WZ7jo/soumith
```
From Source:
```bash
pip install -r requirements.txt
pip install .
```
#!/usr/bin/env python
import os
import shutil
import sys
from setuptools import setup, find_packages
VERSION = '0.1.5'
long_description = '''torch-vision provides DataLoaders, Pre-trained models
and common transforms for torch for images and videos'''
excluded = ['test']
def exclude_package(pkg):
for exclude in excluded:
if pkg.startswith(exclude):
return True
return False
def create_package_list(base_package):
return ([base_package] +
[base_package + '.' + pkg
for pkg
in find_packages(base_package)
if not exclude_package(pkg)])
setup_info = dict(
# Metadata
name='torchvision',
version=VERSION,
author='PyTorch Core Team',
author_email='soumith@pytorch.org',
url='https://github.com/pytorch/vision',
description='image and video datasets and models for torch deep learning',
long_description=long_description,
license='BSD',
# Package info
packages=find_packages(exclude=('test',)), #create_package_list('torchvision'),
zip_safe=True,
)
setup(**setup_info)
import torch
import torchvision
from .lsun import LSUNDataset, LSUNClassDataset
from .folder import ImageFolderDataset
from .coco import CocoCaptionsDataset, CocoDetectionDataset
__all__ = ('LSUNDataset', 'LSUNClassDataset')
import torch.utils.data as data
from PIL import Image
import os
import os.path
class CocoCaptionsDataset(data.Dataset):
def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO
self.root = root
self.coco = COCO(annFile)
self.ids = self.coco.imgs.keys()
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds = img_id)
anns = coco.loadAnns(ann_ids)
target = [ann['caption'] for ann in anns]
path = coco.loadImgs(img_id)[0]['file_name']
img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
class CocoDetectionDataset(data.Dataset):
def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO
self.root = root
self.coco = COCO(annFile)
self.ids = self.coco.imgs.keys()
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds = img_id)
target = coco.loadAnns(ann_ids)
path = coco.loadImgs(img_id)[0]['file_name']
img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def find_classes(dir):
classes = os.listdir(dir)
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(dir, class_to_idx):
images = []
for target in os.listdir(dir):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for filename in os.listdir(d):
if is_image_file(filename):
path = '{0}/{1}'.format(target, filename)
item = (path, class_to_idx[target])
images.append(item)
return images
class ImageFolderDataset(data.Dataset):
def __init__(self, root, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
path, target = self.imgs[index]
img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
import torch.utils.data as data
from PIL import Image
import os
import os.path
import StringIO
import string
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
class LSUNClassDataset(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
import lmdb
self.db_path = db_path
self.env = lmdb.open(db_path, map_size=1099511627776,
max_readers=100, readonly=True)
with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries']
cache_file = '_cache_' + db_path.replace('/', '_')
if os.path.isfile(cache_file):
self.keys = pickle.load( open( cache_file, "rb" ) )
else:
with self.env.begin(write=False) as txn:
self.keys = [ key for key, _ in txn.cursor() ]
pickle.dump( self.keys, open( cache_file, "wb" ) )
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
img, target = None, None
env = self.env
with env.begin(write=False) as txn:
imgbuf = txn.get(self.keys[index])
buf = StringIO.StringIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
class LSUNDataset(data.Dataset):
"""
db_path = root directory for the database files
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
"""
def __init__(self, db_path, classes='train',
transform=None, target_transform=None):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test']
self.db_path = db_path
if type(classes) == str and classes in dset_opts:
classes = [c + '_' + classes for c in categories]
if type(classes) == list:
for c in classes:
c_short = c.split('_')
c_short.pop(len(c_short) - 1)
c_short = string.join(c_short, '_')
if c_short not in categories:
raise(ValueError('Unknown LSUN class: ' + c_short + '.'\
'Options are: ' + str(categories)))
c_short = c.split('_')
c_short = c_short.pop(len(c_short) - 1)
if c_short not in dset_opts:
raise(ValueError('Unknown postfix: ' + c_short + '.'\
'Options are: ' + str(dset_opts)))
else:
raise(ValueError('Unknown option for classes'))
self.classes = classes
# for each class, create an LSUNClassDataset
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClassDataset(
db_path = db_path + '/' + c + '_lmdb',
transform = transform))
self.indices = []
count = 0
for db in self.dbs:
count += len(db)
self.indices.append(count)
self.length = count
self.target_transform = target_transform
def __getitem__(self, index):
target = 0
sub = 0
for ind in self.indices:
if index < ind:
break
target += 1
sub += ind
db = self.dbs[target]
index = index - sub
if self.target_transform is not None:
target = self.target_transform(target)
return db[index], target
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
if __name__ == '__main__':
#lsun = LSUNClassDataset(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb')
#a = lsun[0]
lsun = LSUNDataset(db_path='/home/soumith/local/lsun/train',
classes=['bedroom_train', 'church_outdoor_train'])
print(lsun.classes)
print(lsun.dbs)
a, t = lsun[len(lsun)-1]
print(a)
print(t)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment