"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "f8241bfba384cf8c888847dc44b73d7f43a42d82"
Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
...@@ -24,7 +24,7 @@ docset: html ...@@ -24,7 +24,7 @@ docset: html
convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png
html-noplot: # Avoids running the gallery examples, which may take time html-noplot: # Avoids running the gallery examples, which may take time
$(SPHINXBUILD) -D plot_gallery=0 -b html $(ASPHINXOPTS) "${SOURCEDIR}" "$(BUILDDIR)"/html $(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html
@echo @echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html." @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
...@@ -32,6 +32,8 @@ clean: ...@@ -32,6 +32,8 @@ clean:
rm -rf $(BUILDDIR)/* rm -rf $(BUILDDIR)/*
rm -rf $(SOURCEDIR)/auto_examples/ # sphinx-gallery rm -rf $(SOURCEDIR)/auto_examples/ # sphinx-gallery
rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery
rm -rf $(SOURCEDIR)/generated/ # autosummary
rm -rf $(SOURCEDIR)/models/generated # autosummary
.PHONY: help Makefile docset .PHONY: help Makefile docset
......
sphinx==2.4.4
sphinx-gallery>=0.9.0
sphinx-copybutton>=0.3.1
matplotlib matplotlib
numpy numpy
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme sphinx-copybutton>=0.3.1
sphinx-gallery>=0.11.1
sphinx==5.0.0
tabulate
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
pycocotools
/* This rule (and possibly this entire file) should be removed once /* This rule should be removed once
https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed. https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed.
We override the rule so that the links to the notebooks aren't hidden in the We override the rule so that the links to the notebooks aren't hidden in the
...@@ -9,4 +9,27 @@ torchvision it just hides the links. So we have to put them back here */ ...@@ -9,4 +9,27 @@ torchvision it just hides the links. So we have to put them back here */
article.pytorch-article .sphx-glr-download-link-note.admonition.note, article.pytorch-article .sphx-glr-download-link-note.admonition.note,
article.pytorch-article .reference.download.internal, article.pytorch-article .sphx-glr-signature { article.pytorch-article .reference.download.internal, article.pytorch-article .sphx-glr-signature {
display: block; display: block;
} }
\ No newline at end of file
/* These 2 rules below are for the weight tables (generated in conf.py) to look
* better. In particular we make their row height shorter */
.table-weights td, .table-weights th {
margin-bottom: 0.2rem;
padding: 0 !important;
line-height: 1 !important;
}
.table-weights p {
margin-bottom: 0.2rem !important;
}
/* Fix for Sphinx gallery 0.11
See https://github.com/sphinx-gallery/sphinx-gallery/issues/990
*/
article.pytorch-article .sphx-glr-thumbnails .sphx-glr-thumbcontainer {
width: unset;
margin-right: 0;
margin-left: 0;
}
article.pytorch-article div.section div.wy-table-responsive tbody td {
width: 50%;
}
...@@ -30,4 +30,4 @@ ...@@ -30,4 +30,4 @@
style="fill:#9e529f" style="fill:#9e529f"
id="path4698" id="path4698"
d="m 24.075479,-7.6293945e-7 c -0.5,0 -1.8,2.49999996293945 -1.8,3.59999996293945 0,1.5 1,2 1.8,2 0.8,0 1.8,-0.5 1.8,-2 -0.1,-1.1 -1.4,-3.59999996293945 -1.8,-3.59999996293945 z" d="m 24.075479,-7.6293945e-7 c -0.5,0 -1.8,2.49999996293945 -1.8,3.59999996293945 0,1.5 1,2 1.8,2 0.8,0 1.8,-0.5 1.8,-2 -0.1,-1.1 -1.4,-3.59999996293945 -1.8,-3.59999996293945 z"
class="st1" /></svg> class="st1" /></svg>
\ No newline at end of file
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autoclass:: {{ name }}
:members:
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autoclass:: {{ name }}
:members:
__getitem__,
{% if "category_name" in methods %} category_name {% endif %}
:special-members:
.. role:: hidden
:class: hidden-section
.. currentmodule:: {{ module }}
{{ name | underline}}
.. autofunction:: {{ name }}
from docutils import nodes
from docutils.parsers.rst import Directive
class BetaStatus(Directive):
has_content = True
text = "The {api_name} is in Beta stage, and backward compatibility is not guaranteed."
node = nodes.warning
def run(self):
text = self.text.format(api_name=" ".join(self.content))
return [self.node("", nodes.paragraph("", "", nodes.Text(text)))]
def setup(app):
app.add_directive("betastatus", BetaStatus)
return {
"version": "0.1",
"parallel_read_safe": True,
"parallel_write_safe": True,
}
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*-
# #
# PyTorch documentation build configuration file, created by # PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016. # sphinx-quickstart on Fri Dec 23 13:31:47 2016.
...@@ -21,79 +20,146 @@ ...@@ -21,79 +20,146 @@
# import sys # import sys
# sys.path.insert(0, os.path.abspath('.')) # sys.path.insert(0, os.path.abspath('.'))
import torchvision import os
import sys
import textwrap
from copy import copy
from pathlib import Path
import pytorch_sphinx_theme import pytorch_sphinx_theme
import torchvision
import torchvision.models as M
from sphinx_gallery.sorting import ExplicitOrder
from tabulate import tabulate
sys.path.append(os.path.abspath("."))
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here. # Required version of sphinx is set from docs/requirements.txt
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.autosummary', "sphinx.ext.autosummary",
'sphinx.ext.doctest', "sphinx.ext.doctest",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.todo', "sphinx.ext.todo",
'sphinx.ext.coverage', "sphinx.ext.mathjax",
'sphinx.ext.mathjax', "sphinx.ext.napoleon",
'sphinx.ext.napoleon', "sphinx.ext.viewcode",
'sphinx.ext.viewcode', "sphinx.ext.duration",
'sphinx.ext.duration', "sphinx_gallery.gen_gallery",
'sphinx_gallery.gen_gallery', "sphinx_copybutton",
"sphinx_copybutton" "beta_status",
] ]
# We override sphinx-gallery's example header to prevent sphinx-gallery from
# creating a note at the top of the renderred notebook.
# https://github.com/sphinx-gallery/sphinx-gallery/blob/451ccba1007cc523f39cbcc960ebc21ca39f7b75/sphinx_gallery/gen_rst.py#L1267-L1271
# This is because we also want to add a link to google collab, so we write our own note in each example.
from sphinx_gallery import gen_rst
gen_rst.EXAMPLE_HEADER = """
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "{0}"
.. LINE NUMBERS ARE GIVEN BELOW.
.. rst-class:: sphx-glr-example-title
.. _sphx_glr_{1}:
"""
class CustomGalleryExampleSortKey:
# See https://sphinx-gallery.github.io/stable/configuration.html#sorting-gallery-examples
# and https://github.com/sphinx-gallery/sphinx-gallery/blob/master/sphinx_gallery/sorting.py
def __init__(self, src_dir):
self.src_dir = src_dir
transforms_subsection_order = [
"plot_transforms_getting_started.py",
"plot_transforms_illustrations.py",
"plot_transforms_e2e.py",
"plot_cutmix_mixup.py",
"plot_custom_transforms.py",
"plot_tv_tensors.py",
"plot_custom_tv_tensors.py",
]
def __call__(self, filename):
if "gallery/transforms" in self.src_dir:
try:
return self.transforms_subsection_order.index(filename)
except ValueError as e:
raise ValueError(
"Looks like you added an example in gallery/transforms? "
"You need to specify its order in docs/source/conf.py. Look for CustomGalleryExampleSortKey."
) from e
else:
# For other subsections we just sort alphabetically by filename
return filename
sphinx_gallery_conf = { sphinx_gallery_conf = {
'examples_dirs': '../../gallery/', # path to your example scripts "examples_dirs": "../../gallery/", # path to your example scripts
'gallery_dirs': 'auto_examples', # path to where to save gallery generated output "gallery_dirs": "auto_examples", # path to where to save gallery generated output
'backreferences_dir': 'gen_modules/backreferences', "subsection_order": ExplicitOrder(["../../gallery/transforms", "../../gallery/others"]),
'doc_module': ('torchvision',), "backreferences_dir": "gen_modules/backreferences",
"doc_module": ("torchvision",),
"remove_config_comments": True,
"ignore_pattern": "helpers.py",
"within_subsection_order": CustomGalleryExampleSortKey,
} }
napoleon_use_ivar = True napoleon_use_ivar = True
napoleon_numpy_docstring = False napoleon_numpy_docstring = False
napoleon_google_docstring = True napoleon_google_docstring = True
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
# source_suffix = ['.rst', '.md'] source_suffix = {
source_suffix = '.rst' ".rst": "restructuredtext",
}
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = 'Torchvision' project = "Torchvision"
copyright = '2017-present, Torch Contributors' copyright = "2017-present, Torch Contributors"
author = 'Torch Contributors' author = "Torch Contributors"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
# built documents. # built documents.
# # version: The short X.Y version.
# The short X.Y version. # release: The full version, including alpha/beta/rc tags.
# TODO: change to [:2] at v1.0 if os.environ.get("TORCHVISION_SANITIZE_VERSION_STR_IN_DOCS", None):
version = '0.10.0' # Turn 1.11.0aHASH into 1.11 (major.minor only)
# The full version, including alpha/beta/rc tags. version = release = ".".join(torchvision.__version__.split(".")[:2])
# TODO: verify this works as expected html_title = " ".join((project, version, "documentation"))
release = torchvision.__version__ else:
version = f"main ({torchvision.__version__})"
release = "main"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
# #
# This is also used if you do content translation via gettext catalogs. # This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases. # Usually you set "language" from the command line for these cases.
language = None language = "en"
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
...@@ -101,7 +167,7 @@ language = None ...@@ -101,7 +167,7 @@ language = None
exclude_patterns = [] exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing. # If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True todo_include_todos = True
...@@ -112,7 +178,7 @@ todo_include_todos = True ...@@ -112,7 +178,7 @@ todo_include_todos = True
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'pytorch_sphinx_theme' html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
...@@ -120,58 +186,57 @@ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] ...@@ -120,58 +186,57 @@ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# documentation. # documentation.
# #
html_theme_options = { html_theme_options = {
'collapse_navigation': False, "collapse_navigation": False,
'display_version': True, "display_version": True,
'logo_only': True, "logo_only": True,
'pytorch_project': 'docs', "pytorch_project": "docs",
'navigation_with_keys': True, "navigation_with_keys": True,
'analytics_id': 'UA-117752657-2', "analytics_id": "GTM-T8XT4PS",
} }
html_logo = '_static/img/pytorch-logo-dark.svg' html_logo = "_static/img/pytorch-logo-dark.svg"
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
# TODO: remove this once https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed # TODO: remove this once https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed
html_css_files = [ html_css_files = [
'css/custom_torchvision.css', "css/custom_torchvision.css",
] ]
# -- Options for HTMLHelp output ------------------------------------------ # -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc' htmlhelp_basename = "PyTorchdoc"
# -- Options for LaTeX output --------------------------------------------- autosummary_generate = True
# -- Options for LaTeX output ---------------------------------------------
latex_elements = { latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
} }
# Grouping the document tree into LaTeX files. List of tuples # Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'pytorch.tex', 'torchvision Documentation', (master_doc, "pytorch.tex", "torchvision Documentation", "Torch Contributors", "manual"),
'Torch Contributors', 'manual'),
] ]
...@@ -179,10 +244,7 @@ latex_documents = [ ...@@ -179,10 +244,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [(master_doc, "torchvision", "torchvision Documentation", [author], 1)]
(master_doc, 'torchvision', 'torchvision Documentation',
[author], 1)
]
# -- Options for Texinfo output ------------------------------------------- # -- Options for Texinfo output -------------------------------------------
...@@ -191,27 +253,33 @@ man_pages = [ ...@@ -191,27 +253,33 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'torchvision', 'torchvision Documentation', (
author, 'torchvision', 'One line description of project.', master_doc,
'Miscellaneous'), "torchvision",
"torchvision Documentation",
author,
"torchvision",
"One line description of project.",
"Miscellaneous",
),
] ]
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = { intersphinx_mapping = {
'python': ('https://docs.python.org/', None), "python": ("https://docs.python.org/3/", None),
'torch': ('https://pytorch.org/docs/stable/', None), "torch": ("https://pytorch.org/docs/stable/", None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None), "numpy": ("https://numpy.org/doc/stable/", None),
'PIL': ('https://pillow.readthedocs.io/en/stable/', None), "PIL": ("https://pillow.readthedocs.io/en/stable/", None),
'matplotlib': ('https://matplotlib.org/stable/', None), "matplotlib": ("https://matplotlib.org/stable/", None),
} }
# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043 # See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes from sphinx import addnodes
from sphinx.util.docfields import TypedField
def patched_make_field(self, types, domain, items, **kw): def patched_make_field(self, types, domain, items, **kw):
...@@ -221,40 +289,39 @@ def patched_make_field(self, types, domain, items, **kw): ...@@ -221,40 +289,39 @@ def patched_make_field(self, types, domain, items, **kw):
# type: (list, unicode, tuple) -> nodes.field # noqa: F821 # type: (list, unicode, tuple) -> nodes.field # noqa: F821
def handle_item(fieldarg, content): def handle_item(fieldarg, content):
par = nodes.paragraph() par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added par += addnodes.literal_strong("", fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg, # par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong)) # addnodes.literal_strong))
if fieldarg in types: if fieldarg in types:
par += nodes.Text(' (') par += nodes.Text(" (")
# NOTE: using .pop() here to prevent a single type node to be # NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to # inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved # inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg) fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype) typename = "".join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int') typename = typename.replace("int", "python:int")
typename = typename.replace('long', 'python:long') typename = typename.replace("long", "python:long")
typename = typename.replace('float', 'python:float') typename = typename.replace("float", "python:float")
typename = typename.replace('type', 'python:type') typename = typename.replace("type", "python:type")
par.extend(self.make_xrefs(self.typerolename, domain, typename, par.extend(self.make_xrefs(self.typerolename, domain, typename, addnodes.literal_emphasis, **kw))
addnodes.literal_emphasis, **kw))
else: else:
par += fieldtype par += fieldtype
par += nodes.Text(')') par += nodes.Text(")")
par += nodes.Text(' -- ') par += nodes.Text(" -- ")
par += content par += content
return par return par
fieldname = nodes.field_name('', self.label) fieldname = nodes.field_name("", self.label)
if len(items) == 1 and self.can_collapse: if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0] fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content) bodynode = handle_item(fieldarg, content)
else: else:
bodynode = self.list_type() bodynode = self.list_type()
for fieldarg, content in items: for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content)) bodynode += nodes.list_item("", handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode) fieldbody = nodes.field_body("", bodynode)
return nodes.field('', fieldname, fieldbody) return nodes.field("", fieldname, fieldbody)
TypedField.make_field = patched_make_field TypedField.make_field = patched_make_field
...@@ -286,5 +353,172 @@ def inject_minigalleries(app, what, name, obj, options, lines): ...@@ -286,5 +353,172 @@ def inject_minigalleries(app, what, name, obj, options, lines):
lines.append("\n") lines.append("\n")
def inject_weight_metadata(app, what, name, obj, options, lines):
"""This hook is used to generate docs for the models weights.
Objects like ResNet18_Weights are enums with fields, where each field is a Weight object.
Enums aren't easily documented in Python so the solution we're going for is to:
- add an autoclass directive in the model's builder docstring, e.g.
```
.. autoclass:: torchvision.models.ResNet34_Weights
:members:
```
(see resnet.py for an example)
- then this hook is called automatically when building the docs, and it generates the text that gets
used within the autoclass directive.
"""
if getattr(obj, "__name__", "").endswith(("_Weights", "_QuantizedWeights")):
if len(obj) == 0:
lines[:] = ["There are no available pre-trained weights."]
return
lines[:] = [
"The model builder above accepts the following values as the ``weights`` parameter.",
f"``{obj.__name__}.DEFAULT`` is equivalent to ``{obj.DEFAULT}``. You can also use strings, e.g. "
f"``weights='DEFAULT'`` or ``weights='{str(list(obj)[0]).split('.')[1]}'``.",
]
if obj.__doc__ != "An enumeration.":
# We only show the custom enum doc if it was overridden. The default one from Python is "An enumeration"
lines.append("")
lines.append(obj.__doc__)
lines.append("")
for field in obj:
meta = copy(field.meta)
lines += [f"**{str(field)}**:", ""]
lines += [meta.pop("_docs")]
if field == obj.DEFAULT:
lines += [f"Also available as ``{obj.__name__}.DEFAULT``."]
lines += [""]
table = []
metrics = meta.pop("_metrics")
for dataset, dataset_metrics in metrics.items():
for metric_name, metric_value in dataset_metrics.items():
table.append((f"{metric_name} (on {dataset})", str(metric_value)))
for k, v in meta.items():
if k in {"recipe", "license"}:
v = f"`link <{v}>`__"
elif k == "min_size":
v = f"height={v[0]}, width={v[1]}"
elif k in {"categories", "keypoint_names"} and isinstance(v, list):
max_visible = 3
v_sample = ", ".join(v[:max_visible])
v = f"{v_sample}, ... ({len(v)-max_visible} omitted)" if len(v) > max_visible else v_sample
elif k == "_ops":
v = f"{v:.2f}"
k = "GIPS" if obj.__name__.endswith("_QuantizedWeights") else "GFLOPS"
elif k == "_file_size":
k = "File size"
v = f"{v:.1f} MB"
table.append((str(k), str(v)))
table = tabulate(table, tablefmt="rst")
lines += [".. rst-class:: table-weights"] # Custom CSS class, see custom_torchvision.css
lines += [".. table::", ""]
lines += textwrap.indent(table, " " * 4).split("\n")
lines.append("")
lines.append(
f"The inference transforms are available at ``{str(field)}.transforms`` and "
f"perform the following preprocessing operations: {field.transforms().describe()}"
)
lines.append("")
def generate_weights_table(module, table_name, metrics, dataset, include_patterns=None, exclude_patterns=None):
weights_endswith = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
weight_enums = [getattr(module, name) for name in dir(module) if name.endswith(weights_endswith)]
weights = [w for weight_enum in weight_enums for w in weight_enum]
if include_patterns is not None:
weights = [w for w in weights if any(p in str(w) for p in include_patterns)]
if exclude_patterns is not None:
weights = [w for w in weights if all(p not in str(w) for p in exclude_patterns)]
ops_name = "GIPS" if "QuantizedWeights" in weights_endswith else "GFLOPS"
metrics_keys, metrics_names = zip(*metrics)
column_names = ["Weight"] + list(metrics_names) + ["Params"] + [ops_name, "Recipe"] # Final column order
column_names = [f"**{name}**" for name in column_names] # Add bold
content = []
for w in weights:
row = [
f":class:`{w} <{type(w).__name__}>`",
*(w.meta["_metrics"][dataset][metric] for metric in metrics_keys),
f"{w.meta['num_params']/1e6:.1f}M",
f"{w.meta['_ops']:.2f}",
f"`link <{w.meta['recipe']}>`__",
]
content.append(row)
column_widths = ["110"] + ["18"] * len(metrics_names) + ["18"] * 2 + ["10"]
widths_table = " ".join(column_widths)
table = tabulate(content, headers=column_names, tablefmt="rst")
generated_dir = Path("generated")
generated_dir.mkdir(exist_ok=True)
with open(generated_dir / f"{table_name}_table.rst", "w+") as table_file:
table_file.write(".. rst-class:: table-weights\n") # Custom CSS class, see custom_torchvision.css
table_file.write(".. table::\n")
table_file.write(f" :widths: {widths_table} \n\n")
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")
generate_weights_table(
module=M, table_name="classification", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")], dataset="ImageNet-1K"
)
generate_weights_table(
module=M.quantization,
table_name="classification_quant",
metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")],
dataset="ImageNet-1K",
)
generate_weights_table(
module=M.detection,
table_name="detection",
metrics=[("box_map", "Box MAP")],
exclude_patterns=["Mask", "Keypoint"],
dataset="COCO-val2017",
)
generate_weights_table(
module=M.detection,
table_name="instance_segmentation",
metrics=[("box_map", "Box MAP"), ("mask_map", "Mask MAP")],
dataset="COCO-val2017",
include_patterns=["Mask"],
)
generate_weights_table(
module=M.detection,
table_name="detection_keypoint",
metrics=[("box_map", "Box MAP"), ("kp_map", "Keypoint MAP")],
dataset="COCO-val2017",
include_patterns=["Keypoint"],
)
generate_weights_table(
module=M.segmentation,
table_name="segmentation",
metrics=[("miou", "Mean IoU"), ("pixel_acc", "pixelwise Acc")],
dataset="COCO-val2017-VOC-labels",
)
generate_weights_table(
module=M.video, table_name="video", metrics=[("acc@1", "Acc@1"), ("acc@5", "Acc@5")], dataset="Kinetics-400"
)
def setup(app): def setup(app):
app.connect('autodoc-process-docstring', inject_minigalleries)
app.connect("autodoc-process-docstring", inject_minigalleries)
app.connect("autodoc-process-docstring", inject_weight_metadata)
torchvision.datasets .. _datasets:
====================
Datasets
========
Torchvision provides many built-in datasets in the ``torchvision.datasets``
module, as well as utility classes for building your own datasets.
Built-in datasets
-----------------
All datasets are subclasses of :class:`torch.utils.data.Dataset` All datasets are subclasses of :class:`torch.utils.data.Dataset`
i.e, they have ``__getitem__`` and ``__len__`` methods implemented. i.e, they have ``__getitem__`` and ``__len__`` methods implemented.
...@@ -19,242 +27,157 @@ All the datasets have almost similar API. They all have two common arguments: ...@@ -19,242 +27,157 @@ All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively. ``transform`` and ``target_transform`` to transform the input and target respectively.
You can also create your own datasets using the provided :ref:`base classes <base_classes_datasets>`. You can also create your own datasets using the provided :ref:`base classes <base_classes_datasets>`.
Caltech Image classification
~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
.. autoclass:: Caltech101 .. autosummary::
:members: __getitem__ :toctree: generated/
:special-members: :template: class_dataset.rst
.. autoclass:: Caltech256 Caltech101
:members: __getitem__ Caltech256
:special-members: CelebA
CIFAR10
CelebA CIFAR100
~~~~~~ Country211
DTD
.. autoclass:: CelebA EMNIST
:members: __getitem__ EuroSAT
:special-members: FakeData
FashionMNIST
CIFAR FER2013
~~~~~ FGVCAircraft
Flickr8k
.. autoclass:: CIFAR10 Flickr30k
:members: __getitem__ Flowers102
:special-members: Food101
GTSRB
.. autoclass:: CIFAR100 INaturalist
ImageNet
Cityscapes Imagenette
~~~~~~~~~~ KMNIST
LFWPeople
.. note :: LSUN
Requires Cityscape to be downloaded. MNIST
Omniglot
.. autoclass:: Cityscapes OxfordIIITPet
:members: __getitem__ Places365
:special-members: PCAM
QMNIST
COCO RenderedSST2
~~~~ SEMEION
SBU
.. note :: StanfordCars
These require the `COCO API to be installed`_ STL10
SUN397
.. _COCO API to be installed: https://github.com/pdollar/coco/tree/master/PythonAPI SVHN
USPS
Captions Image detection or segmentation
^^^^^^^^ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: CocoCaptions .. autosummary::
:members: __getitem__ :toctree: generated/
:special-members: :template: class_dataset.rst
CocoDetection
Detection CelebA
^^^^^^^^^ Cityscapes
Kitti
.. autoclass:: CocoDetection OxfordIIITPet
:members: __getitem__ SBDataset
:special-members: VOCSegmentation
VOCDetection
WIDERFace
EMNIST
~~~~~~ Optical Flow
.. autoclass:: EMNIST
FakeData
~~~~~~~~
.. autoclass:: FakeData
Fashion-MNIST
~~~~~~~~~~~~~
.. autoclass:: FashionMNIST
Flickr
~~~~~~
.. autoclass:: Flickr8k
:members: __getitem__
:special-members:
.. autoclass:: Flickr30k
:members: __getitem__
:special-members:
HMDB51
~~~~~~~
.. autoclass:: HMDB51
:members: __getitem__
:special-members:
ImageNet
~~~~~~~~~~~
.. autoclass:: ImageNet
.. note ::
This requires `scipy` to be installed
Kinetics-400
~~~~~~~~~~~~ ~~~~~~~~~~~~
.. autoclass:: Kinetics400 .. autosummary::
:members: __getitem__ :toctree: generated/
:special-members: :template: class_dataset.rst
KITTI FlyingChairs
~~~~~~~~~ FlyingThings3D
HD1K
.. autoclass:: Kitti KittiFlow
:members: __getitem__ Sintel
:special-members:
Stereo Matching
KMNIST ~~~~~~~~~~~~~~~
~~~~~~~~~~~~~
.. autosummary::
.. autoclass:: KMNIST :toctree: generated/
:template: class_dataset.rst
LSUN
~~~~ CarlaStereo
Kitti2012Stereo
.. autoclass:: LSUN Kitti2015Stereo
:members: __getitem__ CREStereo
:special-members: FallingThingsStereo
SceneFlowStereo
MNIST SintelStereo
~~~~~ InStereo2k
ETH3DStereo
.. autoclass:: MNIST Middlebury2014Stereo
Omniglot Image pairs
~~~~~~~~ ~~~~~~~~~~~
.. autoclass:: Omniglot
PhotoTour
~~~~~~~~~
.. autoclass:: PhotoTour
:members: __getitem__
:special-members:
Places365
~~~~~~~~~
.. autoclass:: Places365
:members: __getitem__
:special-members:
QMNIST
~~~~~~
.. autoclass:: QMNIST
SBD
~~~~~~
.. autoclass:: SBDataset
:members: __getitem__
:special-members:
SBU
~~~
.. autoclass:: SBU
:members: __getitem__
:special-members:
SEMEION
~~~~~~~
.. autoclass:: SEMEION
:members: __getitem__
:special-members:
STL10
~~~~~
.. autoclass:: STL10
:members: __getitem__
:special-members:
SVHN
~~~~~
.. autoclass:: SVHN .. autosummary::
:members: __getitem__ :toctree: generated/
:special-members: :template: class_dataset.rst
UCF101 LFWPairs
~~~~~~~ PhotoTour
.. autoclass:: UCF101 Image captioning
:members: __getitem__ ~~~~~~~~~~~~~~~~
:special-members:
USPS .. autosummary::
~~~~~ :toctree: generated/
:template: class_dataset.rst
.. autoclass:: USPS CocoCaptions
:members: __getitem__
:special-members:
VOC Video classification
~~~~~~ ~~~~~~~~~~~~~~~~~~~~
.. autoclass:: VOCSegmentation .. autosummary::
:members: __getitem__ :toctree: generated/
:special-members: :template: class_dataset.rst
.. autoclass:: VOCDetection HMDB51
:members: __getitem__ Kinetics
:special-members: UCF101
WIDERFace Video prediction
~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
.. autoclass:: WIDERFace .. autosummary::
:members: __getitem__ :toctree: generated/
:special-members: :template: class_dataset.rst
MovingMNIST
.. _base_classes_datasets: .. _base_classes_datasets:
Base classes for custom datasets Base classes for custom datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ --------------------------------
.. autosummary::
:toctree: generated/
:template: class.rst
DatasetFolder
ImageFolder
VisionDataset
.. autoclass:: DatasetFolder Transforms v2
:members: __getitem__, find_classes, make_dataset -------------
:special-members:
.. autosummary::
:toctree: generated/
:template: function.rst
.. autoclass:: ImageFolder wrap_dataset_for_transforms_v2
:members: __getitem__
:special-members:
# Necessary for the table generated by autosummary to look decent
[html writers]
table_style: colwidths-auto
Feature extraction for model inspection
=======================================
.. currentmodule:: torchvision.models.feature_extraction
The ``torchvision.models.feature_extraction`` package contains
feature extraction utilities that let us tap into our models to access intermediate
transformations of our inputs. This could be useful for a variety of
applications in computer vision. Just a few examples are:
- Visualizing feature maps.
- Extracting features to compute image descriptors for tasks like facial
recognition, copy-detection, or image retrieval.
- Passing selected features to downstream sub-networks for end-to-end training
with a specific task in mind. For example, passing a hierarchy of features
to a Feature Pyramid Network with object detection heads.
Torchvision provides :func:`create_feature_extractor` for this purpose.
It works by following roughly these steps:
1. Symbolically tracing the model to get a graphical representation of
how it transforms the input, step by step.
2. Setting the user-selected graph nodes as outputs.
3. Removing all redundant nodes (anything downstream of the output nodes).
4. Generating python code from the resulting graph and bundling that into a
PyTorch module together with the graph itself.
|
The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing.
.. _about-node-names:
**About Node Names**
In order to specify which nodes should be output nodes for extracted
features, one should be familiar with the node naming convention used here
(which differs slightly from that used in ``torch.fx``). A node name is
specified as a ``.`` separated path walking the module hierarchy from top level
module down to leaf operation or leaf module. For instance ``"layer4.2.relu"``
in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th
layer of the ``ResNet`` module. Here are some finer points to keep in mind:
- When specifying node names for :func:`create_feature_extractor`, you may
provide a truncated version of a node name as a shortcut. To see how this
works, try creating a ResNet-50 model and printing the node names with
``train_nodes, _ = get_graph_node_names(model) print(train_nodes)`` and
observe that the last node pertaining to ``layer4`` is
``"layer4.2.relu_2"``. One may specify ``"layer4.2.relu_2"`` as the return
node, or just ``"layer4"`` as this, by convention, refers to the last node
(in order of execution) of ``layer4``.
- If a certain module or operation is repeated more than once, node names get
an additional ``_{int}`` postfix to disambiguate. For instance, maybe the
addition (``+``) operation is used three times in the same ``forward``
method. Then there would be ``"path.to.module.add"``,
``"path.to.module.add_1"``, ``"path.to.module.add_2"``. The counter is
maintained within the scope of the direct parent. So in ResNet-50 there is
a ``"layer4.1.add"`` and a ``"layer4.2.add"``. Because the addition
operations reside in different blocks, there is no need for a postfix to
disambiguate.
**An Example**
Here is an example of how we might extract features for MaskRCNN:
.. code-block:: python
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())
# The lists returned, are the names of all the graph nodes (in order of
# execution) for the input model traced in train mode and in eval mode
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# for this example. But if the model contains control flow that's dependent
# on the training mode, they may be different.
# To specify the nodes you want to extract, you could select the final node
# that appears in each of the main layers:
return_nodes = {
# node_name: user-specified key for output dict
'layer1.2.relu_2': 'layer1',
'layer2.3.relu_2': 'layer2',
'layer3.5.relu_2': 'layer3',
'layer4.2.relu_2': 'layer4',
}
# But `create_feature_extractor` can also accept truncated node specifications
# like "layer1", as it will just pick the last node that's a descendent of
# of the specification. (Tip: be careful with this, especially when a layer
# has multiple outputs. It's not always guaranteed that the last operation
# performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.)
return_nodes = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4',
}
# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
# 'layer1': output of layer 1,
# 'layer2': output of layer 2,
# 'layer3': output of layer 3,
# 'layer4': output of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)
# Let's put all that together to wrap resnet50 with MaskRCNN
# MaskRCNN requires a backbone with an attached FPN
class Resnet50WithFPN(torch.nn.Module):
def __init__(self):
super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone
m = resnet50()
# Extract 4 main layers (note: MaskRCNN needs this particular name
# mapping for return nodes)
self.body = create_feature_extractor(
m, return_nodes={f'layer{k}': str(v)
for v, k in enumerate([1, 2, 3, 4])})
# Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224)
with torch.no_grad():
out = self.body(inp)
in_channels_list = [o.shape[1] for o in out.values()]
# Build FPN
self.out_channels = 256
self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels,
extra_blocks=LastLevelMaxPool())
def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x
# Now we can build our model!
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
API Reference
-------------
.. autosummary::
:toctree: generated/
:template: function.rst
create_feature_extractor
get_graph_node_names
...@@ -31,18 +31,21 @@ architectures, and common image transformations for computer vision. ...@@ -31,18 +31,21 @@ architectures, and common image transformations for computer vision.
:maxdepth: 2 :maxdepth: 2
:caption: Package Reference :caption: Package Reference
datasets
io
models
ops
transforms transforms
tv_tensors
models
datasets
utils utils
ops
io
feature_extraction
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Examples :caption: Examples and training references
auto_examples/index auto_examples/index
training_references
.. automodule:: torchvision .. automodule:: torchvision
:members: :members:
...@@ -58,3 +61,9 @@ architectures, and common image transformations for computer vision. ...@@ -58,3 +61,9 @@ architectures, and common image transformations for computer vision.
TorchElastic <https://pytorch.org/elastic/> TorchElastic <https://pytorch.org/elastic/>
TorchServe <https://pytorch.org/serve> TorchServe <https://pytorch.org/serve>
PyTorch on XLA Devices <http://pytorch.org/xla/> PyTorch on XLA Devices <http://pytorch.org/xla/>
Indices
-------
* :ref:`genindex`
torchvision.io Decoding / Encoding images and videos
============== =====================================
.. currentmodule:: torchvision.io .. currentmodule:: torchvision.io
The :mod:`torchvision.io` package provides functions for performing IO The :mod:`torchvision.io` package provides functions for performing IO
operations. They are currently specific to reading and writing video and operations. They are currently specific to reading and writing images and
images. videos.
Images
------
.. autosummary::
:toctree: generated/
:template: function.rst
read_image
decode_image
encode_jpeg
decode_jpeg
write_jpeg
decode_gif
encode_png
decode_png
write_png
read_file
write_file
.. autosummary::
:toctree: generated/
:template: class.rst
ImageReadMode
Video Video
----- -----
.. autofunction:: read_video .. autosummary::
:toctree: generated/
.. autofunction:: read_video_timestamps :template: function.rst
.. autofunction:: write_video read_video
read_video_timestamps
write_video
Fine-grained video API Fine-grained video API
---------------------- ^^^^^^^^^^^^^^^^^^^^^^
In addition to the :mod:`read_video` function, we provide a high-performance In addition to the :mod:`read_video` function, we provide a high-performance
lower-level API for more fine-grained control compared to the :mod:`read_video` function. lower-level API for more fine-grained control compared to the :mod:`read_video` function.
It does all this whilst fully supporting torchscript. It does all this whilst fully supporting torchscript.
.. autoclass:: VideoReader .. betastatus:: fine-grained video API
:members: __next__, get_metadata, set_current_stream, seek
.. autosummary::
:toctree: generated/
:template: class.rst
VideoReader
Example of inspecting a video: Example of inspecting a video:
...@@ -54,29 +88,3 @@ Example of inspecting a video: ...@@ -54,29 +88,3 @@ Example of inspecting a video:
# the constructor we select a default video stream, but # the constructor we select a default video stream, but
# in practice, we can set whichever stream we would like # in practice, we can set whichever stream we would like
video.set_current_stream("video:0") video.set_current_stream("video:0")
Image
-----
.. autoclass:: ImageReadMode
.. autofunction:: read_image
.. autofunction:: decode_image
.. autofunction:: encode_jpeg
.. autofunction:: decode_jpeg
.. autofunction:: write_jpeg
.. autofunction:: encode_png
.. autofunction:: decode_png
.. autofunction:: write_png
.. autofunction:: read_file
.. autofunction:: write_file
torchvision.models .. _models:
##################
Models and pre-trained weights
##############################
The models subpackage contains definitions of models for addressing The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person segmentation, object detection, instance segmentation, person
keypoint detection and video classification. keypoint detection, video classification, and optical flow.
General information on pre-trained weights
==========================================
Classification TorchVision offers pre-trained weights for every provided architecture, using
============== the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its
weights to a cache directory. This directory can be set using the `TORCH_HOME`
environment variable. See :func:`torch.hub.load_state_dict_from_url` for details.
.. note::
The pre-trained models provided in this library may have their own licenses or
terms and conditions derived from the dataset used for training. It is your
responsibility to determine whether you have permission to use the models for
your use case.
.. note ::
Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized
``ScriptModules`` (serialized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
The models subpackage contains definitions for the following model
architectures for image classification: Initializing pre-trained models
-------------------------------
- `AlexNet`_
- `VGG`_ As of v0.13, TorchVision offers a new `Multi-weight support API
- `ResNet`_ <https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_
- `SqueezeNet`_ for loading different weights to the existing model builder methods:
- `DenseNet`_
- `Inception`_ v3
- `GoogLeNet`_
- `ShuffleNet`_ v2
- `MobileNetV2`_
- `MobileNetV3`_
- `ResNeXt`_
- `Wide ResNet`_
- `MNASNet`_
You can construct a model with random weights by calling its constructor:
.. code:: python .. code:: python
import torchvision.models as models from torchvision.models import resnet50, ResNet50_Weights
resnet18 = models.resnet18()
alexnet = models.alexnet() # Old weights with accuracy 76.130%
vgg16 = models.vgg16() resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
squeezenet = models.squeezenet1_0()
densenet = models.densenet161() # New weights with accuracy 80.858%
inception = models.inception_v3() resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0() # Best available weights (currently alias for IMAGENET1K_V2)
mobilenet_v2 = models.mobilenet_v2() # Note that these weights may change across versions
mobilenet_v3_large = models.mobilenet_v3_large() resnet50(weights=ResNet50_Weights.DEFAULT)
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d() # Strings are also supported
wide_resnet50_2 = models.wide_resnet50_2() resnet50(weights="IMAGENET1K_V2")
mnasnet = models.mnasnet1_0()
# No weights - random initialization
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. resnet50(weights=None)
These can be constructed by passing ``pretrained=True``:
Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
.. code:: python .. code:: python
import torchvision.models as models from torchvision.models import resnet50, ResNet50_Weights
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True) # Using pretrained weights:
squeezenet = models.squeezenet1_0(pretrained=True) resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
vgg16 = models.vgg16(pretrained=True) resnet50(weights="IMAGENET1K_V1")
densenet = models.densenet161(pretrained=True) resnet50(pretrained=True) # deprecated
inception = models.inception_v3(pretrained=True) resnet50(True) # deprecated
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True) # Using no weights:
mobilenet_v2 = models.mobilenet_v2(pretrained=True) resnet50(weights=None)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True) resnet50()
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True) resnet50(pretrained=False) # deprecated
resnext50_32x4d = models.resnext50_32x4d(pretrained=True) resnet50(False) # deprecated
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True) Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.
Instancing a pre-trained model will download its weights to a cache directory. Using the pre-trained models
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See ----------------------------
:func:`torch.utils.model_zoo.load_url` for details.
Before using the pre-trained models, one must preprocess the image
(resize with right resolution/interpolation, apply inference transforms,
rescale the values etc). There is no standard way to do this as it depends on
how a given model was trained. It can vary across model families, variants or
even weight versions. Using the correct preprocessing method is critical and
failing to do so may lead to decreased accuracy or incorrect outputs.
All the necessary information for the inference transforms of each pre-trained
model is provided on its weights documentation. To simplify inference, TorchVision
bundles the necessary preprocessing transforms into each model weight. These are
accessible via the ``weight.transforms`` attribute:
.. code:: python
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Apply it to the input image
img_transformed = preprocess(img)
Some models use modules which have different training and evaluation Some models use modules which have different training and evaluation
behavior, such as batch normalization. To switch between these modes, use behavior, such as batch normalization. To switch between these modes, use
``model.train()`` or ``model.eval()`` as appropriate. See ``model.train()`` or ``model.eval()`` as appropriate. See
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details. :meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
All pre-trained models expect input images normalized in the same way, .. code:: python
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
where H and W are expected to be at least 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
You can use the following transform to normalize::
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], # Initialize model
std=[0.229, 0.224, 0.225]) weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
An example of such normalization can be found in the imagenet example # Set model to eval mode
`here <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>`_ model.eval()
The process for obtaining the values of `mean` and `std` is roughly equivalent Listing and retrieving available models
to:: ---------------------------------------
import torch As of v0.14, TorchVision offers a new mechanism which allows listing and
from torchvision import datasets, transforms as T retrieving models and weights by their names. Here are a few examples on how to
use them:
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
dataset = datasets.ImageNet(".", split="train", transform=transform)
means = []
stds = []
for img in subset(dataset):
means.append(torch.mean(img))
stds.append(torch.std(img))
mean = torch.mean(torch.tensor(means))
std = torch.mean(torch.tensor(stds))
Unfortunately, the concrete `subset` that was used is lost. For more
information see `this discussion <https://github.com/pytorch/vision/issues/1439>`_
or `these experiments <https://github.com/pytorch/vision/pull/1965>`_.
ImageNet 1-crop error rates (224x224)
================================ ============= =============
Model Acc@1 Acc@5
================================ ============= =============
AlexNet 56.522 79.066
VGG-11 69.020 88.628
VGG-13 69.928 89.246
VGG-16 71.592 90.382
VGG-19 72.376 90.876
VGG-11 with batch normalization 70.370 89.810
VGG-13 with batch normalization 71.586 90.374
VGG-16 with batch normalization 73.360 91.516
VGG-19 with batch normalization 74.218 91.842
ResNet-18 69.758 89.078
ResNet-34 73.314 91.420
ResNet-50 76.130 92.862
ResNet-101 77.374 93.546
ResNet-152 78.312 94.046
SqueezeNet 1.0 58.092 80.420
SqueezeNet 1.1 58.178 80.624
Densenet-121 74.434 91.972
Densenet-169 75.600 92.806
Densenet-201 76.896 93.370
Densenet-161 77.138 93.560
Inception v3 77.294 93.450
GoogleNet 69.778 89.530
ShuffleNet V2 x1.0 69.362 88.316
ShuffleNet V2 x0.5 60.552 81.746
MobileNet V2 71.878 90.286
MobileNet V3 Large 74.042 91.340
MobileNet V3 Small 67.668 87.402
ResNeXt-50-32x4d 77.618 93.698
ResNeXt-101-32x8d 79.312 94.526
Wide ResNet-50-2 78.468 94.086
Wide ResNet-101-2 78.848 94.284
MNASNet 1.0 73.456 91.510
MNASNet 0.5 67.734 87.490
================================ ============= =============
.. _AlexNet: https://arxiv.org/abs/1404.5997
.. _VGG: https://arxiv.org/abs/1409.1556
.. _ResNet: https://arxiv.org/abs/1512.03385
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
.. _DenseNet: https://arxiv.org/abs/1608.06993
.. _Inception: https://arxiv.org/abs/1512.00567
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNetV2: https://arxiv.org/abs/1801.04381
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626
.. currentmodule:: torchvision.models .. code:: python
Alexnet # List available models
------- all_models = list_models()
classification_models = list_models(module=torchvision.models)
.. autofunction:: alexnet # Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
VGG # Fetch weights
--- weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
.. autofunction:: vgg11 weights_enum = get_model_weights("quantized_mobilenet_v3_large")
.. autofunction:: vgg11_bn assert weights_enum == MobileNet_V3_Large_QuantizedWeights
.. autofunction:: vgg13
.. autofunction:: vgg13_bn
.. autofunction:: vgg16
.. autofunction:: vgg16_bn
.. autofunction:: vgg19
.. autofunction:: vgg19_bn
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2
ResNet Here are the available public functions to retrieve models and their corresponding weights:
------
.. autofunction:: resnet18 .. currentmodule:: torchvision.models
.. autofunction:: resnet34 .. autosummary::
.. autofunction:: resnet50 :toctree: generated/
.. autofunction:: resnet101 :template: function.rst
.. autofunction:: resnet152
SqueezeNet get_model
---------- get_model_weights
get_weight
list_models
.. autofunction:: squeezenet1_0 Using models from Hub
.. autofunction:: squeezenet1_1 ---------------------
DenseNet Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
---------
.. autofunction:: densenet121 .. code:: python
.. autofunction:: densenet169
.. autofunction:: densenet161
.. autofunction:: densenet201
Inception v3 import torch
------------
.. autofunction:: inception_v3 # Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
.. note :: # Option 2: passing weights param as enum
This requires `scipy` to be installed weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
You can also retrieve all the available weights of a specific model via PyTorch Hub by doing:
GoogLeNet .. code:: python
------------
.. autofunction:: googlenet import torch
.. note :: weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
This requires `scipy` to be installed print([weight for weight in weight_enum])
The only exception to the above are the detection models included on
:mod:`torchvision.models.detection`. These models require TorchVision
to be installed because they depend on custom C++ operators.
Classification
==============
.. currentmodule:: torchvision.models
ShuffleNet v2 The following classification models are available, with or without pre-trained
------------- weights:
.. toctree::
:maxdepth: 1
models/alexnet
models/convnext
models/densenet
models/efficientnet
models/efficientnetv2
models/googlenet
models/inception
models/maxvit
models/mnasnet
models/mobilenetv2
models/mobilenetv3
models/regnet
models/resnet
models/resnext
models/shufflenetv2
models/squeezenet
models/swin_transformer
models/vgg
models/vision_transformer
models/wide_resnet
|
Here is an example of how to use the pre-trained image classification models:
.. autofunction:: shufflenet_v2_x0_5 .. code:: python
.. autofunction:: shufflenet_v2_x1_0
.. autofunction:: shufflenet_v2_x1_5
.. autofunction:: shufflenet_v2_x2_0
MobileNet v2 from torchvision.io import read_image
------------- from torchvision.models import resnet50, ResNet50_Weights
.. autofunction:: mobilenet_v2 img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
MobileNet v3 # Step 1: Initialize model with the best available weights
------------- weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
.. autofunction:: mobilenet_v3_large # Step 2: Initialize the inference transforms
.. autofunction:: mobilenet_v3_small preprocess = weights.transforms()
ResNext # Step 3: Apply inference preprocessing transforms
------- batch = preprocess(img).unsqueeze(0)
.. autofunction:: resnext50_32x4d # Step 4: Use the model and print the predicted category
.. autofunction:: resnext101_32x8d prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
Wide ResNet The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
-----------
.. autofunction:: wide_resnet50_2 Table of all available classification weights
.. autofunction:: wide_resnet101_2 ---------------------------------------------
MNASNet Accuracies are reported on ImageNet-1K using single crops:
--------
.. autofunction:: mnasnet0_5 .. include:: generated/classification_table.rst
.. autofunction:: mnasnet0_75
.. autofunction:: mnasnet1_0
.. autofunction:: mnasnet1_3
Quantized Models Quantized models
---------------- ----------------
The following architectures provide support for INT8 quantized models. You can get .. currentmodule:: torchvision.models.quantization
a model with random weights by calling its constructor:
.. code:: python The following architectures provide support for INT8 quantized models, with or without
pre-trained weights:
.. toctree::
:maxdepth: 1
import torchvision.models as models models/googlenet_quant
googlenet = models.quantization.googlenet() models/inception_quant
inception_v3 = models.quantization.inception_v3() models/mobilenetv2_quant
mobilenet_v2 = models.quantization.mobilenet_v2() models/mobilenetv3_quant
mobilenet_v3_large = models.quantization.mobilenet_v3_large() models/resnet_quant
resnet18 = models.quantization.resnet18() models/resnext_quant
resnet50 = models.quantization.resnet50() models/shufflenetv2_quant
resnext101_32x8d = models.quantization.resnext101_32x8d()
shufflenet_v2_x0_5 = models.quantization.shufflenet_v2_x0_5() |
shufflenet_v2_x1_0 = models.quantization.shufflenet_v2_x1_0()
shufflenet_v2_x1_5 = models.quantization.shufflenet_v2_x1_5() Here is an example of how to use the pre-trained quantized image classification models:
shufflenet_v2_x2_0 = models.quantization.shufflenet_v2_x2_0()
Obtaining a pre-trained quantized model can be done with a few lines of code:
.. code:: python .. code:: python
import torchvision.models as models from torchvision.io import read_image
model = models.quantization.mobilenet_v2(pretrained=True, quantize=True) from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval() model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))
We provide pre-trained quantized weights for the following models: # Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")
================================ ============= ============= The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
ShuffleNet V2 68.360 87.582
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
Inception V3 77.176 93.354
GoogleNet 69.826 89.404
================================ ============= =============
Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Accuracies are reported on ImageNet-1K using single crops:
.. include:: generated/classification_quant_table.rst
Semantic Segmentation Semantic Segmentation
===================== =====================
The models subpackage contains definitions for the following model .. currentmodule:: torchvision.models.segmentation
architectures for semantic segmentation:
- `FCN ResNet50, ResNet101 <https://arxiv.org/abs/1411.4038>`_ .. betastatus:: segmentation module
- `DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large <https://arxiv.org/abs/1706.05587>`_
- `LR-ASPP MobileNetV3-Large <https://arxiv.org/abs/1905.02244>`_
As with image classification models, all pre-trained models expect input images normalized in the same way. The following semantic segmentation models are available, with or without
The images have to be loaded in to a range of ``[0, 1]`` and then normalized using pre-trained weights:
``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
They have been trained on images resized such that their minimum size is 520.
For details on how to plot the masks of such models, you may refer to :ref:`semantic_seg_output`. .. toctree::
:maxdepth: 1
The pre-trained models have been trained on a subset of COCO train2017, on the 20 categories that are models/deeplabv3
present in the Pascal VOC dataset. You can see more information on how the subset has been selected in models/fcn
``references/segmentation/coco_utils.py``. The classes that the pre-trained model outputs are the following, models/lraspp
in order:
.. code-block:: python |
['__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', Here is an example of how to use the pre-trained semantic segmentation models:
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
The accuracies of the pre-trained models evaluated on COCO val2017 are as follows .. code:: python
================================ ============= ==================== from torchvision.io.image import read_image
Network mean IoU global pixelwise acc from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
================================ ============= ==================== from torchvision.transforms.functional import to_pil_image
FCN ResNet50 60.5 91.4
FCN ResNet101 63.7 91.9
DeepLabV3 ResNet50 66.4 92.4
DeepLabV3 ResNet101 67.4 92.4
DeepLabV3 MobileNetV3-Large 60.3 91.2
LR-ASPP MobileNetV3-Large 57.9 91.2
================================ ============= ====================
img = read_image("gallery/assets/dog1.jpg")
Fully Convolutional Networks # Step 1: Initialize model with the best available weights
---------------------------- weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()
.. autofunction:: torchvision.models.segmentation.fcn_resnet50 # Step 2: Initialize the inference transforms
.. autofunction:: torchvision.models.segmentation.fcn_resnet101 preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
DeepLabV3 # Step 4: Use the model and visualize the prediction
--------- prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet50 The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet101 The output format of the models is illustrated in :ref:`semantic_seg_output`.
.. autofunction:: torchvision.models.segmentation.deeplabv3_mobilenet_v3_large
LR-ASPP Table of all available semantic segmentation weights
------- ----------------------------------------------------
All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:
.. include:: generated/segmentation_table.rst
.. autofunction:: torchvision.models.segmentation.lraspp_mobilenet_v3_large
.. _object_det_inst_seg_pers_keypoint_det: .. _object_det_inst_seg_pers_keypoint_det:
Object Detection, Instance Segmentation and Person Keypoint Detection Object Detection, Instance Segmentation and Person Keypoint Detection
===================================================================== =====================================================================
The models subpackage contains definitions for the following model
architectures for detection:
- `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
- `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_
- `RetinaNet <https://arxiv.org/abs/1708.02002>`_
- `SSD <https://arxiv.org/abs/1512.02325>`_
- `SSDlite <https://arxiv.org/abs/1801.04381>`_
The pre-trained models for detection, instance segmentation and The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models keypoint detection are initialized with the classification models
in torchvision. in torchvision. The models expect a list of ``Tensor[C, H, W]``.
Check the constructor of the models for more information.
The models expect a list of ``Tensor[C, H, W]``, in the range ``0-1``.
The models internally resize the images but the behaviour varies depending .. betastatus:: detection module
on the model. Check the constructor of the models for more information. The
output format of such models is illustrated in :ref:`instance_seg_output`. Object Detection
----------------
For object detection and instance segmentation, the pre-trained .. currentmodule:: torchvision.models.detection
models return the predictions of the following classes:
The following object detection models are available, with or without pre-trained
.. code-block:: python weights:
COCO_INSTANCE_CATEGORY_NAMES = [ .. toctree::
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', :maxdepth: 1
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', models/faster_rcnn
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', models/fcos
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', models/retinanet
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', models/ssd
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', models/ssdlite
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', |
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
Here are the summary of the accuracies for the models trained on
the instances set of COCO train2017 and evaluated on COCO val2017.
====================================== ======= ======== ===========
Network box AP mask AP keypoint AP
====================================== ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
RetinaNet ResNet-50 FPN 36.4 - -
SSD300 VGG16 25.1 - -
SSDlite320 MobileNetV3-Large 21.3 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
====================================== ======= ======== ===========
For person keypoint detection, the accuracies for the pre-trained Here is an example of how to use the pre-trained object detection models:
models are as follows
================================ ======= ======== =========== .. code:: python
Network box AP mask AP keypoint AP
================================ ======= ======== ===========
Keypoint R-CNN ResNet-50 FPN 54.6 - 65.0 from torchvision.io.image import read_image
================================ ======= ======== =========== from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()
For person keypoint detection, the pre-trained model return the # Step 2: Initialize the inference transforms
keypoints in the following order: preprocess = weights.transforms()
.. code-block:: python # Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]
COCO_PERSON_KEYPOINT_NAMES = [ # Step 4: Use the model and visualize the prediction
'nose', prediction = model(batch)[0]
'left_eye', labels = [weights.meta["categories"][i] for i in prediction["labels"]]
'right_eye', box = draw_bounding_boxes(img, boxes=prediction["boxes"],
'left_ear', labels=labels,
'right_ear', colors="red",
'left_shoulder', width=4, font_size=30)
'right_shoulder', im = to_pil_image(box.detach())
'left_elbow', im.show()
'right_elbow',
'left_wrist',
'right_wrist',
'left_hip',
'right_hip',
'left_knee',
'right_knee',
'left_ankle',
'right_ankle'
]
Runtime characteristics The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
----------------------- For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`.
The implementations of the models for object detection, instance segmentation Table of all available Object detection weights
and keypoint detection are efficient. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In the following table, we use 8 GPUs to report the results. During training, Box MAPs are reported on COCO val2017:
we use a batch size of 2 per GPU for all models except SSD which uses 4
and SSDlite which uses 24. During testing a batch size of 1 is used.
For test time, we report the time for the model evaluation and postprocessing .. include:: generated/detection_table.rst
(including mask pasting in image), but not the time for computing the
precision-recall.
====================================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB)
====================================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD300 VGG16 0.2093 0.0744 1.5
SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
====================================== =================== ================== ===========
Instance Segmentation
---------------------
Faster R-CNN .. currentmodule:: torchvision.models.detection
------------
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn The following instance segmentation models are available, with or without pre-trained
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn weights:
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn
.. toctree::
:maxdepth: 1
RetinaNet models/mask_rcnn
---------
.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn |
SSD For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`.
---
.. autofunction:: torchvision.models.detection.ssd300_vgg16 Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box and Mask MAPs are reported on COCO val2017:
SSDlite .. include:: generated/instance_segmentation_table.rst
-------
.. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large Keypoint Detection
------------------
.. currentmodule:: torchvision.models.detection
Mask R-CNN The following person keypoint detection models are available, with or without
---------- pre-trained weights:
.. autofunction:: torchvision.models.detection.maskrcnn_resnet50_fpn .. toctree::
:maxdepth: 1
models/keypoint_rcnn
Keypoint R-CNN |
--------------
.. autofunction:: torchvision.models.detection.keypointrcnn_resnet50_fpn The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`.
Table of all available Keypoint detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Video classification Box and Keypoint MAPs are reported on COCO val2017:
.. include:: generated/detection_keypoint_table.rst
Video Classification
==================== ====================
We provide models for action recognition pre-trained on Kinetics-400. .. currentmodule:: torchvision.models.video
They have all been trained with the scripts provided in ``references/video_classification``.
All pre-trained models expect input images normalized in the same way, .. betastatus:: video module
i.e. mini-batches of 3-channel RGB videos of shape (3 x T x H x W),
where H and W are expected to be 112, and T is a number of video frames in a clip.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.43216, 0.394666, 0.37645]`` and ``std = [0.22803, 0.22145, 0.216989]``.
The following video classification models are available, with or without
pre-trained weights:
.. note:: .. toctree::
The normalization parameters are different from the image classification ones, and correspond :maxdepth: 1
to the mean and std from Kinetics-400.
.. note:: models/video_mvit
For now, normalization code can be found in ``references/video_classification/transforms.py``, models/video_resnet
see the ``Normalize`` function there. Note that it differs from standard normalization for models/video_s3d
images because it assumes the video is 4d. models/video_swin_transformer
|
Here is an example of how to use the pre-trained video classification models:
.. code:: python
from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
vid = vid[:32] # optionally shorten duration
# Step 1: Initialize model with the best available weights
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Kinetics 1-crop accuracies for clip length 16 (16x112x112)
================================ ============= ============= Table of all available video classification weights
Network Clip acc@1 Clip acc@5 ---------------------------------------------------
================================ ============= =============
ResNet 3D 18 52.75 75.45
ResNet MC 18 53.90 76.29
ResNet (2+1)D 57.50 78.81
================================ ============= =============
Accuracies are reported on Kinetics-400 using single crops for clip length 16:
ResNet 3D .. include:: generated/video_table.rst
----------
.. autofunction:: torchvision.models.video.r3d_18 Optical Flow
============
ResNet Mixed Convolution .. currentmodule:: torchvision.models.optical_flow
------------------------
.. autofunction:: torchvision.models.video.mc3_18 The following Optical Flow models are available, with or without pre-trained
ResNet (2+1)D .. toctree::
------------- :maxdepth: 1
.. autofunction:: torchvision.models.video.r2plus1d_18 models/raft
AlexNet
=======
.. currentmodule:: torchvision.models
The AlexNet model was originally introduced in the
`ImageNet Classification with Deep Convolutional Neural Networks
<https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
paper. The implemented architecture is slightly different from the original one,
and is based on `One weird trick for parallelizing convolutional neural networks
<https://arxiv.org/abs/1404.5997>`__.
Model builders
--------------
The following model builders can be used to instantiate an AlexNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.alexnet.AlexNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
alexnet
ConvNeXt
========
.. currentmodule:: torchvision.models
The ConvNeXt model is based on the `A ConvNet for the 2020s
<https://arxiv.org/abs/2201.03545>`_ paper.
Model builders
--------------
The following model builders can be used to instantiate a ConvNeXt model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.convnext.ConvNeXt`` base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
convnext_tiny
convnext_small
convnext_base
convnext_large
DeepLabV3
=========
.. currentmodule:: torchvision.models.segmentation
The DeepLabV3 model is based on the `Rethinking Atrous Convolution for Semantic
Image Segmentation <https://arxiv.org/abs/1706.05587>`__ paper.
.. betastatus:: segmentation module
Model builders
--------------
The following model builders can be used to instantiate a DeepLabV3 model with
different backbones, with or without pre-trained weights. All the model builders
internally rely on the ``torchvision.models.segmentation.deeplabv3.DeepLabV3`` base class. Please
refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/deeplabv3.py>`_
for more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
deeplabv3_mobilenet_v3_large
deeplabv3_resnet50
deeplabv3_resnet101
DenseNet
========
.. currentmodule:: torchvision.models
The DenseNet model is based on the `Densely Connected Convolutional Networks
<https://arxiv.org/abs/1608.06993>`_ paper.
Model builders
--------------
The following model builders can be used to instantiate a DenseNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.densenet.DenseNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
densenet121
densenet161
densenet169
densenet201
EfficientNet
============
.. currentmodule:: torchvision.models
The EfficientNet model is based on the `EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks <https://arxiv.org/abs/1905.11946>`__
paper.
Model builders
--------------
The following model builders can be used to instantiate an EfficientNet model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.efficientnet.EfficientNet`` base class. Please refer to the `source
code
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
efficientnet_b0
efficientnet_b1
efficientnet_b2
efficientnet_b3
efficientnet_b4
efficientnet_b5
efficientnet_b6
efficientnet_b7
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment