Commit d9b8d003 authored by soumith's avatar soumith
Browse files

readme and improvements

parent 62cfcb70
...@@ -20,3 +20,97 @@ From Source: ...@@ -20,3 +20,97 @@ From Source:
pip install -r requirements.txt pip install -r requirements.txt
pip install . pip install .
``` ```
# Datasets
Datasets have the API:
- `__getitem__`
- `__len__`
They all subclass from `torch.utils.data.Dataset`
Hence, they can all be multi-threaded (python multiprocessing) using standard torch.utils.data.DataLoader.
For example:
`torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)`
In the constructor, each dataset has a slightly different API as needed, but they all take the keyword args:
- `transform` - a function that takes in an image and returns a transformed version
- common stuff like `ToTensor`, `RandomCrop`, etc. These can be composed together with `transforms.Compose` (see transforms section below)
- `target_transform` - a function that takes in the target and transforms it. For example, take in the caption string and return a tensor of word indices.
The following datasets are available:
- COCO (Captioning and Detection)
- LSUN Classification
- Imagenet-12
- ImageFolder
### COCO
This requires the [COCO API to be installed](https://github.com/pdollar/coco/tree/master/PythonAPI)
#### Captions:
`dset.CocoCaptions(root="dir where images are", annFile="json annotation file", [transform, target_transform])`
Example:
```python
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are', annFile = 'json annotation file', transform=transforms.toTensor)
print('Number of samples:', len(cap))
img, target = cap[3] # load 4th sample
print(img.size())
print(target)
```
Output:
```
```
#### Detection:
`dset.CocoDetection(root="dir where images are", annFile="json annotation file", [transform, target_transform])`
### LSUN
`dset.LSUN(db_path, classes='train', [transform, target_transform])`
- db_path = root directory for the database files
- classes =
- 'train' - all categories, training set
- 'val' - all categories, validation set
- 'test' - all categories, test set
- ['bedroom_train', 'church_train', ...] : a list of categories to load
### ImageFolder
A generic data loader where the images are arranged in this way:
```
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
```
`dset.ImageFolder(root="root folder path", [transform, target_transform])`
It has the members:
- `self.classes` - The class names as a list
- `self.class_to_idx` - Corresponding class indices
- `self.imgs` - The list of (image path, class-index) tuples
### Imagenet-12
This is simply implemented with an ImageFolder dataset, after the data is preprocessed [as described here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset)
...@@ -9,21 +9,6 @@ VERSION = '0.1.5' ...@@ -9,21 +9,6 @@ VERSION = '0.1.5'
long_description = '''torch-vision provides DataLoaders, Pre-trained models long_description = '''torch-vision provides DataLoaders, Pre-trained models
and common transforms for torch for images and videos''' and common transforms for torch for images and videos'''
excluded = ['test']
def exclude_package(pkg):
for exclude in excluded:
if pkg.startswith(exclude):
return True
return False
def create_package_list(base_package):
return ([base_package] +
[base_package + '.' + pkg
for pkg
in find_packages(base_package)
if not exclude_package(pkg)])
setup_info = dict( setup_info = dict(
# Metadata # Metadata
name='torchvision', name='torchvision',
...@@ -36,7 +21,7 @@ setup_info = dict( ...@@ -36,7 +21,7 @@ setup_info = dict(
license='BSD', license='BSD',
# Package info # Package info
packages=find_packages(exclude=('test',)), #create_package_list('torchvision'), packages=find_packages(exclude=('test',)),
zip_safe=True, zip_safe=True,
) )
......
import torch import torch
import torchvision import torchvision
import torchvision.datasets as dset
import torchvision.transforms
...@@ -2,4 +2,6 @@ from .lsun import LSUNDataset, LSUNClassDataset ...@@ -2,4 +2,6 @@ from .lsun import LSUNDataset, LSUNClassDataset
from .folder import ImageFolderDataset from .folder import ImageFolderDataset
from .coco import CocoCaptionsDataset, CocoDetectionDataset from .coco import CocoCaptionsDataset, CocoDetectionDataset
__all__ = ('LSUNDataset', 'LSUNClassDataset') __all__ = ('LSUNDataset', 'LSUNClassDataset',
'ImageFolderDataset',
'CocoCaptionsDataset', 'CocoDetectionDataset')
...@@ -3,7 +3,7 @@ from PIL import Image ...@@ -3,7 +3,7 @@ from PIL import Image
import os import os
import os.path import os.path
class CocoCaptionsDataset(data.Dataset): class CocoCaptions(data.Dataset):
def __init__(self, root, annFile, transform=None, target_transform=None): def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.root = root self.root = root
...@@ -33,7 +33,7 @@ class CocoCaptionsDataset(data.Dataset): ...@@ -33,7 +33,7 @@ class CocoCaptionsDataset(data.Dataset):
def __len__(self): def __len__(self):
return len(self.ids) return len(self.ids)
class CocoDetectionDataset(data.Dataset): class CocoDetection(data.Dataset):
def __init__(self, root, annFile, transform=None, target_transform=None): def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.root = root self.root = root
......
...@@ -33,7 +33,7 @@ def make_dataset(dir, class_to_idx): ...@@ -33,7 +33,7 @@ def make_dataset(dir, class_to_idx):
return images return images
class ImageFolderDataset(data.Dataset): class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None): def __init__(self, root, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root) classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx) imgs = make_dataset(root, class_to_idx)
......
...@@ -10,7 +10,7 @@ if sys.version_info[0] == 2: ...@@ -10,7 +10,7 @@ if sys.version_info[0] == 2:
else: else:
import pickle import pickle
class LSUNClassDataset(data.Dataset): class LSUNClass(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None): def __init__(self, db_path, transform=None, target_transform=None):
import lmdb import lmdb
self.db_path = db_path self.db_path = db_path
...@@ -53,7 +53,7 @@ class LSUNClassDataset(data.Dataset): ...@@ -53,7 +53,7 @@ class LSUNClassDataset(data.Dataset):
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')' return self.__class__.__name__ + ' (' + self.db_path + ')'
class LSUNDataset(data.Dataset): class LSUN(data.Dataset):
""" """
db_path = root directory for the database files db_path = root directory for the database files
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...] classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
......
import torch
import math
import random
from PIL import Image
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
class ToTensor(object):
def __call__(self, pic):
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[0], pic.size[1], 3)
# put it in CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 2).transpose(1, 2).contiguous()
return img.float()
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s)
return tensor
class Scale(object):
"Scales the smaller edge to size"
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
w, h = img.size
if (w <= h and w == self.size) or (h <= w and h == self.size):
return img
if w < h:
return img.resize((w, int(round(h / w * self.size))), self.interpolation)
else:
return img.resize((int(round(w / h * self.size)), h), self.interpolation)
class CenterCrop(object):
"Crop to centered rectangle"
def __init__(self, size):
self.size = size
def __call__(self, img):
w, h = img.size
x1 = int(round((w - self.size) / 2))
y1 = int(round((h - self.size) / 2))
return img.crop((x1, y1, x1 + self.size, y1 + self.size))
class RandomCrop(object):
"Random crop form larger image with optional zero padding"
def __init__(self, size, padding=0):
self.size = size
self.padding = padding
def __call__(self, img):
if self.padding > 0:
raise NotImplementedError()
w, h = img.size
if w == self.size and h == self.size:
return img
x1 = random.randint(0, w - self.size)
y1 = random.randint(0, h - self.size)
return img.crop((x1, y1, x1 + self.size, y1 + self.size))
class RandomHorizontalFlip(object):
"Horizontal flip with 0.5 probability"
def __call__(self, img):
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
class RandomSizedCrop(object):
"Random crop with size 0.08-1 and aspect ratio 3/4 - 4/3 (Inception-style)"
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3 / 4, 4 / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)
img = img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))
return img.resize((self.size, self.size), self.interpolation)
# Fallback
scale = Scale(self.size, interpolation=self.interpolation)
crop = CenterCrop(self.size)
return crop(scale(img))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment