Unverified Commit 769ae132 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add more info on new models.srt (#6025)



* Minor updates on model examples.

* Improving wording of auto-generated docs.

* Add general info for pre-trained weights.

* Updating torch hub

* Minor updates

* Make lengthy meta-data partially visible

* Adding meta-data and reference info.

* Minor corrections

* Update docs/source/models_new.rst
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Moving Torch hub section at the end
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 44252c81
......@@ -347,10 +347,6 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
metrics = meta.pop("metrics", {})
meta_with_metrics = dict(meta, **metrics)
# We don't want to document these, they can be too long
for k in ["categories", "keypoint_names"]:
meta_with_metrics.pop(k, None)
custom_docs = meta_with_metrics.pop("_docs", None) # Custom per-Weights docs
if custom_docs is not None:
lines += [custom_docs, ""]
......@@ -360,6 +356,10 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
v = f"`link <{v}>`__"
elif k == "min_size":
v = f"height={v[0]}, width={v[1]}"
elif k in {"categories", "keypoint_names"} and isinstance(v, list):
max_visible = 3
v_sample = ", ".join(v[:max_visible])
v = f"{v_sample}, ... ({len(v)-max_visible} omitted)" if len(v) > max_visible else v_sample
table.append((str(k), str(v)))
table = tabulate(table, tablefmt="rst")
lines += [".. rst-class:: table-weights"] # Custom CSS class, see custom_torchvision.css
......@@ -367,7 +367,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines += textwrap.indent(table, " " * 4).split("\n")
lines.append("")
lines.append(
f"The inference transforms are available at ``{str(field)}.transforms`` and "
f"The preprocessing/inference transforms are available at ``{str(field)}.transforms`` and "
f"perform the following operations: {field.transforms().describe()}"
)
lines.append("")
......
......@@ -3,30 +3,42 @@
Models and pre-trained weights - New
####################################
.. note::
These are the new models docs, documenting the new multi-weight API.
TODO: Once all is done, remove the "- New" part in the title above, and
rename this file as models.rst
The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection, video classification, and optical flow.
General information on pre-trained weights
==========================================
TorchVision offers pre-trained weights for every provided architecture, using
the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its
weights to a cache directory. This directory can be set using the `TORCH_HOME`
environment variable. See :func:`torch.hub.load_state_dict_from_url` for details.
.. note::
The pre-trained models provided in this library may have their own licenses or
terms and conditions derived from the dataset used for training. It is your
responsibility to determine whether you have permission to use the models for
your use case.
.. note ::
Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized
``ScriptModules`` (seralized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized
``ScriptModules`` (serialized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
Initializing pre-trained models
-------------------------------
As of v0.13, TorchVision offers a new `Multi-weight support API
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_ for loading different weights to the
existing model builder methods:
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_
for loading different weights to the existing model builder methods:
.. code:: python
......@@ -46,7 +58,7 @@ existing model builder methods:
resnet50(weights="IMAGENET1K_V2")
# No weights - random initialization
resnet50(weights=None) # or resnet50()
resnet50(weights=None)
Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
......@@ -57,16 +69,57 @@ Migrating to the new API is very straightforward. The following method calls bet
# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True) # deprecated
resnet50(True) # deprecated
# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False) # deprecated
resnet50(False) # deprecated
Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.
Using the pre-trained models
----------------------------
Before using the pre-trained models, one must preprocess the image
(resize with right resolution/interpolation, apply inference transforms,
rescale the values etc). There is no standard way to do this as it depends on
how a given model was trained. It can vary across model families, variants or
even weight versions. Using the correct preprocessing method is critical and
failing to do so may lead to decreased accuracy or incorrect outputs.
All the necessary information for the inference transforms of each pre-trained
model is provided on its weights documentation. To simplify inference, TorchVision
bundles the necessary preprocessing transforms into each model weight. These are
accessible via the ``weight.transforms`` attribute:
.. code:: python
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Apply it to the input image
img_transformed = preprocess(img)
Some models use modules which have different training and evaluation
behavior, such as batch normalization. To switch between these modes, use
``model.train()`` or ``model.eval()`` as appropriate. See
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
.. code:: python
# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
# Set model to eval mode
model.eval()
Classification
==============
......@@ -128,10 +181,12 @@ Here is an example of how to use the pre-trained image classification models:
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Table of all available classification weights
---------------------------------------------
Accuracies are reported on ImageNet
Accuracies are reported on ImageNet-1K using single crops:
.. include:: generated/classification_table.rst
......@@ -140,7 +195,7 @@ Quantized models
.. currentmodule:: torchvision.models.quantization
The following quantized classification models are available, with or without
The following architectures provide support for INT8 quantized models, with or without
pre-trained weights:
.. toctree::
......@@ -181,11 +236,13 @@ Here is an example of how to use the pre-trained quantized image classification
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Accuracies are reported on ImageNet
Accuracies are reported on ImageNet-1K using single crops:
.. include:: generated/classification_quant_table.rst
......@@ -234,11 +291,14 @@ Here is an example of how to use the pre-trained semantic segmentation models:
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
The output format of the models is illustrated in :ref:`semantic_seg_output`.
Table of all available semantic segmentation weights
----------------------------------------------------
All models are evaluated on COCO val2017:
All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:
.. include:: generated/segmentation_table.rst
......@@ -247,6 +307,11 @@ All models are evaluated on COCO val2017:
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================
The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models
in torchvision. The models expect a list of ``Tensor[C, H, W]``.
Check the constructor of the models for more information.
Object Detection
----------------
......@@ -299,10 +364,13 @@ Here is an example of how to use the pre-trained object detection models:
im = to_pil_image(box.detach())
im.show()
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`.
Table of all available Object detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box MAPs are reported on COCO
Box MAPs are reported on COCO val2017:
.. include:: generated/detection_table.rst
......@@ -319,10 +387,15 @@ weights:
models/mask_rcnn
|
For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`.
Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box and Mask MAPs are reported on COCO
Box and Mask MAPs are reported on COCO val2017:
.. include:: generated/instance_segmentation_table.rst
......@@ -331,7 +404,7 @@ Keypoint Detection
.. currentmodule:: torchvision.models.detection
The following keypoint detection models are available, with or without
The following person keypoint detection models are available, with or without
pre-trained weights:
.. toctree::
......@@ -339,10 +412,15 @@ pre-trained weights:
models/keypoint_rcnn
|
The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`.
Table of all available Keypoint detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box and Keypoint MAPs are reported on COCO:
Box and Keypoint MAPs are reported on COCO val2017:
.. include:: generated/detection_keypoint_table.rst
......@@ -391,10 +469,32 @@ Here is an example of how to use the pre-trained video classification models:
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Table of all available video classification weights
---------------------------------------------------
Accuracies are reported on Kinetics-400
Accuracies are reported on Kinetics-400 using single crops for clip length 16:
.. include:: generated/video_table.rst
Using models from Hub
=====================
Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
.. code:: python
import torch
# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
The only exception to the above are the detection models included on
:mod:`torchvision.models.detection`. These models require TorchVision
to be installed because they depend on custom C++ operators.
......@@ -379,6 +379,8 @@ show(dogs_with_masks)
# instance with class 15 (which corresponds to 'bench') was not selected.
#####################################
# .. _keypoint_output:
#
# Visualizing keypoints
# ------------------------------
# The :func:`~torchvision.utils.draw_keypoints` function can be used to
......
......@@ -71,8 +71,8 @@ class ImageClassification(nn.Module):
def describe(self) -> str:
return (
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
)
......@@ -127,8 +127,8 @@ class VideoClassification(nn.Module):
def describe(self) -> str:
return (
f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
)
......@@ -168,7 +168,8 @@ class SemanticSegmentation(nn.Module):
def describe(self) -> str:
return (
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Then the values are rescaled to ``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment