"torchvision/csrc/io/image/cpu/decode_jpeg.cpp" did not exist on "74de51d6d478e289135d9274e6af550a9bfba137"
Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
......@@ -24,7 +24,7 @@ docset: html
convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png
html-noplot: # Avoids running the gallery examples, which may take time
$(SPHINXBUILD) -D plot_gallery=0 -b html $(ASPHINXOPTS) "${SOURCEDIR}" "$(BUILDDIR)"/html
$(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
......@@ -32,6 +32,8 @@ clean:
rm -rf $(BUILDDIR)/*
rm -rf $(SOURCEDIR)/auto_examples/ # sphinx-gallery
rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery
rm -rf $(SOURCEDIR)/generated/ # autosummary
rm -rf $(SOURCEDIR)/models/generated # autosummary
.PHONY: help Makefile docset
......
sphinx==2.4.4
sphinx-gallery>=0.9.0
sphinx-copybutton>=0.3.1
matplotlib
numpy
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinx-copybutton>=0.3.1
sphinx-gallery>=0.11.1
sphinx==5.0.0
tabulate
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
pycocotools
/* This rule (and possibly this entire file) should be removed once
/* This rule should be removed once
https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed.
We override the rule so that the links to the notebooks aren't hidden in the
......@@ -9,4 +9,27 @@ torchvision it just hides the links. So we have to put them back here */
article.pytorch-article .sphx-glr-download-link-note.admonition.note,
article.pytorch-article .reference.download.internal, article.pytorch-article .sphx-glr-signature {
display: block;
}
\ No newline at end of file
}
/* These 2 rules below are for the weight tables (generated in conf.py) to look
* better. In particular we make their row height shorter */
.table-weights td, .table-weights th {
margin-bottom: 0.2rem;
padding: 0 !important;
line-height: 1 !important;
}
.table-weights p {
margin-bottom: 0.2rem !important;
}
/* Fix for Sphinx gallery 0.11
See https://github.com/sphinx-gallery/sphinx-gallery/issues/990
*/
article.pytorch-article .sphx-glr-thumbnails .sphx-glr-thumbcontainer {
width: unset;
margin-right: 0;
margin-left: 0;
}
article.pytorch-article div.section div.wy-table-responsive tbody td {
width: 50%;
}
......@@ -30,4 +30,4 @@
style="fill:#9e529f"
id="path4698"
d="m 24.075479,-7.6293945e-7 c -0.5,0 -1.8,2.49999996293945 -1.8,3.59999996293945 0,1.5 1,2 1.8,2 0.8,0 1.8,-0.5 1.8,-2 -0.1,-1.1 -1.4,-3.59999996293945 -1.8,-3.59999996293945 z"
class="st1" /></svg>
\ No newline at end of file
class="st1" /></svg>
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autoclass:: {{ name }}
:members:
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autoclass:: {{ name }}
:members:
__getitem__,
{% if "category_name" in methods %} category_name {% endif %}
:special-members:
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autofunction:: {{ name }}
from docutils import nodes
from docutils.parsers.rst import Directive
class BetaStatus(Directive):
has_content = True
text = "The {api_name} is in Beta stage, and backward compatibility is not guaranteed."
node = nodes.warning
def run(self):
text = self.text.format(api_name=" ".join(self.content))
return [self.node("", nodes.paragraph("", "", nodes.Text(text)))]
def setup(app):
app.add_directive("betastatus", BetaStatus)
return {
"version": "0.1",
"parallel_read_safe": True,
"parallel_write_safe": True,
}
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
......@@ -21,79 +20,146 @@
# import sys
# sys.path.insert(0, os.path.abspath('.'))
import torchvision
import os
import sys
import textwrap
from copy import copy
from pathlib import Path
import pytorch_sphinx_theme
import torchvision
import torchvision.models as M
from sphinx_gallery.sorting import ExplicitOrder
from tabulate import tabulate
sys.path.append(os.path.abspath("."))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Required version of sphinx is set from docs/requirements.txt
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.duration',
'sphinx_gallery.gen_gallery',
"sphinx_copybutton"
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.duration",
"sphinx_gallery.gen_gallery",
"sphinx_copybutton",
"beta_status",
]
# We override sphinx-gallery's example header to prevent sphinx-gallery from
# creating a note at the top of the renderred notebook.
# https://github.com/sphinx-gallery/sphinx-gallery/blob/451ccba1007cc523f39cbcc960ebc21ca39f7b75/sphinx_gallery/gen_rst.py#L1267-L1271
# This is because we also want to add a link to google collab, so we write our own note in each example.
from sphinx_gallery import gen_rst
gen_rst.EXAMPLE_HEADER = """
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "{0}"
.. LINE NUMBERS ARE GIVEN BELOW.
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_{1}:
"""
class CustomGalleryExampleSortKey:
# See https://sphinx-gallery.github.io/stable/configuration.html#sorting-gallery-examples
# and https://github.com/sphinx-gallery/sphinx-gallery/blob/master/sphinx_gallery/sorting.py
def __init__(self, src_dir):
self.src_dir = src_dir
transforms_subsection_order = [
"plot_transforms_getting_started.py",
"plot_transforms_illustrations.py",
"plot_transforms_e2e.py",
"plot_cutmix_mixup.py",
"plot_custom_transforms.py",
"plot_tv_tensors.py",
"plot_custom_tv_tensors.py",
]
def __call__(self, filename):
if "gallery/transforms" in self.src_dir:
try:
return self.transforms_subsection_order.index(filename)
except ValueError as e:
raise ValueError(
"Looks like you added an example in gallery/transforms? "
"You need to specify its order in docs/source/conf.py. Look for CustomGalleryExampleSortKey."
) from e
else:
# For other subsections we just sort alphabetically by filename
return filename
sphinx_gallery_conf = {
'examples_dirs': '../../gallery/', # path to your example scripts
'gallery_dirs': 'auto_examples', # path to where to save gallery generated output
'backreferences_dir': 'gen_modules/backreferences',
'doc_module': ('torchvision',),
"examples_dirs": "../../gallery/", # path to your example scripts
"gallery_dirs": "auto_examples", # path to where to save gallery generated output
"subsection_order": ExplicitOrder(["../../gallery/transforms", "../../gallery/others"]),
"backreferences_dir": "gen_modules/backreferences",
"doc_module": ("torchvision",),
"remove_config_comments": True,
"ignore_pattern": "helpers.py",
"within_subsection_order": CustomGalleryExampleSortKey,
}
napoleon_use_ivar = True
napoleon_numpy_docstring = False
napoleon_google_docstring = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
source_suffix = {
".rst": "restructuredtext",
}
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# General information about the project.
project = 'Torchvision'
copyright = '2017-present, Torch Contributors'
author = 'Torch Contributors'
project = "Torchvision"
copyright = "2017-present, Torch Contributors"
author = "Torch Contributors"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
# TODO: change to [:2] at v1.0
version = '0.10.0'
# The full version, including alpha/beta/rc tags.
# TODO: verify this works as expected
release = torchvision.__version__
# version: The short X.Y version.
# release: The full version, including alpha/beta/rc tags.
if os.environ.get("TORCHVISION_SANITIZE_VERSION_STR_IN_DOCS", None):
# Turn 1.11.0aHASH into 1.11 (major.minor only)
version = release = ".".join(torchvision.__version__.split(".")[:2])
html_title = " ".join((project, version, "documentation"))
else:
version = f"main ({torchvision.__version__})"
release = "main"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
......@@ -101,7 +167,7 @@ language = None
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
......@@ -112,7 +178,7 @@ todo_include_todos = True
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'pytorch_sphinx_theme'
html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
......@@ -120,58 +186,57 @@ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# documentation.
#
html_theme_options = {
'collapse_navigation': False,
'display_version': True,
'logo_only': True,
'pytorch_project': 'docs',
'navigation_with_keys': True,
'analytics_id': 'UA-117752657-2',
"collapse_navigation": False,
"display_version": True,
"logo_only": True,
"pytorch_project": "docs",
"navigation_with_keys": True,
"analytics_id": "GTM-T8XT4PS",
}
html_logo = '_static/img/pytorch-logo-dark.svg'
html_logo = "_static/img/pytorch-logo-dark.svg"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
# TODO: remove this once https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed
html_css_files = [
'css/custom_torchvision.css',
"css/custom_torchvision.css",
]
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc'
htmlhelp_basename = "PyTorchdoc"
# -- Options for LaTeX output ---------------------------------------------
autosummary_generate = True
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'pytorch.tex', 'torchvision Documentation',
'Torch Contributors', 'manual'),
(master_doc, "pytorch.tex", "torchvision Documentation", "Torch Contributors", "manual"),
]
......@@ -179,10 +244,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchvision', 'torchvision Documentation',
[author], 1)
]
man_pages = [(master_doc, "torchvision", "torchvision Documentation", [author], 1)]
# -- Options for Texinfo output -------------------------------------------
......@@ -191,27 +253,33 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'torchvision', 'torchvision Documentation',
author, 'torchvision', 'One line description of project.',
'Miscellaneous'),
(
master_doc,
"torchvision",
"torchvision Documentation",
author,
"torchvision",
"One line description of project.",
"Miscellaneous",
),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/', None),
'torch': ('https://pytorch.org/docs/stable/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'PIL': ('https://pillow.readthedocs.io/en/stable/', None),
'matplotlib': ('https://matplotlib.org/stable/', None),
"python": ("https://docs.python.org/3/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
}
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes
from sphinx.util.docfields import TypedField
def patched_make_field(self, types, domain, items, **kw):
......@@ -221,40 +289,39 @@ def patched_make_field(self, types, domain, items, **kw):
# type: (list, unicode, tuple) -> nodes.field # noqa: F821
def handle_item(fieldarg, content):
par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
par += addnodes.literal_strong("", fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong))
if fieldarg in types:
par += nodes.Text(' (')
par += nodes.Text(" (")
# NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int')
typename = typename.replace('long', 'python:long')
typename = typename.replace('float', 'python:float')
typename = typename.replace('type', 'python:type')
par.extend(self.make_xrefs(self.typerolename, domain, typename,
addnodes.literal_emphasis, **kw))
typename = "".join(n.astext() for n in fieldtype)
typename = typename.replace("int", "python:int")
typename = typename.replace("long", "python:long")
typename = typename.replace("float", "python:float")
typename = typename.replace("type", "python:type")
par.extend(self.make_xrefs(self.typerolename, domain, typename, addnodes.literal_emphasis, **kw))
else:
par += fieldtype
par += nodes.Text(')')
par += nodes.Text(' -- ')
par += nodes.Text(")")
par += nodes.Text(" -- ")
par += content
return par
fieldname = nodes.field_name('', self.label)
fieldname = nodes.field_name("", self.label)
if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content)
else:
bodynode = self.list_type()
for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode)
return nodes.field('', fieldname, fieldbody)
bodynode += nodes.list_item("", handle_item(fieldarg, content))
fieldbody = nodes.field_body("", bodynode)
return nodes.field("", fieldname, fieldbody)
TypedField.make_field = patched_make_field
......@@ -286,5 +353,172 @@ def inject_minigalleries(app, what, name, obj, options, lines):
lines.append("\n")
def inject_weight_metadata(app, what, name, obj, options, lines):
"""This hook is used to generate docs for the models weights.
Objects like ResNet18_Weights are enums with fields, where each field is a Weight object.
Enums aren't easily documented in Python so the solution we're going for is to:
- add an autoclass directive in the model's builder docstring, e.g.
```
.. autoclass:: torchvision.models.ResNet34_Weights
:members:
```
(see resnet.py for an example)
- then this hook is called automatically when building the docs, and it generates the text that gets
used within the autoclass directive.
"""
if getattr(obj, "__name__", "").endswith(("_Weights", "_QuantizedWeights")):
if len(obj) == 0:
lines[:] = ["There are no available pre-trained weights."]
return
lines[:] = [
"The model builder above accepts the following values as the ``weights`` parameter.",
f"``{obj.__name__}.DEFAULT`` is equivalent to ``{obj.DEFAULT}``. You can also use strings, e.g. "
f"``weights='DEFAULT'`` or ``weights='{str(list(obj)[0]).split('.')[1]}'``.",
]
if obj.__doc__ != "An enumeration.":
# We only show the custom enum doc if it was overridden. The default one from Python is "An enumeration"
lines.append("")
lines.append(obj.__doc__)
lines.append("")
for field in obj:
meta = copy(field.meta)
lines += [f"**{str(field)}**:", ""]
lines += [meta.pop("_docs")]
if field == obj.DEFAULT:
lines += [f"Also available as ``{obj.__name__}.DEFAULT``."]
lines += [""]
table = []
metrics = meta.pop("_metrics")
for dataset, dataset_metrics in metrics.items():
for metric_name, metric_value in dataset_metrics.items():
table.append((f"{metric_name} (on {dataset})", str(metric_value)))
for k, v in meta.items():
if k in {"recipe", "license"}:
v = f"`link <{v}>`__"
elif k == "min_size":
v = f"height={v[0]}, width={v[1]}"
elif k in {"categories", "keypoint_names"} and isinstance(v, list):
max_visible = 3
v_sample = ", ".join(v[:max_visible])
v = f"{v_sample}, ... ({len(v)-max_visible} omitted)" if len(v) > max_visible else v_sample
elif k == "_ops":
v = f"{v:.2f}"
k = "GIPS" if obj.__name__.endswith("_QuantizedWeights") else "GFLOPS"
elif k == "_file_size":
k = "File size"
v = f"{v:.1f} MB"
table.append((str(k), str(v)))
table = tabulate(table, tablefmt="rst")
lines += [".. rst-class:: table-weights"] # Custom CSS class, see custom_torchvision.css
lines += [".. table::", ""]
lines += textwrap.indent(table, " " * 4).split("\n")
lines.append("")
lines.append(
f"The inference transforms are available at ``{str(field)}.transforms`` and "
f"perform the following preprocessing operations: {field.transforms().describe()}"
)
lines.append("")
def generate_weights_table(module, table_name, metrics, dataset, include_patterns=None, exclude_patterns=None):
weights_endswith = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith(weights_endswith)]
weights = [w for weight_enum in weight_enums for w in weight_enum]
if include_patterns is not None:
weights = [w for w in weights if any(p in str(w) for p in include_patterns)]
if exclude_patterns is not None:
weights = [w for w in weights if all(p not in str(w) for p in exclude_patterns)]
ops_name = "GIPS" if "QuantizedWeights" in weights_endswith else "GFLOPS"
metrics_keys, metrics_names = zip(*metrics)
column_names = ["Weight"] + list(metrics_names) + ["Params"] + [ops_name, "Recipe"] # Final column order
column_names = [f"**{name}**" for name in column_names] # Add bold
content = []
for w in weights:
row = [
f":class:`{w} <{type(w).__name__}>`",
*(w.meta["_metrics"][dataset][metric] for metric in metrics_keys),
f"{w.meta['num_params']/1e6:.1f}M",
f"{w.meta['_ops']:.2f}",
f"`link <{w.meta['recipe']}>`__",
]
content.append(row)
column_widths = ["110"] + ["18"] * len(metrics_names) + ["18"] * 2 + ["10"]
widths_table = " ".join(column_widths)
table = tabulate(content, headers=column_names, tablefmt="rst")
generated_dir = Path("generated")
generated_dir.mkdir(exist_ok=True)
with open(generated_dir / f"{table_name}_table.rst", "w+") as table_file:
table_file.write(".. rst-class:: table-weights\n") # Custom CSS class, see custom_torchvision.css
table_file.write(".. table::\n")
table_file.write(f" :widths: {widths_table} \n\n")
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")
generate_weights_table(
module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")], dataset="ImageNet-1K"
)
generate_weights_table(
module=M.quantization,
table_name="classification_quant",
metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")],
dataset="ImageNet-1K",
)
generate_weights_table(
module=M.detection,
table_name="detection",
metrics=[("box_map", "Box MAP")],
exclude_patterns=["Mask", "Keypoint"],
dataset="COCO-val2017",
)
generate_weights_table(
module=M.detection,
table_name="instance_segmentation",
metrics=[("box_map", "Box MAP"), ("mask_map", "Mask MAP")],
dataset="COCO-val2017",
include_patterns=["Mask"],
)
generate_weights_table(
module=M.detection,
table_name="detection_keypoint",
metrics=[("box_map", "Box MAP"), ("kp_map", "Keypoint MAP")],
dataset="COCO-val2017",
include_patterns=["Keypoint"],
)
generate_weights_table(
module=M.segmentation,
table_name="segmentation",
metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")],
dataset="COCO-val2017-VOC-labels",
)
generate_weights_table(
module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")], dataset="Kinetics-400"
)
def setup(app):
app.connect('autodoc-process-docstring', inject_minigalleries)
app.connect("autodoc-process-docstring", inject_minigalleries)
app.connect("autodoc-process-docstring", inject_weight_metadata)
torchvision.datasets
====================
.. _datasets:
Datasets
========
Torchvision provides many built-in datasets in the ``torchvision.datasets``
module, as well as utility classes for building your own datasets.
Built-in datasets
-----------------
All datasets are subclasses of :class:`torch.utils.data.Dataset`
i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
......@@ -19,242 +27,157 @@ All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively.
You can also create your own datasets using the provided :ref:`base classes <base_classes_datasets>`.
Caltech
~~~~~~~
.. autoclass:: Caltech101
:members: __getitem__
:special-members:
.. autoclass:: Caltech256
:members: __getitem__
:special-members:
CelebA
~~~~~~
.. autoclass:: CelebA
:members: __getitem__
:special-members:
CIFAR
~~~~~
.. autoclass:: CIFAR10
:members: __getitem__
:special-members:
.. autoclass:: CIFAR100
Cityscapes
~~~~~~~~~~
.. note ::
Requires Cityscape to be downloaded.
.. autoclass:: Cityscapes
:members: __getitem__
:special-members:
COCO
~~~~
.. note ::
These require the `COCO API to be installed`_
.. _COCO API to be installed: https://github.com/pdollar/coco/tree/master/PythonAPI
Captions
^^^^^^^^
.. autoclass:: CocoCaptions
:members: __getitem__
:special-members:
Detection
^^^^^^^^^
.. autoclass:: CocoDetection
:members: __getitem__
:special-members:
EMNIST
~~~~~~
.. autoclass:: EMNIST
FakeData
~~~~~~~~
.. autoclass:: FakeData
Fashion-MNIST
~~~~~~~~~~~~~
.. autoclass:: FashionMNIST
Flickr
~~~~~~
.. autoclass:: Flickr8k
:members: __getitem__
:special-members:
.. autoclass:: Flickr30k
:members: __getitem__
:special-members:
HMDB51
~~~~~~~
.. autoclass:: HMDB51
:members: __getitem__
:special-members:
ImageNet
~~~~~~~~~~~
.. autoclass:: ImageNet
.. note ::
This requires `scipy` to be installed
Kinetics-400
Image classification
~~~~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
Caltech101
Caltech256
CelebA
CIFAR10
CIFAR100
Country211
DTD
EMNIST
EuroSAT
FakeData
FashionMNIST
FER2013
FGVCAircraft
Flickr8k
Flickr30k
Flowers102
Food101
GTSRB
INaturalist
ImageNet
Imagenette
KMNIST
LFWPeople
LSUN
MNIST
Omniglot
OxfordIIITPet
Places365
PCAM
QMNIST
RenderedSST2
SEMEION
SBU
StanfordCars
STL10
SUN397
SVHN
USPS
Image detection or segmentation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
CocoDetection
CelebA
Cityscapes
Kitti
OxfordIIITPet
SBDataset
VOCSegmentation
VOCDetection
WIDERFace
Optical Flow
~~~~~~~~~~~~
.. autoclass:: Kinetics400
:members: __getitem__
:special-members:
KITTI
~~~~~~~~~
.. autoclass:: Kitti
:members: __getitem__
:special-members:
KMNIST
~~~~~~~~~~~~~
.. autoclass:: KMNIST
LSUN
~~~~
.. autoclass:: LSUN
:members: __getitem__
:special-members:
MNIST
~~~~~
.. autoclass:: MNIST
Omniglot
~~~~~~~~
.. autoclass:: Omniglot
PhotoTour
~~~~~~~~~
.. autoclass:: PhotoTour
:members: __getitem__
:special-members:
Places365
~~~~~~~~~
.. autoclass:: Places365
:members: __getitem__
:special-members:
QMNIST
~~~~~~
.. autoclass:: QMNIST
SBD
~~~~~~
.. autoclass:: SBDataset
:members: __getitem__
:special-members:
SBU
~~~
.. autoclass:: SBU
:members: __getitem__
:special-members:
SEMEION
~~~~~~~
.. autoclass:: SEMEION
:members: __getitem__
:special-members:
STL10
~~~~~
.. autoclass:: STL10
:members: __getitem__
:special-members:
SVHN
~~~~~
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
FlyingChairs
FlyingThings3D
HD1K
KittiFlow
Sintel
Stereo Matching
~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
CarlaStereo
Kitti2012Stereo
Kitti2015Stereo
CREStereo
FallingThingsStereo
SceneFlowStereo
SintelStereo
InStereo2k
ETH3DStereo
Middlebury2014Stereo
Image pairs
~~~~~~~~~~~
.. autoclass:: SVHN
:members: __getitem__
:special-members:
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
UCF101
~~~~~~~
LFWPairs
PhotoTour
.. autoclass:: UCF101
:members: __getitem__
:special-members:
Image captioning
~~~~~~~~~~~~~~~~
USPS
~~~~~
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
.. autoclass:: USPS
:members: __getitem__
:special-members:
CocoCaptions
VOC
~~~~~~
Video classification
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: VOCSegmentation
:members: __getitem__
:special-members:
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
.. autoclass:: VOCDetection
:members: __getitem__
:special-members:
HMDB51
Kinetics
UCF101
WIDERFace
~~~~~~~~~
Video prediction
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: WIDERFace
:members: __getitem__
:special-members:
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
MovingMNIST
.. _base_classes_datasets:
Base classes for custom datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--------------------------------
.. autosummary::
:toctree: generated/
:template: class.rst
DatasetFolder
ImageFolder
VisionDataset
.. autoclass:: DatasetFolder
:members: __getitem__, find_classes, make_dataset
:special-members:
Transforms v2
-------------
.. autosummary::
:toctree: generated/
:template: function.rst
.. autoclass:: ImageFolder
:members: __getitem__
:special-members:
wrap_dataset_for_transforms_v2
# Necessary for the table generated by autosummary to look decent
[html writers]
table_style: colwidths-auto
Feature extraction for model inspection
=======================================
.. currentmodule:: torchvision.models.feature_extraction
The ``torchvision.models.feature_extraction`` package contains
feature extraction utilities that let us tap into our models to access intermediate
transformations of our inputs. This could be useful for a variety of
applications in computer vision. Just a few examples are:
- Visualizing feature maps.
- Extracting features to compute image descriptors for tasks like facial
recognition, copy-detection, or image retrieval.
- Passing selected features to downstream sub-networks for end-to-end training
with a specific task in mind. For example, passing a hierarchy of features
to a Feature Pyramid Network with object detection heads.
Torchvision provides :func:`create_feature_extractor` for this purpose.
It works by following roughly these steps:
1. Symbolically tracing the model to get a graphical representation of
how it transforms the input, step by step.
2. Setting the user-selected graph nodes as outputs.
3. Removing all redundant nodes (anything downstream of the output nodes).
4. Generating python code from the resulting graph and bundling that into a
PyTorch module together with the graph itself.
|
The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing.
.. _about-node-names:
**About Node Names**
In order to specify which nodes should be output nodes for extracted
features, one should be familiar with the node naming convention used here
(which differs slightly from that used in ``torch.fx``). A node name is
specified as a ``.`` separated path walking the module hierarchy from top level
module down to leaf operation or leaf module. For instance ``"layer4.2.relu"``
in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th
layer of the ``ResNet`` module. Here are some finer points to keep in mind:
- When specifying node names for :func:`create_feature_extractor`, you may
provide a truncated version of a node name as a shortcut. To see how this
works, try creating a ResNet-50 model and printing the node names with
``train_nodes, _ = get_graph_node_names(model) print(train_nodes)`` and
observe that the last node pertaining to ``layer4`` is
``"layer4.2.relu_2"``. One may specify ``"layer4.2.relu_2"`` as the return
node, or just ``"layer4"`` as this, by convention, refers to the last node
(in order of execution) of ``layer4``.
- If a certain module or operation is repeated more than once, node names get
an additional ``_{int}`` postfix to disambiguate. For instance, maybe the
addition (``+``) operation is used three times in the same ``forward``
method. Then there would be ``"path.to.module.add"``,
``"path.to.module.add_1"``, ``"path.to.module.add_2"``. The counter is
maintained within the scope of the direct parent. So in ResNet-50 there is
a ``"layer4.1.add"`` and a ``"layer4.2.add"``. Because the addition
operations reside in different blocks, there is no need for a postfix to
disambiguate.
**An Example**
Here is an example of how we might extract features for MaskRCNN:
.. code-block:: python
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())
# The lists returned, are the names of all the graph nodes (in order of
# execution) for the input model traced in train mode and in eval mode
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# for this example. But if the model contains control flow that's dependent
# on the training mode, they may be different.
# To specify the nodes you want to extract, you could select the final node
# that appears in each of the main layers:
return_nodes = {
# node_name: user-specified key for output dict
'layer1.2.relu_2': 'layer1',
'layer2.3.relu_2': 'layer2',
'layer3.5.relu_2': 'layer3',
'layer4.2.relu_2': 'layer4',
}
# But `create_feature_extractor` can also accept truncated node specifications
# like "layer1", as it will just pick the last node that's a descendent of
# of the specification. (Tip: be careful with this, especially when a layer
# has multiple outputs. It's not always guaranteed that the last operation
# performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.)
return_nodes = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4',
}
# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
# 'layer1': output of layer 1,
# 'layer2': output of layer 2,
# 'layer3': output of layer 3,
# 'layer4': output of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)
# Let's put all that together to wrap resnet50 with MaskRCNN
# MaskRCNN requires a backbone with an attached FPN
class Resnet50WithFPN(torch.nn.Module):
def __init__(self):
super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone
m = resnet50()
# Extract 4 main layers (note: MaskRCNN needs this particular name
# mapping for return nodes)
self.body = create_feature_extractor(
m, return_nodes={f'layer{k}': str(v)
for v, k in enumerate([1, 2, 3, 4])})
# Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224)
with torch.no_grad():
out = self.body(inp)
in_channels_list = [o.shape[1] for o in out.values()]
# Build FPN
self.out_channels = 256
self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels,
extra_blocks=LastLevelMaxPool())
def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x
# Now we can build our model!
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
API Reference
-------------
.. autosummary::
:toctree: generated/
:template: function.rst
create_feature_extractor
get_graph_node_names
......@@ -31,18 +31,21 @@ architectures, and common image transformations for computer vision.
:maxdepth: 2
:caption: Package Reference
datasets
io
models
ops
transforms
tv_tensors
models
datasets
utils
ops
io
feature_extraction
.. toctree::
:maxdepth: 1
:caption: Examples
:caption: Examples and training references
auto_examples/index
training_references
.. automodule:: torchvision
:members:
......@@ -58,3 +61,9 @@ architectures, and common image transformations for computer vision.
TorchElastic <https://pytorch.org/elastic/>
TorchServe <https://pytorch.org/serve>
PyTorch on XLA Devices <http://pytorch.org/xla/>
Indices
-------
* :ref:`genindex`
torchvision.io
==============
Decoding / Encoding images and videos
=====================================
.. currentmodule:: torchvision.io
The :mod:`torchvision.io` package provides functions for performing IO
operations. They are currently specific to reading and writing video and
images.
operations. They are currently specific to reading and writing images and
videos.
Images
------
.. autosummary::
:toctree: generated/
:template: function.rst
read_image
decode_image
encode_jpeg
decode_jpeg
write_jpeg
decode_gif
encode_png
decode_png
write_png
read_file
write_file
.. autosummary::
:toctree: generated/
:template: class.rst
ImageReadMode
Video
-----
.. autofunction:: read_video
.. autofunction:: read_video_timestamps
.. autosummary::
:toctree: generated/
:template: function.rst
.. autofunction:: write_video
read_video
read_video_timestamps
write_video
Fine-grained video API
----------------------
^^^^^^^^^^^^^^^^^^^^^^
In addition to the :mod:`read_video` function, we provide a high-performance
lower-level API for more fine-grained control compared to the :mod:`read_video` function.
It does all this whilst fully supporting torchscript.
.. autoclass:: VideoReader
:members: __next__, get_metadata, set_current_stream, seek
.. betastatus:: fine-grained video API
.. autosummary::
:toctree: generated/
:template: class.rst
VideoReader
Example of inspecting a video:
......@@ -54,29 +88,3 @@ Example of inspecting a video:
# the constructor we select a default video stream, but
# in practice, we can set whichever stream we would like
video.set_current_stream("video:0")
Image
-----
.. autoclass:: ImageReadMode
.. autofunction:: read_image
.. autofunction:: decode_image
.. autofunction:: encode_jpeg
.. autofunction:: decode_jpeg
.. autofunction:: write_jpeg
.. autofunction:: encode_png
.. autofunction:: decode_png
.. autofunction:: write_png
.. autofunction:: read_file
.. autofunction:: write_file
torchvision.models
##################
.. _models:
Models and pre-trained weights
##############################
The models subpackage contains definitions of models for addressing
The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection and video classification.
keypoint detection, video classification, and optical flow.
General information on pre-trained weights
==========================================
Classification
==============
TorchVision offers pre-trained weights for every provided architecture, using
the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its
weights to a cache directory. This directory can be set using the `TORCH_HOME`
environment variable. See :func:`torch.hub.load_state_dict_from_url` for details.
.. note::
The pre-trained models provided in this library may have their own licenses or
terms and conditions derived from the dataset used for training. It is your
responsibility to determine whether you have permission to use the models for
your use case.
.. note ::
Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized
``ScriptModules`` (serialized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
The models subpackage contains definitions for the following model
architectures for image classification:
- `AlexNet`_
- `VGG`_
- `ResNet`_
- `SqueezeNet`_
- `DenseNet`_
- `Inception`_ v3
- `GoogLeNet`_
- `ShuffleNet`_ v2
- `MobileNetV2`_
- `MobileNetV3`_
- `ResNeXt`_
- `Wide ResNet`_
- `MNASNet`_
You can construct a model with random weights by calling its constructor:
Initializing pre-trained models
-------------------------------
As of v0.13, TorchVision offers a new `Multi-weight support API
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_
for loading different weights to the existing model builder methods:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
from torchvision.models import resnet50, ResNet50_Weights
# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)
# Strings are also supported
resnet50(weights="IMAGENET1K_V2")
# No weights - random initialization
resnet50(weights=None)
Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
.. code:: python
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
:func:`torch.utils.model_zoo.load_url` for details.
from torchvision.models import resnet50, ResNet50_Weights
# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True) # deprecated
resnet50(True) # deprecated
# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False) # deprecated
resnet50(False) # deprecated
Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.
Using the pre-trained models
----------------------------
Before using the pre-trained models, one must preprocess the image
(resize with right resolution/interpolation, apply inference transforms,
rescale the values etc). There is no standard way to do this as it depends on
how a given model was trained. It can vary across model families, variants or
even weight versions. Using the correct preprocessing method is critical and
failing to do so may lead to decreased accuracy or incorrect outputs.
All the necessary information for the inference transforms of each pre-trained
model is provided on its weights documentation. To simplify inference, TorchVision
bundles the necessary preprocessing transforms into each model weight. These are
accessible via the ``weight.transforms`` attribute:
.. code:: python
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Apply it to the input image
img_transformed = preprocess(img)
Some models use modules which have different training and evaluation
behavior, such as batch normalization. To switch between these modes, use
``model.train()`` or ``model.eval()`` as appropriate. See
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
where H and W are expected to be at least 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
You can use the following transform to normalize::
.. code:: python
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
An example of such normalization can be found in the imagenet example
`here <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>`_
# Set model to eval mode
model.eval()
The process for obtaining the values of `mean` and `std` is roughly equivalent
to::
Listing and retrieving available models
---------------------------------------
import torch
from torchvision import datasets, transforms as T
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
dataset = datasets.ImageNet(".", split="train", transform=transform)
means = []
stds = []
for img in subset(dataset):
means.append(torch.mean(img))
stds.append(torch.std(img))
mean = torch.mean(torch.tensor(means))
std = torch.mean(torch.tensor(stds))
Unfortunately, the concrete `subset` that was used is lost. For more
information see `this discussion <https://github.com/pytorch/vision/issues/1439>`_
or `these experiments <https://github.com/pytorch/vision/pull/1965>`_.
ImageNet 1-crop error rates (224x224)
================================ ============= =============
Model Acc@1 Acc@5
================================ ============= =============
AlexNet 56.522 79.066
VGG-11 69.020 88.628
VGG-13 69.928 89.246
VGG-16 71.592 90.382
VGG-19 72.376 90.876
VGG-11 with batch normalization 70.370 89.810
VGG-13 with batch normalization 71.586 90.374
VGG-16 with batch normalization 73.360 91.516
VGG-19 with batch normalization 74.218 91.842
ResNet-18 69.758 89.078
ResNet-34 73.314 91.420
ResNet-50 76.130 92.862
ResNet-101 77.374 93.546
ResNet-152 78.312 94.046
SqueezeNet 1.0 58.092 80.420
SqueezeNet 1.1 58.178 80.624
Densenet-121 74.434 91.972
Densenet-169 75.600 92.806
Densenet-201 76.896 93.370
Densenet-161 77.138 93.560
Inception v3 77.294 93.450
GoogleNet 69.778 89.530
ShuffleNet V2 x1.0 69.362 88.316
ShuffleNet V2 x0.5 60.552 81.746
MobileNet V2 71.878 90.286
MobileNet V3 Large 74.042 91.340
MobileNet V3 Small 67.668 87.402
ResNeXt-50-32x4d 77.618 93.698
ResNeXt-101-32x8d 79.312 94.526
Wide ResNet-50-2 78.468 94.086
Wide ResNet-101-2 78.848 94.284
MNASNet 1.0 73.456 91.510
MNASNet 0.5 67.734 87.490
================================ ============= =============
.. _AlexNet: https://arxiv.org/abs/1404.5997
.. _VGG: https://arxiv.org/abs/1409.1556
.. _ResNet: https://arxiv.org/abs/1512.03385
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
.. _DenseNet: https://arxiv.org/abs/1608.06993
.. _Inception: https://arxiv.org/abs/1512.00567
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNetV2: https://arxiv.org/abs/1801.04381
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626
As of v0.14, TorchVision offers a new mechanism which allows listing and
retrieving models and weights by their names. Here are a few examples on how to
use them:
.. currentmodule:: torchvision.models
.. code:: python
Alexnet
-------
# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)
.. autofunction:: alexnet
# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
VGG
---
# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
.. autofunction:: vgg11
.. autofunction:: vgg11_bn
.. autofunction:: vgg13
.. autofunction:: vgg13_bn
.. autofunction:: vgg16
.. autofunction:: vgg16_bn
.. autofunction:: vgg19
.. autofunction:: vgg19_bn
weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2
ResNet
------
Here are the available public functions to retrieve models and their corresponding weights:
.. autofunction:: resnet18
.. autofunction:: resnet34
.. autofunction:: resnet50
.. autofunction:: resnet101
.. autofunction:: resnet152
.. currentmodule:: torchvision.models
.. autosummary::
:toctree: generated/
:template: function.rst
SqueezeNet
----------
get_model
get_model_weights
get_weight
list_models
.. autofunction:: squeezenet1_0
.. autofunction:: squeezenet1_1
Using models from Hub
---------------------
DenseNet
---------
Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
.. autofunction:: densenet121
.. autofunction:: densenet169
.. autofunction:: densenet161
.. autofunction:: densenet201
.. code:: python
Inception v3
------------
import torch
.. autofunction:: inception_v3
# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
.. note ::
This requires `scipy` to be installed
# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
You can also retrieve all the available weights of a specific model via PyTorch Hub by doing:
GoogLeNet
------------
.. code:: python
.. autofunction:: googlenet
import torch
.. note ::
This requires `scipy` to be installed
weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])
The only exception to the above are the detection models included on
:mod:`torchvision.models.detection`. These models require TorchVision
to be installed because they depend on custom C++ operators.
Classification
==============
.. currentmodule:: torchvision.models
ShuffleNet v2
-------------
The following classification models are available, with or without pre-trained
weights:
.. toctree::
:maxdepth: 1
models/alexnet
models/convnext
models/densenet
models/efficientnet
models/efficientnetv2
models/googlenet
models/inception
models/maxvit
models/mnasnet
models/mobilenetv2
models/mobilenetv3
models/regnet
models/resnet
models/resnext
models/shufflenetv2
models/squeezenet
models/swin_transformer
models/vgg
models/vision_transformer
models/wide_resnet
|
Here is an example of how to use the pre-trained image classification models:
.. autofunction:: shufflenet_v2_x0_5
.. autofunction:: shufflenet_v2_x1_0
.. autofunction:: shufflenet_v2_x1_5
.. autofunction:: shufflenet_v2_x2_0
.. code:: python
MobileNet v2
-------------
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
.. autofunction:: mobilenet_v2
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
MobileNet v3
-------------
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
.. autofunction:: mobilenet_v3_large
.. autofunction:: mobilenet_v3_small
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
ResNext
-------
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
.. autofunction:: resnext50_32x4d
.. autofunction:: resnext101_32x8d
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
Wide ResNet
-----------
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
.. autofunction:: wide_resnet50_2
.. autofunction:: wide_resnet101_2
Table of all available classification weights
---------------------------------------------
MNASNet
--------
Accuracies are reported on ImageNet-1K using single crops:
.. autofunction:: mnasnet0_5
.. autofunction:: mnasnet0_75
.. autofunction:: mnasnet1_0
.. autofunction:: mnasnet1_3
.. include:: generated/classification_table.rst
Quantized Models
Quantized models
----------------
The following architectures provide support for INT8 quantized models. You can get
a model with random weights by calling its constructor:
.. currentmodule:: torchvision.models.quantization
.. code:: python
The following architectures provide support for INT8 quantized models, with or without
pre-trained weights:
.. toctree::
:maxdepth: 1
import torchvision.models as models
googlenet = models.quantization.googlenet()
inception_v3 = models.quantization.inception_v3()
mobilenet_v2 = models.quantization.mobilenet_v2()
mobilenet_v3_large = models.quantization.mobilenet_v3_large()
resnet18 = models.quantization.resnet18()
resnet50 = models.quantization.resnet50()
resnext101_32x8d = models.quantization.resnext101_32x8d()
shufflenet_v2_x0_5 = models.quantization.shufflenet_v2_x0_5()
shufflenet_v2_x1_0 = models.quantization.shufflenet_v2_x1_0()
shufflenet_v2_x1_5 = models.quantization.shufflenet_v2_x1_5()
shufflenet_v2_x2_0 = models.quantization.shufflenet_v2_x2_0()
Obtaining a pre-trained quantized model can be done with a few lines of code:
models/googlenet_quant
models/inception_quant
models/mobilenetv2_quant
models/mobilenetv3_quant
models/resnet_quant
models/resnext_quant
models/shufflenetv2_quant
|
Here is an example of how to use the pre-trained quantized image classification models:
.. code:: python
import torchvision.models as models
model = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
from torchvision.io import read_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))
We provide pre-trained quantized weights for the following models:
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")
================================ ============= =============
Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
ShuffleNet V2 68.360 87.582
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
Inception V3 77.176 93.354
GoogleNet 69.826 89.404
================================ ============= =============
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Accuracies are reported on ImageNet-1K using single crops:
.. include:: generated/classification_quant_table.rst
Semantic Segmentation
=====================
The models subpackage contains definitions for the following model
architectures for semantic segmentation:
.. currentmodule:: torchvision.models.segmentation
- `FCN ResNet50, ResNet101 <https://arxiv.org/abs/1411.4038>`_
- `DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large <https://arxiv.org/abs/1706.05587>`_
- `LR-ASPP MobileNetV3-Large <https://arxiv.org/abs/1905.02244>`_
.. betastatus:: segmentation module
As with image classification models, all pre-trained models expect input images normalized in the same way.
The images have to be loaded in to a range of ``[0, 1]`` and then normalized using
``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
They have been trained on images resized such that their minimum size is 520.
The following semantic segmentation models are available, with or without
pre-trained weights:
For details on how to plot the masks of such models, you may refer to :ref:`semantic_seg_output`.
.. toctree::
:maxdepth: 1
The pre-trained models have been trained on a subset of COCO train2017, on the 20 categories that are
present in the Pascal VOC dataset. You can see more information on how the subset has been selected in
``references/segmentation/coco_utils.py``. The classes that the pre-trained model outputs are the following,
in order:
models/deeplabv3
models/fcn
models/lraspp
.. code-block:: python
|
['__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
Here is an example of how to use the pre-trained semantic segmentation models:
The accuracies of the pre-trained models evaluated on COCO val2017 are as follows
.. code:: python
================================ ============= ====================
Network mean IoU global pixelwise acc
================================ ============= ====================
FCN ResNet50 60.5 91.4
FCN ResNet101 63.7 91.9
DeepLabV3 ResNet50 66.4 92.4
DeepLabV3 ResNet101 67.4 92.4
DeepLabV3 MobileNetV3-Large 60.3 91.2
LR-ASPP MobileNetV3-Large 57.9 91.2
================================ ============= ====================
from torchvision.io.image import read_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
img = read_image("gallery/assets/dog1.jpg")
Fully Convolutional Networks
----------------------------
# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()
.. autofunction:: torchvision.models.segmentation.fcn_resnet50
.. autofunction:: torchvision.models.segmentation.fcn_resnet101
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
DeepLabV3
---------
# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet50
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet101
.. autofunction:: torchvision.models.segmentation.deeplabv3_mobilenet_v3_large
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
The output format of the models is illustrated in :ref:`semantic_seg_output`.
LR-ASPP
-------
Table of all available semantic segmentation weights
----------------------------------------------------
All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:
.. include:: generated/segmentation_table.rst
.. autofunction:: torchvision.models.segmentation.lraspp_mobilenet_v3_large
.. _object_det_inst_seg_pers_keypoint_det:
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================
The models subpackage contains definitions for the following model
architectures for detection:
- `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
- `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_
- `RetinaNet <https://arxiv.org/abs/1708.02002>`_
- `SSD <https://arxiv.org/abs/1512.02325>`_
- `SSDlite <https://arxiv.org/abs/1801.04381>`_
The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models
in torchvision.
The models expect a list of ``Tensor[C, H, W]``, in the range ``0-1``.
The models internally resize the images but the behaviour varies depending
on the model. Check the constructor of the models for more information. The
output format of such models is illustrated in :ref:`instance_seg_output`.
For object detection and instance segmentation, the pre-trained
models return the predictions of the following classes:
.. code-block:: python
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
Here are the summary of the accuracies for the models trained on
the instances set of COCO train2017 and evaluated on COCO val2017.
====================================== ======= ======== ===========
Network box AP mask AP keypoint AP
====================================== ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
RetinaNet ResNet-50 FPN 36.4 - -
SSD300 VGG16 25.1 - -
SSDlite320 MobileNetV3-Large 21.3 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
====================================== ======= ======== ===========
in torchvision. The models expect a list of ``Tensor[C, H, W]``.
Check the constructor of the models for more information.
.. betastatus:: detection module
Object Detection
----------------
.. currentmodule:: torchvision.models.detection
The following object detection models are available, with or without pre-trained
weights:
.. toctree::
:maxdepth: 1
models/faster_rcnn
models/fcos
models/retinanet
models/ssd
models/ssdlite
|
For person keypoint detection, the accuracies for the pre-trained
models are as follows
Here is an example of how to use the pre-trained object detection models:
================================ ======= ======== ===========
Network box AP mask AP keypoint AP
================================ ======= ======== ===========
Keypoint R-CNN ResNet-50 FPN 54.6 - 65.0
================================ ======= ======== ===========
.. code:: python
from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()
For person keypoint detection, the pre-trained model return the
keypoints in the following order:
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
.. code-block:: python
# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]
COCO_PERSON_KEYPOINT_NAMES = [
'nose',
'left_eye',
'right_eye',
'left_ear',
'right_ear',
'left_shoulder',
'right_shoulder',
'left_elbow',
'right_elbow',
'left_wrist',
'right_wrist',
'left_hip',
'right_hip',
'left_knee',
'right_knee',
'left_ankle',
'right_ankle'
]
# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
labels=labels,
colors="red",
width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()
Runtime characteristics
-----------------------
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`.
The implementations of the models for object detection, instance segmentation
and keypoint detection are efficient.
Table of all available Object detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In the following table, we use 8 GPUs to report the results. During training,
we use a batch size of 2 per GPU for all models except SSD which uses 4
and SSDlite which uses 24. During testing a batch size of 1 is used.
Box MAPs are reported on COCO val2017:
For test time, we report the time for the model evaluation and postprocessing
(including mask pasting in image), but not the time for computing the
precision-recall.
.. include:: generated/detection_table.rst
====================================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB)
====================================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD300 VGG16 0.2093 0.0744 1.5
SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
====================================== =================== ================== ===========
Instance Segmentation
---------------------
Faster R-CNN
------------
.. currentmodule:: torchvision.models.detection
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn
The following instance segmentation models are available, with or without pre-trained
weights:
.. toctree::
:maxdepth: 1
RetinaNet
---------
models/mask_rcnn
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn
|
SSD
---
For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`.
.. autofunction:: torchvision.models.detection.ssd300_vgg16
Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box and Mask MAPs are reported on COCO val2017:
SSDlite
-------
.. include:: generated/instance_segmentation_table.rst
.. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large
Keypoint Detection
------------------
.. currentmodule:: torchvision.models.detection
Mask R-CNN
----------
The following person keypoint detection models are available, with or without
pre-trained weights:
.. autofunction:: torchvision.models.detection.maskrcnn_resnet50_fpn
.. toctree::
:maxdepth: 1
models/keypoint_rcnn
Keypoint R-CNN
--------------
|
.. autofunction:: torchvision.models.detection.keypointrcnn_resnet50_fpn
The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`.
Table of all available Keypoint detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Video classification
Box and Keypoint MAPs are reported on COCO val2017:
.. include:: generated/detection_keypoint_table.rst
Video Classification
====================
We provide models for action recognition pre-trained on Kinetics-400.
They have all been trained with the scripts provided in ``references/video_classification``.
.. currentmodule:: torchvision.models.video
All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB videos of shape (3 x T x H x W),
where H and W are expected to be 112, and T is a number of video frames in a clip.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.43216, 0.394666, 0.37645]`` and ``std = [0.22803, 0.22145, 0.216989]``.
.. betastatus:: video module
The following video classification models are available, with or without
pre-trained weights:
.. note::
The normalization parameters are different from the image classification ones, and correspond
to the mean and std from Kinetics-400.
.. toctree::
:maxdepth: 1
.. note::
For now, normalization code can be found in ``references/video_classification/transforms.py``,
see the ``Normalize`` function there. Note that it differs from standard normalization for
images because it assumes the video is 4d.
models/video_mvit
models/video_resnet
models/video_s3d
models/video_swin_transformer
|
Here is an example of how to use the pre-trained video classification models:
.. code:: python
from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
vid = vid[:32] # optionally shorten duration
# Step 1: Initialize model with the best available weights
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Kinetics 1-crop accuracies for clip length 16 (16x112x112)
================================ ============= =============
Network Clip acc@1 Clip acc@5
================================ ============= =============
ResNet 3D 18 52.75 75.45
ResNet MC 18 53.90 76.29
ResNet (2+1)D 57.50 78.81
================================ ============= =============
Table of all available video classification weights
---------------------------------------------------
Accuracies are reported on Kinetics-400 using single crops for clip length 16:
ResNet 3D
----------
.. include:: generated/video_table.rst
.. autofunction:: torchvision.models.video.r3d_18
Optical Flow
============
ResNet Mixed Convolution
------------------------
.. currentmodule:: torchvision.models.optical_flow
.. autofunction:: torchvision.models.video.mc3_18
The following Optical Flow models are available, with or without pre-trained
ResNet (2+1)D
-------------
.. toctree::
:maxdepth: 1
.. autofunction:: torchvision.models.video.r2plus1d_18
models/raft
AlexNet
=======
.. currentmodule:: torchvision.models
The AlexNet model was originally introduced in the
`ImageNet Classification with Deep Convolutional Neural Networks
<https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
paper. The implemented architecture is slightly different from the original one,
and is based on `One weird trick for parallelizing convolutional neural networks
<https://arxiv.org/abs/1404.5997>`__.
Model builders
--------------
The following model builders can be used to instantiate an AlexNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.alexnet.AlexNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
alexnet
ConvNeXt
========
.. currentmodule:: torchvision.models
The ConvNeXt model is based on the `A ConvNet for the 2020s
<https://arxiv.org/abs/2201.03545>`_ paper.
Model builders
--------------
The following model builders can be used to instantiate a ConvNeXt model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.convnext.ConvNeXt`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
convnext_tiny
convnext_small
convnext_base
convnext_large
DeepLabV3
=========
.. currentmodule:: torchvision.models.segmentation
The DeepLabV3 model is based on the `Rethinking Atrous Convolution for Semantic
Image Segmentation <https://arxiv.org/abs/1706.05587>`__ paper.
.. betastatus:: segmentation module
Model builders
--------------
The following model builders can be used to instantiate a DeepLabV3 model with
different backbones, with or without pre-trained weights. All the model builders
internally rely on the ``torchvision.models.segmentation.deeplabv3.DeepLabV3`` base class. Please
refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/deeplabv3.py>`_
for more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
deeplabv3_mobilenet_v3_large
deeplabv3_resnet50
deeplabv3_resnet101
DenseNet
========
.. currentmodule:: torchvision.models
The DenseNet model is based on the `Densely Connected Convolutional Networks
<https://arxiv.org/abs/1608.06993>`_ paper.
Model builders
--------------
The following model builders can be used to instantiate a DenseNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.densenet.DenseNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
densenet121
densenet161
densenet169
densenet201
EfficientNet
============
.. currentmodule:: torchvision.models
The EfficientNet model is based on the `EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks <https://arxiv.org/abs/1905.11946>`__
paper.
Model builders
--------------
The following model builders can be used to instantiate an EfficientNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.efficientnet.EfficientNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
efficientnet_b0
efficientnet_b1
efficientnet_b2
efficientnet_b3
efficientnet_b4
efficientnet_b5
efficientnet_b6
efficientnet_b7
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment