Commit 432aa00d authored by Sasank Chilamkurthy's avatar Sasank Chilamkurthy Committed by Francisco Massa
Browse files

Improve torchvision documentation (#179)

* Add documentation for transforms

* document and remove unused imports in mnist.py

* document lsun, mscoco datasets

* rest of the datasets documented

* Clean up the documentation in other functions

* Add links for datasets

* Add more documentation

* pep8 fix
parent fa2836c2
......@@ -11,12 +11,10 @@ def set_image_backend(backend):
"""
Specifies the package used to load images.
Options are 'PIL' and 'accimage'. The :mod:`accimage` package uses the
Intel IPP library. It is generally faster than PIL, but does not support as
many operations.
Args:
backend (string): name of the image backend
backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
The :mod:`accimage` package uses the Intel IPP library. It is
generally faster than PIL, but does not support as many operations.
"""
global _image_backend
if backend not in ['PIL', 'accimage']:
......
......@@ -15,6 +15,22 @@ from .utils import download_url, check_integrity
class CIFAR10(data.Dataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'cifar-10-batches-py'
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
......@@ -86,6 +102,13 @@ class CIFAR10(data.Dataset):
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
......
......@@ -5,7 +5,43 @@ import os.path
class CocoCaptions(data.Dataset):
"""`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
Example:
.. code:: python
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())
print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample
print("Image Size: ", img.size())
print(target)
Output: ::
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
"""
def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO
self.root = root
......@@ -15,6 +51,13 @@ class CocoCaptions(data.Dataset):
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
......@@ -37,6 +80,16 @@ class CocoCaptions(data.Dataset):
class CocoDetection(data.Dataset):
"""`MS Coco Captions <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO
......@@ -47,6 +100,13 @@ class CocoDetection(data.Dataset):
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
......
......@@ -63,6 +63,29 @@ def default_loader(path):
class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
......@@ -81,6 +104,13 @@ class ImageFolder(data.Dataset):
self.loader = loader
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
......
......@@ -12,7 +12,6 @@ else:
class LSUNClass(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
import lmdb
self.db_path = db_path
......@@ -58,8 +57,16 @@ class LSUNClass(data.Dataset):
class LSUN(data.Dataset):
"""
db_path = root directory for the database files
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
`LSUN <http://lsun.cs.princeton.edu>`_ dataset.
Args:
db_path (string): Root directory for the database files.
classes (string or list): One of {'train', 'val', 'test'} or a list of
categories to load. e,g. ['bedroom_train', 'church_train'].
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, db_path, classes='train',
......@@ -108,6 +115,13 @@ class LSUN(data.Dataset):
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target) where target is the index of the target category.
"""
target = 0
sub = 0
for ind in self.indices:
......
......@@ -5,12 +5,25 @@ import os
import os.path
import errno
import torch
import json
import codecs
import numpy as np
class MNIST(data.Dataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
......@@ -42,6 +55,13 @@ class MNIST(data.Dataset):
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
......@@ -70,6 +90,7 @@ class MNIST(data.Dataset):
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
......
......@@ -10,6 +10,19 @@ from .utils import download_url, check_integrity
class PhotoTour(data.Dataset):
"""`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.
Args:
root (string): Root directory where images are.
name (string): Name of the dataset to load.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
urls = {
'notredame': [
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip',
......@@ -59,6 +72,13 @@ class PhotoTour(data.Dataset):
self.data, self.labels, self.matches = torch.load(self.data_file)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (data1, data2, matches)
"""
if self.train:
data = self.data[index]
if self.transform is not None:
......
......@@ -10,6 +10,22 @@ from .cifar import CIFAR10
class STL10(CIFAR10):
"""`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``stl10_binary`` exists.
split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
Accordingly dataset is selected.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'stl10_binary'
url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
filename = "stl10_binary.tar.gz"
......@@ -67,6 +83,13 @@ class STL10(CIFAR10):
self.classes = f.read().splitlines()
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.labels is not None:
img, target = self.data[index], int(self.labels[index])
else:
......
......@@ -3,13 +3,27 @@ import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
from .utils import download_url, check_integrity
class SVHN(data.Dataset):
"""`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``SVHN`` exists.
split (string): One of {'train', 'test', 'extra'}.
Accordingly dataset is selected. 'extra' is Extra training set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = ""
filename = ""
file_md5 = ""
......@@ -56,6 +70,13 @@ class SVHN(data.Dataset):
self.data = np.transpose(self.data, (3, 2, 0, 1))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.labels[index]
# doing this so that it is consistent with all other datasets
......
......@@ -27,6 +27,19 @@ PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
where H and W are expected to be atleast 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
You can use the following transform to normalize::
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
An example of such normalization can be found in the imagenet example
`here <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>`_
ImageNet 1-crop error rates (224x224)
================================ ============= =============
......
......@@ -17,7 +17,7 @@ class Compose(object):
"""Composes several transforms together.
Args:
transforms (List[Transform]): list of transforms to compose.
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
......@@ -36,11 +36,20 @@ class Compose(object):
class ToTensor(object):
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
"""
Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
......@@ -77,11 +86,21 @@ class ToTensor(object):
class ToPILImage(object):
"""Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving value range.
"""Convert a tensor to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving the value range.
"""
def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.
Returns:
PIL.Image: Image converted to PIL.Image.
"""
npimg = pic
mode = None
if isinstance(pic, torch.FloatTensor):
......@@ -108,9 +127,16 @@ class ToPILImage(object):
class Normalize(object):
"""Given mean: (R, G, B) and std: (R, G, B),
"""Normalize an tensor image with mean and standard deviation.
Given mean: (R, G, B) and std: (R, G, B),
will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std
Args:
mean (sequence): Sequence of means for R, G, B channels respecitvely.
std (sequence): Sequence of standard deviations for R, G, B channels
respecitvely.
"""
def __init__(self, mean, std):
......@@ -118,6 +144,13 @@ class Normalize(object):
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
# TODO: make efficient
for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s)
......@@ -125,13 +158,16 @@ class Normalize(object):
class Scale(object):
"""Rescales the input PIL.Image to the given 'size'.
If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale.
If 'size' is a number, it will indicate the size of the smaller edge.
For example, if height > width, then image will be
rescaled to (size * height / width, size)
size: size of the exactly size or the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""Rescale the input PIL.Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(w, h), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
......@@ -140,6 +176,13 @@ class Scale(object):
self.interpolation = interpolation
def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be scaled.
Returns:
PIL.Image: Rescaled image.
"""
if isinstance(self.size, int):
w, h = img.size
if (w <= h and w == self.size) or (h <= w and h == self.size):
......@@ -157,9 +200,12 @@ class Scale(object):
class CenterCrop(object):
"""Crops the given PIL.Image at the center to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""Crops the given PIL.Image at the center.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (w, h), a square crop (size, size) is
made.
"""
def __init__(self, size):
......@@ -169,6 +215,13 @@ class CenterCrop(object):
self.size = size
def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be cropped.
Returns:
PIL.Image: Cropped image.
"""
w, h = img.size
th, tw = self.size
x1 = int(round((w - tw) / 2.))
......@@ -177,7 +230,13 @@ class CenterCrop(object):
class Pad(object):
"""Pads the given PIL.Image on all sides with the given "pad" value"""
"""Pad the given PIL.Image on all sides with the given "pad" value.
Args:
padding (int or sequence): Padding on each border. If a sequence of
length 4, it is used to pad left, top, right and bottom borders respectively.
fill: Pixel fill value. Default is 0.
"""
def __init__(self, padding, fill=0):
assert isinstance(padding, numbers.Number)
......@@ -186,11 +245,22 @@ class Pad(object):
self.fill = fill
def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be padded.
Returns:
PIL.Image: Padded image.
"""
return ImageOps.expand(img, border=self.padding, fill=self.fill)
class Lambda(object):
"""Applies a lambda as a transform."""
"""Apply a user-defined lambda as a transform.
Args:
lambd (function): Lambda/function to be used for transform.
"""
def __init__(self, lambd):
assert isinstance(lambd, types.LambdaType)
......@@ -201,9 +271,16 @@ class Lambda(object):
class RandomCrop(object):
"""Crops the given PIL.Image at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""Crop the given PIL.Image at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (w, h), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is 0, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively.
"""
def __init__(self, size, padding=0):
......@@ -214,6 +291,13 @@ class RandomCrop(object):
self.padding = padding
def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be cropped.
Returns:
PIL.Image: Cropped image.
"""
if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0)
......@@ -228,19 +312,30 @@ class RandomCrop(object):
class RandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
"""Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be flipped.
Returns:
PIL.Image: Randomly flipped image.
"""
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
class RandomSizedCrop(object):
"""Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
This is popularly used to train the Inception networks
"""Crop the given PIL.Image to random size and aspect ratio.
A crop of random size of (0.08 to 1.0) of the original size and a random
aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
......
......@@ -5,23 +5,25 @@ irange = range
def make_grid(tensor, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0):
"""
Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size,
makes a grid of images of size (B / nrow, nrow).
"""Make a grid of images.
normalize=True will shift the image to the range (0, 1),
Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
or a list of images all of the same size.
nrows (int, optional): Number of rows in grid. Final grid size is
(B / nrow, nrow). Default is 8.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by subtracting the minimum and dividing by the maximum pixel value.
range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each(bool, optional): If True, scale each image in the batch of
images separately rather than the (min, max) over all images.
pad_value(float, optional): Value for the padded pixels.
if range=(min, max) where min and max are numbers, then these numbers are used to
normalize the image.
scale_each=True will scale each image in the batch of images separately rather than
computing the (min, max) over all images.
pad_value=<float> sets the value for the padded pixels.
Example:
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
[Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
"""
# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
......@@ -82,11 +84,12 @@ def make_grid(tensor, nrow=8, padding=2,
def save_image(tensor, filename, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0):
"""
Saves a given Tensor into an image file.
If given a mini-batch tensor, will save the tensor as a grid of images by calling `make_grid`.
All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
more details
"""Save a given Tensor into an image file.
Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``.
**kwargs: Other arguments are documented in ``make_grid``.
"""
from PIL import Image
tensor = tensor.cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment