Commit 432aa00d authored by Sasank Chilamkurthy's avatar Sasank Chilamkurthy Committed by Francisco Massa
Browse files

Improve torchvision documentation (#179)

* Add documentation for transforms

* document and remove unused imports in mnist.py

* document lsun, mscoco datasets

* rest of the datasets documented

* Clean up the documentation in other functions

* Add links for datasets

* Add more documentation

* pep8 fix
parent fa2836c2
...@@ -11,12 +11,10 @@ def set_image_backend(backend): ...@@ -11,12 +11,10 @@ def set_image_backend(backend):
""" """
Specifies the package used to load images. Specifies the package used to load images.
Options are 'PIL' and 'accimage'. The :mod:`accimage` package uses the
Intel IPP library. It is generally faster than PIL, but does not support as
many operations.
Args: Args:
backend (string): name of the image backend backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
The :mod:`accimage` package uses the Intel IPP library. It is
generally faster than PIL, but does not support as many operations.
""" """
global _image_backend global _image_backend
if backend not in ['PIL', 'accimage']: if backend not in ['PIL', 'accimage']:
......
...@@ -15,6 +15,22 @@ from .utils import download_url, check_integrity ...@@ -15,6 +15,22 @@ from .utils import download_url, check_integrity
class CIFAR10(data.Dataset): class CIFAR10(data.Dataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'cifar-10-batches-py' base_folder = 'cifar-10-batches-py'
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz"
...@@ -86,6 +102,13 @@ class CIFAR10(data.Dataset): ...@@ -86,6 +102,13 @@ class CIFAR10(data.Dataset):
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train: if self.train:
img, target = self.train_data[index], self.train_labels[index] img, target = self.train_data[index], self.train_labels[index]
else: else:
......
...@@ -5,7 +5,43 @@ import os.path ...@@ -5,7 +5,43 @@ import os.path
class CocoCaptions(data.Dataset): class CocoCaptions(data.Dataset):
"""`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
Example:
.. code:: python
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())
print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample
print("Image Size: ", img.size())
print(target)
Output: ::
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
"""
def __init__(self, root, annFile, transform=None, target_transform=None): def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.root = root self.root = root
...@@ -15,6 +51,13 @@ class CocoCaptions(data.Dataset): ...@@ -15,6 +51,13 @@ class CocoCaptions(data.Dataset):
self.target_transform = target_transform self.target_transform = target_transform
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
coco = self.coco coco = self.coco
img_id = self.ids[index] img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id) ann_ids = coco.getAnnIds(imgIds=img_id)
...@@ -37,6 +80,16 @@ class CocoCaptions(data.Dataset): ...@@ -37,6 +80,16 @@ class CocoCaptions(data.Dataset):
class CocoDetection(data.Dataset): class CocoDetection(data.Dataset):
"""`MS Coco Captions <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, annFile, transform=None, target_transform=None): def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO from pycocotools.coco import COCO
...@@ -47,6 +100,13 @@ class CocoDetection(data.Dataset): ...@@ -47,6 +100,13 @@ class CocoDetection(data.Dataset):
self.target_transform = target_transform self.target_transform = target_transform
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
"""
coco = self.coco coco = self.coco
img_id = self.ids[index] img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id) ann_ids = coco.getAnnIds(imgIds=img_id)
......
...@@ -63,6 +63,29 @@ def default_loader(path): ...@@ -63,6 +63,29 @@ def default_loader(path):
class ImageFolder(data.Dataset): class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None,
loader=default_loader): loader=default_loader):
...@@ -81,6 +104,13 @@ class ImageFolder(data.Dataset): ...@@ -81,6 +104,13 @@ class ImageFolder(data.Dataset):
self.loader = loader self.loader = loader
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index] path, target = self.imgs[index]
img = self.loader(path) img = self.loader(path)
if self.transform is not None: if self.transform is not None:
......
...@@ -12,7 +12,6 @@ else: ...@@ -12,7 +12,6 @@ else:
class LSUNClass(data.Dataset): class LSUNClass(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None): def __init__(self, db_path, transform=None, target_transform=None):
import lmdb import lmdb
self.db_path = db_path self.db_path = db_path
...@@ -58,8 +57,16 @@ class LSUNClass(data.Dataset): ...@@ -58,8 +57,16 @@ class LSUNClass(data.Dataset):
class LSUN(data.Dataset): class LSUN(data.Dataset):
""" """
db_path = root directory for the database files `LSUN <http://lsun.cs.princeton.edu>`_ dataset.
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
Args:
db_path (string): Root directory for the database files.
classes (string or list): One of {'train', 'val', 'test'} or a list of
categories to load. e,g. ['bedroom_train', 'church_train'].
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
""" """
def __init__(self, db_path, classes='train', def __init__(self, db_path, classes='train',
...@@ -108,6 +115,13 @@ class LSUN(data.Dataset): ...@@ -108,6 +115,13 @@ class LSUN(data.Dataset):
self.target_transform = target_transform self.target_transform = target_transform
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target) where target is the index of the target category.
"""
target = 0 target = 0
sub = 0 sub = 0
for ind in self.indices: for ind in self.indices:
......
...@@ -5,12 +5,25 @@ import os ...@@ -5,12 +5,25 @@ import os
import os.path import os.path
import errno import errno
import torch import torch
import json
import codecs import codecs
import numpy as np
class MNIST(data.Dataset): class MNIST(data.Dataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
urls = [ urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
...@@ -42,6 +55,13 @@ class MNIST(data.Dataset): ...@@ -42,6 +55,13 @@ class MNIST(data.Dataset):
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train: if self.train:
img, target = self.train_data[index], self.train_labels[index] img, target = self.train_data[index], self.train_labels[index]
else: else:
...@@ -70,6 +90,7 @@ class MNIST(data.Dataset): ...@@ -70,6 +90,7 @@ class MNIST(data.Dataset):
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
def download(self): def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib from six.moves import urllib
import gzip import gzip
......
...@@ -10,6 +10,19 @@ from .utils import download_url, check_integrity ...@@ -10,6 +10,19 @@ from .utils import download_url, check_integrity
class PhotoTour(data.Dataset): class PhotoTour(data.Dataset):
"""`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.
Args:
root (string): Root directory where images are.
name (string): Name of the dataset to load.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
urls = { urls = {
'notredame': [ 'notredame': [
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip', 'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip',
...@@ -59,6 +72,13 @@ class PhotoTour(data.Dataset): ...@@ -59,6 +72,13 @@ class PhotoTour(data.Dataset):
self.data, self.labels, self.matches = torch.load(self.data_file) self.data, self.labels, self.matches = torch.load(self.data_file)
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (data1, data2, matches)
"""
if self.train: if self.train:
data = self.data[index] data = self.data[index]
if self.transform is not None: if self.transform is not None:
......
...@@ -10,6 +10,22 @@ from .cifar import CIFAR10 ...@@ -10,6 +10,22 @@ from .cifar import CIFAR10
class STL10(CIFAR10): class STL10(CIFAR10):
"""`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``stl10_binary`` exists.
split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
Accordingly dataset is selected.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'stl10_binary' base_folder = 'stl10_binary'
url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz" url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
filename = "stl10_binary.tar.gz" filename = "stl10_binary.tar.gz"
...@@ -67,6 +83,13 @@ class STL10(CIFAR10): ...@@ -67,6 +83,13 @@ class STL10(CIFAR10):
self.classes = f.read().splitlines() self.classes = f.read().splitlines()
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.labels is not None: if self.labels is not None:
img, target = self.data[index], int(self.labels[index]) img, target = self.data[index], int(self.labels[index])
else: else:
......
...@@ -3,13 +3,27 @@ import torch.utils.data as data ...@@ -3,13 +3,27 @@ import torch.utils.data as data
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
import errno
import numpy as np import numpy as np
import sys
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
class SVHN(data.Dataset): class SVHN(data.Dataset):
"""`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``SVHN`` exists.
split (string): One of {'train', 'test', 'extra'}.
Accordingly dataset is selected. 'extra' is Extra training set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = "" url = ""
filename = "" filename = ""
file_md5 = "" file_md5 = ""
...@@ -56,6 +70,13 @@ class SVHN(data.Dataset): ...@@ -56,6 +70,13 @@ class SVHN(data.Dataset):
self.data = np.transpose(self.data, (3, 2, 0, 1)) self.data = np.transpose(self.data, (3, 2, 0, 1))
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.labels[index] img, target = self.data[index], self.labels[index]
# doing this so that it is consistent with all other datasets # doing this so that it is consistent with all other datasets
......
...@@ -27,6 +27,19 @@ PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing ...@@ -27,6 +27,19 @@ PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
resnet18 = models.resnet18(pretrained=True) resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True) alexnet = models.alexnet(pretrained=True)
All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
where H and W are expected to be atleast 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
You can use the following transform to normalize::
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
An example of such normalization can be found in the imagenet example
`here <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>`_
ImageNet 1-crop error rates (224x224) ImageNet 1-crop error rates (224x224)
================================ ============= ============= ================================ ============= =============
......
...@@ -17,7 +17,7 @@ class Compose(object): ...@@ -17,7 +17,7 @@ class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
Args: Args:
transforms (List[Transform]): list of transforms to compose. transforms (list of ``Transform`` objects): list of transforms to compose.
Example: Example:
>>> transforms.Compose([ >>> transforms.Compose([
...@@ -36,11 +36,20 @@ class Compose(object): ...@@ -36,11 +36,20 @@ class Compose(object):
class ToTensor(object): class ToTensor(object):
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
""" """
def __call__(self, pic): def __call__(self, pic):
"""
Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, np.ndarray): if isinstance(pic, np.ndarray):
# handle numpy array # handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1))) img = torch.from_numpy(pic.transpose((2, 0, 1)))
...@@ -77,11 +86,21 @@ class ToTensor(object): ...@@ -77,11 +86,21 @@ class ToTensor(object):
class ToPILImage(object): class ToPILImage(object):
"""Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape """Convert a tensor to PIL Image.
H x W x C to a PIL.Image while preserving value range.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving the value range.
""" """
def __call__(self, pic): def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.
Returns:
PIL.Image: Image converted to PIL.Image.
"""
npimg = pic npimg = pic
mode = None mode = None
if isinstance(pic, torch.FloatTensor): if isinstance(pic, torch.FloatTensor):
...@@ -108,9 +127,16 @@ class ToPILImage(object): ...@@ -108,9 +127,16 @@ class ToPILImage(object):
class Normalize(object): class Normalize(object):
"""Given mean: (R, G, B) and std: (R, G, B), """Normalize an tensor image with mean and standard deviation.
Given mean: (R, G, B) and std: (R, G, B),
will normalize each channel of the torch.*Tensor, i.e. will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std channel = (channel - mean) / std
Args:
mean (sequence): Sequence of means for R, G, B channels respecitvely.
std (sequence): Sequence of standard deviations for R, G, B channels
respecitvely.
""" """
def __init__(self, mean, std): def __init__(self, mean, std):
...@@ -118,6 +144,13 @@ class Normalize(object): ...@@ -118,6 +144,13 @@ class Normalize(object):
self.std = std self.std = std
def __call__(self, tensor): def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
# TODO: make efficient # TODO: make efficient
for t, m, s in zip(tensor, self.mean, self.std): for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s) t.sub_(m).div_(s)
...@@ -125,13 +158,16 @@ class Normalize(object): ...@@ -125,13 +158,16 @@ class Normalize(object):
class Scale(object): class Scale(object):
"""Rescales the input PIL.Image to the given 'size'. """Rescale the input PIL.Image to the given size.
If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale.
If 'size' is a number, it will indicate the size of the smaller edge. Args:
For example, if height > width, then image will be size (sequence or int): Desired output size. If size is a sequence like
rescaled to (size * height / width, size) (w, h), output size will be matched to this. If size is an int,
size: size of the exactly size or the smaller edge smaller edge of the image will be matched to this number.
interpolation: Default: PIL.Image.BILINEAR i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
""" """
def __init__(self, size, interpolation=Image.BILINEAR): def __init__(self, size, interpolation=Image.BILINEAR):
...@@ -140,6 +176,13 @@ class Scale(object): ...@@ -140,6 +176,13 @@ class Scale(object):
self.interpolation = interpolation self.interpolation = interpolation
def __call__(self, img): def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be scaled.
Returns:
PIL.Image: Rescaled image.
"""
if isinstance(self.size, int): if isinstance(self.size, int):
w, h = img.size w, h = img.size
if (w <= h and w == self.size) or (h <= w and h == self.size): if (w <= h and w == self.size) or (h <= w and h == self.size):
...@@ -157,9 +200,12 @@ class Scale(object): ...@@ -157,9 +200,12 @@ class Scale(object):
class CenterCrop(object): class CenterCrop(object):
"""Crops the given PIL.Image at the center to have a region of """Crops the given PIL.Image at the center.
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size) Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (w, h), a square crop (size, size) is
made.
""" """
def __init__(self, size): def __init__(self, size):
...@@ -169,6 +215,13 @@ class CenterCrop(object): ...@@ -169,6 +215,13 @@ class CenterCrop(object):
self.size = size self.size = size
def __call__(self, img): def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be cropped.
Returns:
PIL.Image: Cropped image.
"""
w, h = img.size w, h = img.size
th, tw = self.size th, tw = self.size
x1 = int(round((w - tw) / 2.)) x1 = int(round((w - tw) / 2.))
...@@ -177,7 +230,13 @@ class CenterCrop(object): ...@@ -177,7 +230,13 @@ class CenterCrop(object):
class Pad(object): class Pad(object):
"""Pads the given PIL.Image on all sides with the given "pad" value""" """Pad the given PIL.Image on all sides with the given "pad" value.
Args:
padding (int or sequence): Padding on each border. If a sequence of
length 4, it is used to pad left, top, right and bottom borders respectively.
fill: Pixel fill value. Default is 0.
"""
def __init__(self, padding, fill=0): def __init__(self, padding, fill=0):
assert isinstance(padding, numbers.Number) assert isinstance(padding, numbers.Number)
...@@ -186,11 +245,22 @@ class Pad(object): ...@@ -186,11 +245,22 @@ class Pad(object):
self.fill = fill self.fill = fill
def __call__(self, img): def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be padded.
Returns:
PIL.Image: Padded image.
"""
return ImageOps.expand(img, border=self.padding, fill=self.fill) return ImageOps.expand(img, border=self.padding, fill=self.fill)
class Lambda(object): class Lambda(object):
"""Applies a lambda as a transform.""" """Apply a user-defined lambda as a transform.
Args:
lambd (function): Lambda/function to be used for transform.
"""
def __init__(self, lambd): def __init__(self, lambd):
assert isinstance(lambd, types.LambdaType) assert isinstance(lambd, types.LambdaType)
...@@ -201,9 +271,16 @@ class Lambda(object): ...@@ -201,9 +271,16 @@ class Lambda(object):
class RandomCrop(object): class RandomCrop(object):
"""Crops the given PIL.Image at a random location to have a region of """Crop the given PIL.Image at a random location.
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size) Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (w, h), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is 0, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively.
""" """
def __init__(self, size, padding=0): def __init__(self, size, padding=0):
...@@ -214,6 +291,13 @@ class RandomCrop(object): ...@@ -214,6 +291,13 @@ class RandomCrop(object):
self.padding = padding self.padding = padding
def __call__(self, img): def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be cropped.
Returns:
PIL.Image: Cropped image.
"""
if self.padding > 0: if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0) img = ImageOps.expand(img, border=self.padding, fill=0)
...@@ -228,19 +312,30 @@ class RandomCrop(object): ...@@ -228,19 +312,30 @@ class RandomCrop(object):
class RandomHorizontalFlip(object): class RandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5 """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
"""
def __call__(self, img): def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be flipped.
Returns:
PIL.Image: Randomly flipped image.
"""
if random.random() < 0.5: if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT) return img.transpose(Image.FLIP_LEFT_RIGHT)
return img return img
class RandomSizedCrop(object): class RandomSizedCrop(object):
"""Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size """Crop the given PIL.Image to random size and aspect ratio.
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
This is popularly used to train the Inception networks A crop of random size of (0.08 to 1.0) of the original size and a random
aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: size of the smaller edge size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR interpolation: Default: PIL.Image.BILINEAR
""" """
......
...@@ -5,23 +5,25 @@ irange = range ...@@ -5,23 +5,25 @@ irange = range
def make_grid(tensor, nrow=8, padding=2, def make_grid(tensor, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0): normalize=False, range=None, scale_each=False, pad_value=0):
""" """Make a grid of images.
Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size,
makes a grid of images of size (B / nrow, nrow).
normalize=True will shift the image to the range (0, 1), Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
or a list of images all of the same size.
nrows (int, optional): Number of rows in grid. Final grid size is
(B / nrow, nrow). Default is 8.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by subtracting the minimum and dividing by the maximum pixel value. by subtracting the minimum and dividing by the maximum pixel value.
range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each(bool, optional): If True, scale each image in the batch of
images separately rather than the (min, max) over all images.
pad_value(float, optional): Value for the padded pixels.
if range=(min, max) where min and max are numbers, then these numbers are used to Example:
normalize the image. See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
scale_each=True will scale each image in the batch of images separately rather than
computing the (min, max) over all images.
pad_value=<float> sets the value for the padded pixels.
[Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
""" """
# if list of tensors, convert to a 4D mini-batch Tensor # if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list): if isinstance(tensor, list):
...@@ -82,11 +84,12 @@ def make_grid(tensor, nrow=8, padding=2, ...@@ -82,11 +84,12 @@ def make_grid(tensor, nrow=8, padding=2,
def save_image(tensor, filename, nrow=8, padding=2, def save_image(tensor, filename, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0): normalize=False, range=None, scale_each=False, pad_value=0):
""" """Save a given Tensor into an image file.
Saves a given Tensor into an image file.
If given a mini-batch tensor, will save the tensor as a grid of images by calling `make_grid`. Args:
All options after `filename` are passed through to `make_grid`. Refer to it's documentation for tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
more details saves the tensor as a grid of images by calling ``make_grid``.
**kwargs: Other arguments are documented in ``make_grid``.
""" """
from PIL import Image from PIL import Image
tensor = tensor.cpu() tensor = tensor.cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment