Commit 9fc7edfb authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #285 from chsasank/sphinx-docs

Sphinx docs
parents cee28367 c6738109
...@@ -4,4 +4,5 @@ torchvision.egg-info/ ...@@ -4,4 +4,5 @@ torchvision.egg-info/
*/**/__pycache__ */**/__pycache__
*/**/*.pyc */**/*.pyc
*/**/*~ */**/*~
*~ *~
\ No newline at end of file docs/build
\ No newline at end of file
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SPHINXPROJ = torchvision
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
docset: html
doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/vision/ --force $(BUILDDIR)/html/
# Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution.
cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png
convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png
.PHONY: help Makefile docset
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
set SPHINXPROJ=torchvision
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
:end
popd
sphinx
sphinxcontrib-googleanalytics
-e git://github.com/snide/sphinx_rtd_theme.git#egg=sphinx_rtd_theme
body {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}
/* Default header fonts are ugly */
h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}
/* Use white for docs background */
.wy-side-nav-search {
background-color: #fff;
}
.wy-nav-content-wrap, .wy-menu li.current > a {
background-color: #fff;
}
@media screen and (min-width: 1400px) {
.wy-nav-content-wrap {
background-color: rgba(0, 0, 0, 0.0470588);
}
.wy-nav-content {
background-color: #fff;
}
}
/* Fixes for mobile */
.wy-nav-top {
background-color: #fff;
background-image: url('../img/pytorch-logo-dark.svg');
background-repeat: no-repeat;
background-position: center;
padding: 0;
margin: 0.4045em 0.809em;
color: #333;
}
.wy-nav-top > a {
display: none;
}
@media screen and (max-width: 768px) {
.wy-side-nav-search>a img.logo {
height: 60px;
}
}
/* This is needed to ensure that logo above search scales properly */
.wy-side-nav-search a {
display: block;
}
/* This ensures that multiple constructors will remain in separate lines. */
.rst-content dl:not(.docutils) dt {
display: table;
}
/* Use our red for literals (it's very similar to the original color) */
.rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal {
color: #F05732;
}
.rst-content tt.xref, a .rst-content tt, .rst-content tt.xref,
.rst-content code.xref, a .rst-content tt, a .rst-content code {
color: #404040;
}
/* Change link colors (except for the menu) */
a {
color: #F05732;
}
a:hover {
color: #F05732;
}
a:visited {
color: #D44D2C;
}
.wy-menu a {
color: #b3b3b3;
}
.wy-menu a:hover {
color: #b3b3b3;
}
/* Default footer text is quite big */
footer {
font-size: 80%;
}
footer .rst-footer-buttons {
font-size: 125%; /* revert footer settings - 1/80% = 125% */
}
footer p {
font-size: 100%;
}
/* For hidden headers that appear in TOC tree */
/* see http://stackoverflow.com/a/32363545/3343043 */
.rst-content .hidden-section {
display: none;
}
nav .hidden-section {
display: inherit;
}
.wy-side-nav-search>div.version {
color: #000;
}
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 21.0.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 199.7 40.2" style="enable-background:new 0 0 199.7 40.2;" xml:space="preserve">
<style type="text/css">
.st0{fill:#F05732;}
.st1{fill:#9E529F;}
.st2{fill:#333333;}
</style>
<path class="st0" d="M102.7,12.2c-1.3-1-1.8,3.9-4.4,3.9c-3,0-4-13-6.3-13c-0.7,0-0.8-0.4-7.9,21.3c-2.9,9,4.4,15.8,11.8,15.8
c4.6,0,12.3-3,12.3-12.6C108.2,20.5,104.7,13.7,102.7,12.2z M95.8,35.3c-3.7,0-6.7-3.1-6.7-7c0-3.9,3-7,6.7-7s6.7,3.1,6.7,7
C102.5,32.1,99.5,35.3,95.8,35.3z"/>
<path class="st1" d="M99.8,0c-0.5,0-1.8,2.5-1.8,3.6c0,1.5,1,2,1.8,2c0.8,0,1.8-0.5,1.8-2C101.5,2.5,100.2,0,99.8,0z"/>
<path class="st2" d="M0,39.5V14.9h11.5c5.3,0,8.3,3.6,8.3,7.9c0,4.3-3,7.9-8.3,7.9H5.2v8.8H0z M14.4,22.8c0-2.1-1.6-3.3-3.7-3.3H5.2
v6.6h5.5C12.8,26.1,14.4,24.8,14.4,22.8z"/>
<path class="st2" d="M35.2,39.5V29.4l-9.4-14.5h6l6.1,9.8l6.1-9.8h5.9l-9.4,14.5v10.1H35.2z"/>
<path class="st2" d="M63.3,39.5v-20h-7.2v-4.6h19.6v4.6h-7.2v20H63.3z"/>
<path class="st2" d="M131.4,39.5l-4.8-8.7h-3.8v8.7h-5.2V14.9H129c5.1,0,8.3,3.4,8.3,7.9c0,4.3-2.8,6.7-5.4,7.3l5.6,9.4H131.4z
M131.9,22.8c0-2-1.6-3.3-3.7-3.3h-5.5v6.6h5.5C130.3,26.1,131.9,24.9,131.9,22.8z"/>
<path class="st2" d="M145.6,27.2c0-7.6,5.7-12.7,13.1-12.7c5.4,0,8.5,2.9,10.3,6l-4.5,2.2c-1-2-3.2-3.6-5.8-3.6
c-4.5,0-7.7,3.4-7.7,8.1c0,4.6,3.2,8.1,7.7,8.1c2.5,0,4.7-1.6,5.8-3.6l4.5,2.2c-1.7,3.1-4.9,6-10.3,6
C151.3,39.9,145.6,34.7,145.6,27.2z"/>
<path class="st2" d="M194.5,39.5V29.1h-11.6v10.4h-5.2V14.9h5.2v9.7h11.6v-9.7h5.3v24.6H194.5z"/>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
height="40.200001"
width="40.200001"
xml:space="preserve"
viewBox="0 0 40.200002 40.2"
y="0px"
x="0px"
id="Layer_1"
version="1.1"><metadata
id="metadata4717"><rdf:RDF><cc:Work
rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
id="defs4715" /><style
id="style4694"
type="text/css">
.st0{fill:#F05732;}
.st1{fill:#9E529F;}
.st2{fill:#333333;}
</style><path
style="fill:#f05732"
id="path4696"
d="m 26.975479,12.199999 c -1.3,-1 -1.8,3.9 -4.4,3.9 -3,0 -4,-12.9999998 -6.3,-12.9999998 -0.7,0 -0.8,-0.4 -7.9000003,21.2999998 -2.9000001,9 4.4000003,15.8 11.8000003,15.8 4.6,0 12.3,-3 12.3,-12.6 0,-7.1 -3.5,-13.9 -5.5,-15.4 z m -6.9,23.1 c -3.7,0 -6.7,-3.1 -6.7,-7 0,-3.9 3,-7 6.7,-7 3.7,0 6.7,3.1 6.7,7 0,3.8 -3,7 -6.7,7 z"
class="st0" /><path
style="fill:#9e529f"
id="path4698"
d="m 24.075479,-7.6293945e-7 c -0.5,0 -1.8,2.49999996293945 -1.8,3.59999996293945 0,1.5 1,2 1.8,2 0.8,0 1.8,-0.5 1.8,-2 -0.1,-1.1 -1.4,-3.59999996293945 -1.8,-3.59999996293945 z"
class="st1" /></svg>
\ No newline at end of file
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
import torch
import torchvision
import sphinx_rtd_theme
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinxcontrib.googleanalytics',
]
napoleon_use_ivar = True
googleanalytics_id = 'UA-90545585-1'
googleanalytics_enabled = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = 'Torchvision'
copyright = '2017, Torch Contributors'
author = 'Torch Contributors'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
# TODO: change to [:2] at v1.0
version = 'master (' + torchvision.__version__ + ' )'
# The full version, including alpha/beta/rc tags.
# TODO: verify this works as expected
release = 'master'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
html_theme_options = {
'collapse_navigation': False,
'display_version': True,
'logo_only': True,
}
html_logo = '_static/img/pytorch-logo-dark.svg'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# html_style_path = 'css/pytorch_theme.css'
html_context = {
'css_files': [
'https://fonts.googleapis.com/css?family=Lato',
'_static/css/pytorch_theme.css'
],
}
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'pytorch.tex', 'torchvision Documentation',
'Torch Contributors', 'manual'),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchvision', 'torchvision Documentation',
[author], 1)
]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'torchvision', 'torchvision Documentation',
author, 'torchvision', 'One line description of project.',
'Miscellaneous'),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
}
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes
def patched_make_field(self, types, domain, items, **kw):
# `kw` catches `env=None` needed for newer sphinx while maintaining
# backwards compatibility when passed along further down!
# type: (List, unicode, Tuple) -> nodes.field
def handle_item(fieldarg, content):
par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong))
if fieldarg in types:
par += nodes.Text(' (')
# NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int')
typename = typename.replace('long', 'python:long')
typename = typename.replace('float', 'python:float')
typename = typename.replace('type', 'python:type')
par.extend(self.make_xrefs(self.typerolename, domain, typename,
addnodes.literal_emphasis, **kw))
else:
par += fieldtype
par += nodes.Text(')')
par += nodes.Text(' -- ')
par += content
return par
fieldname = nodes.field_name('', self.label)
if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content)
else:
bodynode = self.list_type()
for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode)
return nodes.field('', fieldname, fieldbody)
TypedField.make_field = patched_make_field
torchvision.datasets
====================
All datasets are subclasses of :class:`torch.utils.data.Dataset`
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
Hence, they can all be passed to a :class:`torch.utils.data.DataLoader`
which can load multiple samples parallelly using ``torch.multiprocessing`` workers.
For example: ::
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
The following datasets are available:
.. contents:: Datasets
:local:
All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively.
.. currentmodule:: torchvision.datasets
MNIST
~~~~~
.. autoclass:: MNIST
COCO
~~~~
.. note ::
These require the `COCO API to be installed`_
.. _COCO API to be installed: https://github.com/pdollar/coco/tree/master/PythonAPI
Captions
^^^^^^^^
.. autoclass:: CocoCaptions
:members: __getitem__
:special-members:
Detection
^^^^^^^^^
.. autoclass:: CocoDetection
:members: __getitem__
:special-members:
LSUN
~~~~
.. autoclass:: LSUN
:members: __getitem__
:special-members:
ImageFolder
~~~~~~~~~~~
.. autoclass:: ImageFolder
:members: __getitem__
:special-members:
Imagenet-12
~~~~~~~~~~~
This should simply be implemented with an ``ImageFolder`` dataset.
The data is preprocessed `as described
here <https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset>`__
`Here is an
example <https://github.com/pytorch/examples/blob/27e2a46c1d1505324032b1d94fc6ce24d5b67e97/imagenet/main.py#L48-L62>`__.
CIFAR
~~~~~
.. autoclass:: CIFAR10
:members: __getitem__
:special-members:
STL10
~~~~~
.. autoclass:: STL10
:members: __getitem__
:special-members:
SVHN
~~~~~
.. autoclass:: SVHN
:members: __getitem__
:special-members:
PhotoTour
~~~~~~~~~
.. autoclass:: PhotoTour
:members: __getitem__
:special-members:
torchvision
===========
The :mod:`torchvision` package consists of popular datasets, model
architectures, and common image transformations for computer vision.
.. toctree::
:maxdepth: 2
:caption: Package Reference
datasets
models
transforms
utils
.. automodule:: torchvision
:members:
torchvision.models
==================
The models subpackage contains definitions for the following model
architectures:
- `AlexNet`_
- `VGG`_
- `ResNet`_
- `SqueezeNet`_
- `DenseNet`_
- `Inception`_ v3
You can construct a model with random weights by calling its constructor:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()
inception = models.inception_v3()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet_161(pretrained=True)
inception = models.inception_v3(pretrained=True)
All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
where H and W are expected to be at least 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
You can use the following transform to normalize::
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
An example of such normalization can be found in the imagenet example
`here <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>`_
ImageNet 1-crop error rates (224x224)
================================ ============= =============
Network Top-1 error Top-5 error
================================ ============= =============
AlexNet 43.45 20.91
VGG-11 30.98 11.37
VGG-13 30.07 10.75
VGG-16 28.41 9.62
VGG-19 27.62 9.12
VGG-11 with batch normalization 29.62 10.19
VGG-13 with batch normalization 28.45 9.63
VGG-16 with batch normalization 26.63 8.50
VGG-19 with batch normalization 25.76 8.15
ResNet-18 30.24 10.92
ResNet-34 26.70 8.58
ResNet-50 23.85 7.13
ResNet-101 22.63 6.44
ResNet-152 21.69 5.94
SqueezeNet 1.0 41.90 19.58
SqueezeNet 1.1 41.81 19.38
Densenet-121 25.35 7.83
Densenet-169 24.00 7.00
Densenet-201 22.80 6.43
Densenet-161 22.35 6.20
Inception v3 22.55 6.44
================================ ============= =============
.. _AlexNet: https://arxiv.org/abs/1404.5997
.. _VGG: https://arxiv.org/abs/1409.1556
.. _ResNet: https://arxiv.org/abs/1512.03385
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
.. _DenseNet: https://arxiv.org/abs/1608.06993
.. _Inception: https://arxiv.org/abs/1512.00567
.. currentmodule:: torchvision.models
Alexnet
-------
.. autofunction:: alexnet
VGG
---
.. autofunction:: vgg11
.. autofunction:: vgg11_bn
.. autofunction:: vgg13
.. autofunction:: vgg13_bn
.. autofunction:: vgg16
.. autofunction:: vgg16_bn
.. autofunction:: vgg19
.. autofunction:: vgg19_bn
ResNet
------
.. autofunction:: resnet18
.. autofunction:: resnet34
.. autofunction:: resnet50
.. autofunction:: resnet101
.. autofunction:: resnet152
SqueezeNet
----------
.. autofunction:: squeezenet1_0
.. autofunction:: squeezenet1_1
DensetNet
---------
.. autofunction:: densenet121
.. autofunction:: densenet169
.. autofunction:: densenet161
.. autofunction:: densenet201
Inception v3
------------
.. autofunction:: inception_v3
torchvision.transforms
======================
.. currentmodule:: torchvision.transforms
Transforms are common image transforms. They can be chained together using :class:`Compose`
.. autoclass:: Compose
Transforms on PIL Image
-----------------------
.. autoclass:: Resize
.. autoclass:: Scale
.. autoclass:: CenterCrop
.. autoclass:: RandomCrop
.. autoclass:: RandomHorizontalFlip
.. autoclass:: RandomVerticalFlip
.. autoclass:: RandomResizedCrop
.. autoclass:: RandomSizedCrop
.. autoclass:: FiveCrop
.. autoclass:: TenCrop
.. autoclass:: Pad
.. autoclass:: ColorJitter
Transforms on torch.\*Tensor
----------------------------
.. autoclass:: Normalize
:members: __call__
:special-members:
Conversion Transforms
---------------------
.. autoclass:: ToTensor
:members: __call__
:special-members:
.. autoclass:: ToPILImage
:members: __call__
:special-members:
Generic Transforms
------------------
.. autoclass:: Lambda
torchvision.utils
=================
.. currentmodule:: torchvision.utils
.. autofunction:: make_grid
.. autofunction:: save_image
#!/usr/bin/env python #!/usr/bin/env python
import os import os
import io
import re
import shutil import shutil
import sys import sys
from setuptools import setup, find_packages from setuptools import setup, find_packages
def read(*names, **kwargs):
with io.open(
os.path.join(os.path.dirname(__file__), *names),
encoding=kwargs.get("encoding", "utf8")
) as fp:
return fp.read()
def find_version(*file_paths):
version_file = read(*file_paths)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
readme = open('README.rst').read() readme = open('README.rst').read()
VERSION = '0.1.9' VERSION = find_version('torchvision', '__init__.py')
requirements = [ requirements = [
'numpy', 'numpy',
......
...@@ -3,6 +3,7 @@ from torchvision import datasets ...@@ -3,6 +3,7 @@ from torchvision import datasets
from torchvision import transforms from torchvision import transforms
from torchvision import utils from torchvision import utils
__version__ = '0.1.9'
_image_backend = 'PIL' _image_backend = 'PIL'
......
"""The models subpackage contains definitions for the following model
architectures:
- `AlexNet`_
- `VGG`_
- `ResNet`_
- `SqueezeNet`_
- `DenseNet`_
- `Inception`_ v3
You can construct a model with random weights by calling its constructor:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()
inception = models.inception_v3()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet_161(pretrained=True)
inception = models.inception_v3(pretrained=True)
All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
where H and W are expected to be at least 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
You can use the following transform to normalize::
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
An example of such normalization can be found in the imagenet example
`here <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>`_
ImageNet 1-crop error rates (224x224)
================================ ============= =============
Network Top-1 error Top-5 error
================================ ============= =============
ResNet-18 30.24 10.92
ResNet-34 26.70 8.58
ResNet-50 23.85 7.13
ResNet-101 22.63 6.44
ResNet-152 21.69 5.94
Inception v3 22.55 6.44
AlexNet 43.45 20.91
VGG-11 30.98 11.37
VGG-13 30.07 10.75
VGG-16 28.41 9.62
VGG-19 27.62 9.12
VGG-11 with batch normalization 29.62 10.19
VGG-13 with batch normalization 28.45 9.63
VGG-16 with batch normalization 26.63 8.50
VGG-19 with batch normalization 25.76 8.15
SqueezeNet 1.0 41.90 19.58
SqueezeNet 1.1 41.81 19.38
Densenet-121 25.35 7.83
Densenet-169 24.00 7.00
Densenet-201 22.80 6.43
Densenet-161 22.35 6.20
================================ ============= =============
.. _AlexNet: https://arxiv.org/abs/1404.5997
.. _VGG: https://arxiv.org/abs/1409.1556
.. _ResNet: https://arxiv.org/abs/1512.03385
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
.. _DenseNet: https://arxiv.org/abs/1608.06993
.. _Inception: https://arxiv.org/abs/1512.00567
"""
from .alexnet import * from .alexnet import *
from .resnet import * from .resnet import *
from .vgg import * from .vgg import *
......
...@@ -17,7 +17,7 @@ model_urls = { ...@@ -17,7 +17,7 @@ model_urls = {
def densenet121(pretrained=False, **kwargs): def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
...@@ -31,7 +31,7 @@ def densenet121(pretrained=False, **kwargs): ...@@ -31,7 +31,7 @@ def densenet121(pretrained=False, **kwargs):
def densenet169(pretrained=False, **kwargs): def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
...@@ -45,7 +45,7 @@ def densenet169(pretrained=False, **kwargs): ...@@ -45,7 +45,7 @@ def densenet169(pretrained=False, **kwargs):
def densenet201(pretrained=False, **kwargs): def densenet201(pretrained=False, **kwargs):
r"""Densenet-201 model from r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
...@@ -59,7 +59,7 @@ def densenet201(pretrained=False, **kwargs): ...@@ -59,7 +59,7 @@ def densenet201(pretrained=False, **kwargs):
def densenet161(pretrained=False, **kwargs): def densenet161(pretrained=False, **kwargs):
r"""Densenet-161 model from r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
...@@ -111,7 +111,7 @@ class _Transition(nn.Sequential): ...@@ -111,7 +111,7 @@ class _Transition(nn.Sequential):
class DenseNet(nn.Module): class DenseNet(nn.Module):
r"""Densenet-BC model class, based on r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args: Args:
growth_rate (int) - how many filters to add each layer (`k` in paper) growth_rate (int) - how many filters to add each layer (`k` in paper)
......
...@@ -30,12 +30,12 @@ def _is_numpy_image(img): ...@@ -30,12 +30,12 @@ def _is_numpy_image(img):
def to_tensor(pic): def to_tensor(pic):
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
See ``ToTensor`` for more details. See ``ToTensor`` for more details.
Args: Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
...@@ -84,10 +84,10 @@ def to_pil_image(pic): ...@@ -84,10 +84,10 @@ def to_pil_image(pic):
See ``ToPIlImage`` for more details. See ``ToPIlImage`` for more details.
Args: Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image. pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns: Returns:
PIL.Image: Image converted to PIL.Image. PIL Image: Image converted to PIL Image.
""" """
if not(_is_numpy_image(pic) or _is_tensor_image(pic)): if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
...@@ -143,10 +143,10 @@ def normalize(tensor, mean, std): ...@@ -143,10 +143,10 @@ def normalize(tensor, mean, std):
def resize(img, size, interpolation=Image.BILINEAR): def resize(img, size, interpolation=Image.BILINEAR):
"""Resize the input PIL.Image to the given size. """Resize the input PIL Image to the given size.
Args: Args:
img (PIL.Image): Image to be resized. img (PIL Image): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int, (h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaing the smaller edge of the image will be matched to this number maintaing
...@@ -156,7 +156,7 @@ def resize(img, size, interpolation=Image.BILINEAR): ...@@ -156,7 +156,7 @@ def resize(img, size, interpolation=Image.BILINEAR):
``PIL.Image.BILINEAR`` ``PIL.Image.BILINEAR``
Returns: Returns:
PIL.Image: Resized image. PIL Image: Resized image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -186,10 +186,10 @@ def scale(*args, **kwargs): ...@@ -186,10 +186,10 @@ def scale(*args, **kwargs):
def pad(img, padding, fill=0): def pad(img, padding, fill=0):
"""Pad the given PIL.Image on all sides with the given "pad" value. """Pad the given PIL Image on all sides with the given "pad" value.
Args: Args:
img (PIL.Image): Image to be padded. img (PIL Image): Image to be padded.
padding (int or tuple): Padding on each border. If a single int is provided this padding (int or tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided on left/right and top/bottom respectively. If a tuple of length 4 is provided
...@@ -199,7 +199,7 @@ def pad(img, padding, fill=0): ...@@ -199,7 +199,7 @@ def pad(img, padding, fill=0):
length 3, it is used to fill R, G, B channels respectively. length 3, it is used to fill R, G, B channels respectively.
Returns: Returns:
PIL.Image: Padded image. PIL Image: Padded image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -217,17 +217,17 @@ def pad(img, padding, fill=0): ...@@ -217,17 +217,17 @@ def pad(img, padding, fill=0):
def crop(img, i, j, h, w): def crop(img, i, j, h, w):
"""Crop the given PIL.Image. """Crop the given PIL Image.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
i: Upper pixel coordinate. i: Upper pixel coordinate.
j: Left pixel coordinate. j: Left pixel coordinate.
h: Height of the cropped image. h: Height of the cropped image.
w: Width of the cropped image. w: Width of the cropped image.
Returns: Returns:
PIL.Image: Cropped image. PIL Image: Cropped image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -236,12 +236,12 @@ def crop(img, i, j, h, w): ...@@ -236,12 +236,12 @@ def crop(img, i, j, h, w):
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
"""Crop the given PIL.Image and resize it to desired size. """Crop the given PIL Image and resize it to desired size.
Notably used in RandomResizedCrop. Notably used in RandomResizedCrop.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
i: Upper pixel coordinate. i: Upper pixel coordinate.
j: Left pixel coordinate. j: Left pixel coordinate.
h: Height of the cropped image. h: Height of the cropped image.
...@@ -250,7 +250,7 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): ...@@ -250,7 +250,7 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
interpolation (int, optional): Desired interpolation. Default is interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``. ``PIL.Image.BILINEAR``.
Returns: Returns:
PIL.Image: Cropped image. PIL Image: Cropped image.
""" """
assert _is_pil_image(img), 'img should be PIL Image' assert _is_pil_image(img), 'img should be PIL Image'
img = crop(img, i, j, h, w) img = crop(img, i, j, h, w)
...@@ -259,13 +259,13 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): ...@@ -259,13 +259,13 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
def hflip(img): def hflip(img):
"""Horizontally flip the given PIL.Image. """Horizontally flip the given PIL Image.
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Horizontall flipped image. PIL Image: Horizontall flipped image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -274,13 +274,13 @@ def hflip(img): ...@@ -274,13 +274,13 @@ def hflip(img):
def vflip(img): def vflip(img):
"""Vertically flip the given PIL.Image. """Vertically flip the given PIL Image.
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Vertically flipped image. PIL Image: Vertically flipped image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -289,10 +289,11 @@ def vflip(img): ...@@ -289,10 +289,11 @@ def vflip(img):
def five_crop(img, size): def five_crop(img, size):
"""Crop the given PIL.Image into four corners and the central crop. """Crop the given PIL Image into four corners and the central crop.
Note: this transform returns a tuple of images and there may be a mismatch in the number of .. Note::
inputs and targets your `Dataset` returns. This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
...@@ -321,11 +322,12 @@ def five_crop(img, size): ...@@ -321,11 +322,12 @@ def five_crop(img, size):
def ten_crop(img, size, vertical_flip=False): def ten_crop(img, size, vertical_flip=False):
"""Crop the given PIL.Image into four corners and the central crop plus the """Crop the given PIL Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default). flipped version of these (horizontal flipping is used by default).
Note: this transform returns a tuple of images and there may be a mismatch in the number of .. Note::
inputs and targets your `Dataset` returns. This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
...@@ -359,13 +361,13 @@ def adjust_brightness(img, brightness_factor): ...@@ -359,13 +361,13 @@ def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an Image. """Adjust brightness of an Image.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2. original image while 2 increases the brightness by a factor of 2.
Returns: Returns:
PIL.Image: Brightness adjusted image. PIL Image: Brightness adjusted image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -379,13 +381,13 @@ def adjust_contrast(img, contrast_factor): ...@@ -379,13 +381,13 @@ def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image. """Adjust contrast of an Image.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2. original image while 2 increases the contrast by a factor of 2.
Returns: Returns:
PIL.Image: Contrast adjusted image. PIL Image: Contrast adjusted image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -399,13 +401,13 @@ def adjust_saturation(img, saturation_factor): ...@@ -399,13 +401,13 @@ def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image. """Adjust color saturation of an image.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2. 2 will enhance the saturation by a factor of 2.
Returns: Returns:
PIL.Image: Saturation adjusted image. PIL Image: Saturation adjusted image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -428,7 +430,7 @@ def adjust_hue(img, hue_factor): ...@@ -428,7 +430,7 @@ def adjust_hue(img, hue_factor):
See https://en.wikipedia.org/wiki/Hue for more details on Hue. See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively. HSV space in positive and negative direction respectively.
...@@ -436,7 +438,7 @@ def adjust_hue(img, hue_factor): ...@@ -436,7 +438,7 @@ def adjust_hue(img, hue_factor):
with complementary colors while 0 gives the original image. with complementary colors while 0 gives the original image.
Returns: Returns:
PIL.Image: Hue adjusted image. PIL Image: Hue adjusted image.
""" """
if not(-0.5 <= hue_factor <= 0.5): if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
...@@ -471,7 +473,7 @@ def adjust_gamma(img, gamma, gain=1): ...@@ -471,7 +473,7 @@ def adjust_gamma(img, gamma, gain=1):
See https://en.wikipedia.org/wiki/Gamma_correction for more details. See https://en.wikipedia.org/wiki/Gamma_correction for more details.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
gamma (float): Non negative real number. gamma larger than 1 make the gamma (float): Non negative real number. gamma larger than 1 make the
shadows darker, while gamma smaller than 1 make dark regions shadows darker, while gamma smaller than 1 make dark regions
lighter. lighter.
...@@ -517,16 +519,16 @@ class Compose(object): ...@@ -517,16 +519,16 @@ class Compose(object):
class ToTensor(object): class ToTensor(object):
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL.Image or numpy.ndarray (H x W x C) in the range Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
""" """
def __call__(self, pic): def __call__(self, pic):
""" """
Args: Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
...@@ -538,16 +540,16 @@ class ToPILImage(object): ...@@ -538,16 +540,16 @@ class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image. """Convert a tensor or an ndarray to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving the value range. H x W x C to a PIL Image while preserving the value range.
""" """
def __call__(self, pic): def __call__(self, pic):
""" """
Args: Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image. pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns: Returns:
PIL.Image: Image converted to PIL.Image. PIL Image: Image converted to PIL Image.
""" """
return to_pil_image(pic) return to_pil_image(pic)
...@@ -582,7 +584,7 @@ class Normalize(object): ...@@ -582,7 +584,7 @@ class Normalize(object):
class Resize(object): class Resize(object):
"""Resize the input PIL.Image to the given size. """Resize the input PIL Image to the given size.
Args: Args:
size (sequence or int): Desired output size. If size is a sequence like size (sequence or int): Desired output size. If size is a sequence like
...@@ -602,15 +604,18 @@ class Resize(object): ...@@ -602,15 +604,18 @@ class Resize(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be scaled. img (PIL Image): Image to be scaled.
Returns: Returns:
PIL.Image: Rescaled image. PIL Image: Rescaled image.
""" """
return resize(img, self.size, self.interpolation) return resize(img, self.size, self.interpolation)
class Scale(Resize): class Scale(Resize):
"""
Note: This transform is deprecated in favor of Resize.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " + warnings.warn("The use of the transforms.Scale transform is deprecated, " +
"please use transforms.Resize instead.") "please use transforms.Resize instead.")
...@@ -618,7 +623,7 @@ class Scale(Resize): ...@@ -618,7 +623,7 @@ class Scale(Resize):
class CenterCrop(object): class CenterCrop(object):
"""Crops the given PIL.Image at the center. """Crops the given PIL Image at the center.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
...@@ -637,7 +642,7 @@ class CenterCrop(object): ...@@ -637,7 +642,7 @@ class CenterCrop(object):
"""Get parameters for ``crop`` for center crop. """Get parameters for ``crop`` for center crop.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop. output_size (tuple): Expected output size of the crop.
Returns: Returns:
...@@ -652,17 +657,17 @@ class CenterCrop(object): ...@@ -652,17 +657,17 @@ class CenterCrop(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
Returns: Returns:
PIL.Image: Cropped image. PIL Image: Cropped image.
""" """
i, j, h, w = self.get_params(img, self.size) i, j, h, w = self.get_params(img, self.size)
return crop(img, i, j, h, w) return crop(img, i, j, h, w)
class Pad(object): class Pad(object):
"""Pad the given PIL.Image on all sides with the given "pad" value. """Pad the given PIL Image on all sides with the given "pad" value.
Args: Args:
padding (int or tuple): Padding on each border. If a single int is provided this padding (int or tuple): Padding on each border. If a single int is provided this
...@@ -687,10 +692,10 @@ class Pad(object): ...@@ -687,10 +692,10 @@ class Pad(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be padded. img (PIL Image): Image to be padded.
Returns: Returns:
PIL.Image: Padded image. PIL Image: Padded image.
""" """
return pad(img, self.padding, self.fill) return pad(img, self.padding, self.fill)
...@@ -711,7 +716,7 @@ class Lambda(object): ...@@ -711,7 +716,7 @@ class Lambda(object):
class RandomCrop(object): class RandomCrop(object):
"""Crop the given PIL.Image at a random location. """Crop the given PIL Image at a random location.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
...@@ -735,7 +740,7 @@ class RandomCrop(object): ...@@ -735,7 +740,7 @@ class RandomCrop(object):
"""Get parameters for ``crop`` for a random crop. """Get parameters for ``crop`` for a random crop.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop. output_size (tuple): Expected output size of the crop.
Returns: Returns:
...@@ -753,10 +758,10 @@ class RandomCrop(object): ...@@ -753,10 +758,10 @@ class RandomCrop(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
Returns: Returns:
PIL.Image: Cropped image. PIL Image: Cropped image.
""" """
if self.padding > 0: if self.padding > 0:
img = pad(img, self.padding) img = pad(img, self.padding)
...@@ -767,15 +772,15 @@ class RandomCrop(object): ...@@ -767,15 +772,15 @@ class RandomCrop(object):
class RandomHorizontalFlip(object): class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Randomly flipped image. PIL Image: Randomly flipped image.
""" """
if random.random() < 0.5: if random.random() < 0.5:
return hflip(img) return hflip(img)
...@@ -783,15 +788,15 @@ class RandomHorizontalFlip(object): ...@@ -783,15 +788,15 @@ class RandomHorizontalFlip(object):
class RandomVerticalFlip(object): class RandomVerticalFlip(object):
"""Vertically flip the given PIL.Image randomly with a probability of 0.5.""" """Vertically flip the given PIL Image randomly with a probability of 0.5."""
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Randomly flipped image. PIL Image: Randomly flipped image.
""" """
if random.random() < 0.5: if random.random() < 0.5:
return vflip(img) return vflip(img)
...@@ -799,7 +804,7 @@ class RandomVerticalFlip(object): ...@@ -799,7 +804,7 @@ class RandomVerticalFlip(object):
class RandomResizedCrop(object): class RandomResizedCrop(object):
"""Crop the given PIL.Image to random size and aspect ratio. """Crop the given PIL Image to random size and aspect ratio.
A crop of random size of (0.08 to 1.0) of the original size and a random A crop of random size of (0.08 to 1.0) of the original size and a random
aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
...@@ -820,7 +825,7 @@ class RandomResizedCrop(object): ...@@ -820,7 +825,7 @@ class RandomResizedCrop(object):
"""Get parameters for ``crop`` for a random sized crop. """Get parameters for ``crop`` for a random sized crop.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
Returns: Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random tuple: params (i, j, h, w) to be passed to ``crop`` for a random
...@@ -851,16 +856,19 @@ class RandomResizedCrop(object): ...@@ -851,16 +856,19 @@ class RandomResizedCrop(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Randomly cropped and resize image. PIL Image: Randomly cropped and resize image.
""" """
i, j, h, w = self.get_params(img) i, j, h, w = self.get_params(img)
return resized_crop(img, i, j, h, w, self.size, self.interpolation) return resized_crop(img, i, j, h, w, self.size, self.interpolation)
class RandomSizedCrop(RandomResizedCrop): class RandomSizedCrop(RandomResizedCrop):
"""
Note: This transform is deprecated in favor of RandomResizedCrop.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
"please use transforms.RandomResizedCrop instead.") "please use transforms.RandomResizedCrop instead.")
...@@ -868,7 +876,7 @@ class RandomSizedCrop(RandomResizedCrop): ...@@ -868,7 +876,7 @@ class RandomSizedCrop(RandomResizedCrop):
class FiveCrop(object): class FiveCrop(object):
"""Crop the given PIL.Image into four corners and the central crop.abs """Crop the given PIL Image into four corners and the central crop.abs
Note: this transform returns a tuple of images and there may be a mismatch in the number of Note: this transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your `Dataset` returns. inputs and targets your `Dataset` returns.
...@@ -892,7 +900,7 @@ class FiveCrop(object): ...@@ -892,7 +900,7 @@ class FiveCrop(object):
class TenCrop(object): class TenCrop(object):
"""Crop the given PIL.Image into four corners and the central crop plus the """Crop the given PIL Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default) flipped version of these (horizontal flipping is used by default)
Note: this transform returns a tuple of images and there may be a mismatch in the number of Note: this transform returns a tuple of images and there may be a mismatch in the number of
...@@ -1013,10 +1021,10 @@ class ColorJitter(object): ...@@ -1013,10 +1021,10 @@ class ColorJitter(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Input image. img (PIL Image): Input image.
Returns: Returns:
PIL.Image: Color jittered image. PIL Image: Color jittered image.
""" """
transform = self.get_params(self.brightness, self.contrast, transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue) self.saturation, self.hue)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment