Unverified Commit 52a4480d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Models doc revamp: final cleanup (#6049)

* Remove models.rst

* Remove '- New'

* Put back torchhub section where it originally was
parent 5985504c
...@@ -6,9 +6,7 @@ ifneq ($(EXAMPLES_PATTERN),) ...@@ -6,9 +6,7 @@ ifneq ($(EXAMPLES_PATTERN),)
endif endif
# You can set these variables from the command line. # You can set these variables from the command line.
# TODO: Once the models doc revamp is done, set back the -W option to raise SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS)
# errors on warnings. See https://github.com/pytorch/vision/pull/5821#discussion_r850500693
SPHINXOPTS = -j auto $(EXAMPLES_PATTERN_OPTS)
SPHINXBUILD = sphinx-build SPHINXBUILD = sphinx-build
SPHINXPROJ = torchvision SPHINXPROJ = torchvision
SOURCEDIR = source SOURCEDIR = source
......
...@@ -347,7 +347,6 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -347,7 +347,6 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
metrics = meta.pop("_metrics") metrics = meta.pop("_metrics")
for dataset, dataset_metrics in metrics.items(): for dataset, dataset_metrics in metrics.items():
for metric_name, metric_value in dataset_metrics.items(): for metric_name, metric_value in dataset_metrics.items():
metric_name = metric_name.replace("_", "-")
table.append((f"{metric_name} (on {dataset})", str(metric_value))) table.append((f"{metric_name} (on {dataset})", str(metric_value)))
for k, v in meta.items(): for k, v in meta.items():
......
...@@ -38,7 +38,6 @@ architectures, and common image transformations for computer vision. ...@@ -38,7 +38,6 @@ architectures, and common image transformations for computer vision.
ops ops
io io
feature_extraction feature_extraction
models_new
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
...@@ -3,863 +3,513 @@ ...@@ -3,863 +3,513 @@
Models and pre-trained weights Models and pre-trained weights
############################## ##############################
The ``torchvision.models`` subpackage contains definitions of models for addressing The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person segmentation, object detection, instance segmentation, person
keypoint detection, video classification, and optical flow. keypoint detection, video classification, and optical flow.
General information on pre-trained weights
==========================================
TorchVision offers pre-trained weights for every provided architecture, using
the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its
weights to a cache directory. This directory can be set using the `TORCH_HOME`
environment variable. See :func:`torch.hub.load_state_dict_from_url` for details.
.. note::
The pre-trained models provided in this library may have their own licenses or
terms and conditions derived from the dataset used for training. It is your
responsibility to determine whether you have permission to use the models for
your use case.
.. note :: .. note ::
Backward compatibility is guaranteed for loading a serialized Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version. ``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized On the contrary, loading entire saved models or serialized
``ScriptModules`` (seralized using older versions of PyTorch) ``ScriptModules`` (serialized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following may not preserve the historic behaviour. Refer to the following
`documentation `documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_ <https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
Classification Initializing pre-trained models
============== -------------------------------
The models subpackage contains definitions for the following model As of v0.13, TorchVision offers a new `Multi-weight support API
architectures for image classification: <https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_
for loading different weights to the existing model builder methods:
- `AlexNet`_
- `VGG`_
- `ResNet`_
- `SqueezeNet`_
- `DenseNet`_
- `Inception`_ v3
- `GoogLeNet`_
- `ShuffleNet`_ v2
- `MobileNetV2`_
- `MobileNetV3`_
- `ResNeXt`_
- `Wide ResNet`_
- `MNASNet`_
- `EfficientNet`_ v1 & v2
- `RegNet`_
- `VisionTransformer`_
- `ConvNeXt`_
- `SwinTransformer`_
You can construct a model with random weights by calling its constructor:
.. code:: python .. code:: python
import torchvision.models as models from torchvision.models import resnet50, ResNet50_Weights
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
resnext101_32x8d = models.resnext101_32x8d()
resnext101_64x4d = models.resnext101_64x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
efficientnet_b0 = models.efficientnet_b0()
efficientnet_b1 = models.efficientnet_b1()
efficientnet_b2 = models.efficientnet_b2()
efficientnet_b3 = models.efficientnet_b3()
efficientnet_b4 = models.efficientnet_b4()
efficientnet_b5 = models.efficientnet_b5()
efficientnet_b6 = models.efficientnet_b6()
efficientnet_b7 = models.efficientnet_b7()
efficientnet_v2_s = models.efficientnet_v2_s()
efficientnet_v2_m = models.efficientnet_v2_m()
efficientnet_v2_l = models.efficientnet_v2_l()
regnet_y_400mf = models.regnet_y_400mf()
regnet_y_800mf = models.regnet_y_800mf()
regnet_y_1_6gf = models.regnet_y_1_6gf()
regnet_y_3_2gf = models.regnet_y_3_2gf()
regnet_y_8gf = models.regnet_y_8gf()
regnet_y_16gf = models.regnet_y_16gf()
regnet_y_32gf = models.regnet_y_32gf()
regnet_y_128gf = models.regnet_y_128gf()
regnet_x_400mf = models.regnet_x_400mf()
regnet_x_800mf = models.regnet_x_800mf()
regnet_x_1_6gf = models.regnet_x_1_6gf()
regnet_x_3_2gf = models.regnet_x_3_2gf()
regnet_x_8gf = models.regnet_x_8gf()
regnet_x_16gf = models.regnet_x_16gf()
regnet_x_32gf = models.regnet_x_32gf()
vit_b_16 = models.vit_b_16()
vit_b_32 = models.vit_b_32()
vit_l_16 = models.vit_l_16()
vit_l_32 = models.vit_l_32()
vit_h_14 = models.vit_h_14()
convnext_tiny = models.convnext_tiny()
convnext_small = models.convnext_small()
convnext_base = models.convnext_base()
convnext_large = models.convnext_large()
swin_t = models.swin_t()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_HOME` environment variable. See
:func:`torch.hub.load_state_dict_from_url` for details.
Some models use modules which have different training and evaluation # Old weights with accuracy 76.130%
behavior, such as batch normalization. To switch between these modes, use resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
``model.train()`` or ``model.eval()`` as appropriate. See
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
All pre-trained models expect input images normalized in the same way, # New weights with accuracy 80.858%
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
where H and W are expected to be at least 224.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
You can use the following transform to normalize::
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], # Best available weights (currently alias for IMAGENET1K_V2)
std=[0.229, 0.224, 0.225]) # Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)
An example of such normalization can be found in the imagenet example # Strings are also supported
`here <https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101>`_ resnet50(weights="IMAGENET1K_V2")
The process for obtaining the values of `mean` and `std` is roughly equivalent # No weights - random initialization
to:: resnet50(weights=None)
import torch
from torchvision import datasets, transforms as T
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.PILToTensor(), T.ConvertImageDtype(torch.float)])
dataset = datasets.ImageNet(".", split="train", transform=transform)
means = []
stds = []
for img in subset(dataset):
means.append(torch.mean(img))
stds.append(torch.std(img))
mean = torch.mean(torch.tensor(means))
std = torch.mean(torch.tensor(stds))
Unfortunately, the concrete `subset` that was used is lost. For more
information see `this discussion <https://github.com/pytorch/vision/issues/1439>`_
or `these experiments <https://github.com/pytorch/vision/pull/1965>`_.
The sizes of the EfficientNet models depend on the variant. For the exact input sizes
`check here <https://github.com/pytorch/vision/blob/d2bfd639e46e1c5dc3c177f889dc7750c8d137c7/references/classification/train.py#L92-L93>`_
ImageNet 1-crop error rates
================================ ============= =============
Model Acc@1 Acc@5
================================ ============= =============
AlexNet 56.522 79.066
VGG-11 69.020 88.628
VGG-13 69.928 89.246
VGG-16 71.592 90.382
VGG-19 72.376 90.876
VGG-11 with batch normalization 70.370 89.810
VGG-13 with batch normalization 71.586 90.374
VGG-16 with batch normalization 73.360 91.516
VGG-19 with batch normalization 74.218 91.842
ResNet-18 69.758 89.078
ResNet-34 73.314 91.420
ResNet-50 76.130 92.862
ResNet-101 77.374 93.546
ResNet-152 78.312 94.046
SqueezeNet 1.0 58.092 80.420
SqueezeNet 1.1 58.178 80.624
Densenet-121 74.434 91.972
Densenet-169 75.600 92.806
Densenet-201 76.896 93.370
Densenet-161 77.138 93.560
Inception v3 77.294 93.450
GoogleNet 69.778 89.530
ShuffleNet V2 x0.5 60.552 81.746
ShuffleNet V2 x1.0 69.362 88.316
ShuffleNet V2 x1.5 72.996 91.086
ShuffleNet V2 x2.0 76.230 93.006
MobileNet V2 71.878 90.286
MobileNet V3 Large 74.042 91.340
MobileNet V3 Small 67.668 87.402
ResNeXt-50-32x4d 77.618 93.698
ResNeXt-101-32x8d 79.312 94.526
ResNeXt-101-64x4d 83.246 96.454
Wide ResNet-50-2 78.468 94.086
Wide ResNet-101-2 78.848 94.284
MNASNet 1.0 73.456 91.510
MNASNet 0.5 67.734 87.490
EfficientNet-B0 77.692 93.532
EfficientNet-B1 78.642 94.186
EfficientNet-B2 80.608 95.310
EfficientNet-B3 82.008 96.054
EfficientNet-B4 83.384 96.594
EfficientNet-B5 83.444 96.628
EfficientNet-B6 84.008 96.916
EfficientNet-B7 84.122 96.908
EfficientNetV2-s 84.228 96.878
EfficientNetV2-m 85.112 97.156
EfficientNetV2-l 85.810 97.792
regnet_x_400mf 72.834 90.950
regnet_x_800mf 75.212 92.348
regnet_x_1_6gf 77.040 93.440
regnet_x_3_2gf 78.364 93.992
regnet_x_8gf 79.344 94.686
regnet_x_16gf 80.058 94.944
regnet_x_32gf 80.622 95.248
regnet_y_400mf 74.046 91.716
regnet_y_800mf 76.420 93.136
regnet_y_1_6gf 77.950 93.966
regnet_y_3_2gf 78.948 94.576
regnet_y_8gf 80.032 95.048
regnet_y_16gf 80.424 95.240
regnet_y_32gf 80.878 95.340
vit_b_16 81.072 95.318
vit_b_32 75.912 92.466
vit_l_16 79.662 94.638
vit_l_32 76.972 93.070
vit_h_14 88.552 98.694
convnext_tiny 82.520 96.146
convnext_small 83.616 96.650
convnext_base 84.062 96.870
convnext_large 84.414 96.976
swin_t 81.358 95.526
================================ ============= =============
.. _AlexNet: https://arxiv.org/abs/1404.5997
.. _VGG: https://arxiv.org/abs/1409.1556
.. _ResNet: https://arxiv.org/abs/1512.03385
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
.. _DenseNet: https://arxiv.org/abs/1608.06993
.. _Inception: https://arxiv.org/abs/1512.00567
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNetV2: https://arxiv.org/abs/1801.04381
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626
.. _EfficientNet: https://arxiv.org/abs/1905.11946
.. _RegNet: https://arxiv.org/abs/2003.13678
.. _VisionTransformer: https://arxiv.org/abs/2010.11929
.. _ConvNeXt: https://arxiv.org/abs/2201.03545
.. _SwinTransformer: https://arxiv.org/abs/2103.14030
.. currentmodule:: torchvision.models Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
Alexnet .. code:: python
-------
.. autosummary:: from torchvision.models import resnet50, ResNet50_Weights
:toctree: generated/
:template: function.rst
alexnet # Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True) # deprecated
resnet50(True) # deprecated
VGG # Using no weights:
--- resnet50(weights=None)
resnet50()
resnet50(pretrained=False) # deprecated
resnet50(False) # deprecated
.. autosummary:: Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.
:toctree: generated/
:template: function.rst
vgg11 Using the pre-trained models
vgg11_bn ----------------------------
vgg13
vgg13_bn
vgg16
vgg16_bn
vgg19
vgg19_bn
Before using the pre-trained models, one must preprocess the image
(resize with right resolution/interpolation, apply inference transforms,
rescale the values etc). There is no standard way to do this as it depends on
how a given model was trained. It can vary across model families, variants or
even weight versions. Using the correct preprocessing method is critical and
failing to do so may lead to decreased accuracy or incorrect outputs.
ResNet All the necessary information for the inference transforms of each pre-trained
------ model is provided on its weights documentation. To simplify inference, TorchVision
bundles the necessary preprocessing transforms into each model weight. These are
accessible via the ``weight.transforms`` attribute:
.. autosummary:: .. code:: python
:toctree: generated/
:template: function.rst
resnet18 # Initialize the Weight Transforms
resnet34 weights = ResNet50_Weights.DEFAULT
resnet50 preprocess = weights.transforms()
resnet101
resnet152
SqueezeNet # Apply it to the input image
---------- img_transformed = preprocess(img)
.. autosummary::
:toctree: generated/
:template: function.rst
squeezenet1_0 Some models use modules which have different training and evaluation
squeezenet1_1 behavior, such as batch normalization. To switch between these modes, use
``model.train()`` or ``model.eval()`` as appropriate. See
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
DenseNet .. code:: python
---------
.. autosummary:: # Initialize model
:toctree: generated/ weights = ResNet50_Weights.DEFAULT
:template: function.rst model = resnet50(weights=weights)
densenet121 # Set model to eval mode
densenet169 model.eval()
densenet161
densenet201
Inception v3 Using models from Hub
------------ ---------------------
.. autosummary:: Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
:toctree: generated/
:template: function.rst
inception_v3 .. code:: python
GoogLeNet import torch
------------
.. autosummary:: # Option 1: passing weights param as string
:toctree: generated/ model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
:template: function.rst
googlenet # Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
ShuffleNet v2 The only exception to the above are the detection models included on
------------- :mod:`torchvision.models.detection`. These models require TorchVision
to be installed because they depend on custom C++ operators.
.. autosummary:: Classification
:toctree: generated/ ==============
:template: function.rst
shufflenet_v2_x0_5 .. currentmodule:: torchvision.models
shufflenet_v2_x1_0
shufflenet_v2_x1_5
shufflenet_v2_x2_0
MobileNet v2 The following classification models are available, with or without pre-trained
------------- weights:
.. toctree::
:maxdepth: 1
models/alexnet
models/convnext
models/densenet
models/efficientnet
models/efficientnetv2
models/googlenet
models/inception
models/mnasnet
models/mobilenetv2
models/mobilenetv3
models/regnet
models/resnet
models/resnext
models/shufflenetv2
models/squeezenet
models/swin_transformer
models/vgg
models/vision_transformer
models/wide_resnet
|
Here is an example of how to use the pre-trained image classification models:
.. autosummary:: .. code:: python
:toctree: generated/
:template: function.rst
mobilenet_v2 from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
MobileNet v3 img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
-------------
.. autosummary:: # Step 1: Initialize model with the best available weights
:toctree: generated/ weights = ResNet50_Weights.DEFAULT
:template: function.rst model = resnet50(weights=weights)
model.eval()
mobilenet_v3_large # Step 2: Initialize the inference transforms
mobilenet_v3_small preprocess = weights.transforms()
ResNext
-------
.. autosummary::
:toctree: generated/
:template: function.rst
resnext50_32x4d
resnext101_32x8d
resnext101_64x4d
Wide ResNet # Step 3: Apply inference preprocessing transforms
----------- batch = preprocess(img).unsqueeze(0)
.. autosummary:: # Step 4: Use the model and print the predicted category
:toctree: generated/ prediction = model(batch).squeeze(0).softmax(0)
:template: function.rst class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
wide_resnet50_2 The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
wide_resnet101_2
MNASNet Table of all available classification weights
-------- ---------------------------------------------
.. autosummary:: Accuracies are reported on ImageNet-1K using single crops:
:toctree: generated/
:template: function.rst
mnasnet0_5 .. include:: generated/classification_table.rst
mnasnet0_75
mnasnet1_0
mnasnet1_3
EfficientNet Quantized models
------------ ----------------
.. autosummary:: .. currentmodule:: torchvision.models.quantization
:toctree: generated/
:template: function.rst
efficientnet_b0 The following architectures provide support for INT8 quantized models, with or without
efficientnet_b1 pre-trained weights:
efficientnet_b2
efficientnet_b3
efficientnet_b4
efficientnet_b5
efficientnet_b6
efficientnet_b7
efficientnet_v2_s
efficientnet_v2_m
efficientnet_v2_l
RegNet .. toctree::
------------ :maxdepth: 1
.. autosummary::
:toctree: generated/
:template: function.rst
regnet_y_400mf
regnet_y_800mf
regnet_y_1_6gf
regnet_y_3_2gf
regnet_y_8gf
regnet_y_16gf
regnet_y_32gf
regnet_y_128gf
regnet_x_400mf
regnet_x_800mf
regnet_x_1_6gf
regnet_x_3_2gf
regnet_x_8gf
regnet_x_16gf
regnet_x_32gf
VisionTransformer
-----------------
.. autosummary::
:toctree: generated/
:template: function.rst
vit_b_16
vit_b_32
vit_l_16
vit_l_32
vit_h_14
ConvNeXt
--------
.. autosummary::
:toctree: generated/
:template: function.rst
convnext_tiny
convnext_small
convnext_base
convnext_large
SwinTransformer
---------------
.. autosummary::
:toctree: generated/
:template: function.rst
swin_t
Quantized Models
----------------
The following architectures provide support for INT8 quantized models. You can get models/googlenet_quant
a model with random weights by calling its constructor: models/inception_quant
models/mobilenetv2_quant
models/mobilenetv3_quant
models/resnet_quant
models/resnext_quant
models/shufflenetv2_quant
.. code:: python |
import torchvision.models as models Here is an example of how to use the pre-trained quantized image classification models:
googlenet = models.quantization.googlenet()
inception_v3 = models.quantization.inception_v3()
mobilenet_v2 = models.quantization.mobilenet_v2()
mobilenet_v3_large = models.quantization.mobilenet_v3_large()
resnet18 = models.quantization.resnet18()
resnet50 = models.quantization.resnet50()
resnext101_32x8d = models.quantization.resnext101_32x8d()
resnext101_64x4d = models.quantization.resnext101_64x4d()
shufflenet_v2_x0_5 = models.quantization.shufflenet_v2_x0_5()
shufflenet_v2_x1_0 = models.quantization.shufflenet_v2_x1_0()
shufflenet_v2_x1_5 = models.quantization.shufflenet_v2_x1_5()
shufflenet_v2_x2_0 = models.quantization.shufflenet_v2_x2_0()
Obtaining a pre-trained quantized model can be done with a few lines of code:
.. code:: python .. code:: python
import torchvision.models as models from torchvision.io import read_image
model = models.quantization.mobilenet_v2(weights=MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1, quantize=True) from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval() model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))
We provide pre-trained quantized weights for the following models:
================================ ============= =============
Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
ShuffleNet V2 x0.5 57.972 79.780
ShuffleNet V2 x1.0 68.360 87.582
ShuffleNet V2 x1.5 72.052 90.700
ShuffleNet V2 x2.0 75.354 92.488
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
ResNext 101 64x4d 82.898 96.326
Inception V3 77.176 93.354
GoogleNet 69.826 89.404
================================ ============= =============
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
Semantic Segmentation # Step 3: Apply inference preprocessing transforms
===================== batch = preprocess(img).unsqueeze(0)
The models subpackage contains definitions for the following model # Step 4: Use the model and print the predicted category
architectures for semantic segmentation: prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")
- `FCN ResNet50, ResNet101 <https://arxiv.org/abs/1411.4038>`_ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
- `DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large <https://arxiv.org/abs/1706.05587>`_
- `LR-ASPP MobileNetV3-Large <https://arxiv.org/abs/1905.02244>`_
As with image classification models, all pre-trained models expect input images normalized in the same way.
The images have to be loaded in to a range of ``[0, 1]`` and then normalized using
``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
They have been trained on images resized such that their minimum size is 520.
For details on how to plot the masks of such models, you may refer to :ref:`semantic_seg_output`. Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The pre-trained models have been trained on a subset of COCO train2017, on the 20 categories that are Accuracies are reported on ImageNet-1K using single crops:
present in the Pascal VOC dataset. You can see more information on how the subset has been selected in
``references/segmentation/coco_utils.py``. The classes that the pre-trained model outputs are the following,
in order:
.. code-block:: python .. include:: generated/classification_quant_table.rst
['__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', Semantic Segmentation
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', =====================
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
The accuracies of the pre-trained models evaluated on COCO val2017 are as follows .. currentmodule:: torchvision.models.segmentation
================================ ============= ==================== The following semantic segmentation models are available, with or without
Network mean IoU global pixelwise acc pre-trained weights:
================================ ============= ====================
FCN ResNet50 60.5 91.4
FCN ResNet101 63.7 91.9
DeepLabV3 ResNet50 66.4 92.4
DeepLabV3 ResNet101 67.4 92.4
DeepLabV3 MobileNetV3-Large 60.3 91.2
LR-ASPP MobileNetV3-Large 57.9 91.2
================================ ============= ====================
.. toctree::
:maxdepth: 1
Fully Convolutional Networks models/deeplabv3
---------------------------- models/fcn
models/lraspp
|
.. autosummary:: Here is an example of how to use the pre-trained semantic segmentation models:
:toctree: generated/
:template: function.rst
torchvision.models.segmentation.fcn_resnet50 .. code:: python
torchvision.models.segmentation.fcn_resnet101
from torchvision.io.image import read_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
DeepLabV3 img = read_image("gallery/assets/dog1.jpg")
---------
.. autosummary:: # Step 1: Initialize model with the best available weights
:toctree: generated/ weights = FCN_ResNet50_Weights.DEFAULT
:template: function.rst model = fcn_resnet50(weights=weights)
model.eval()
torchvision.models.segmentation.deeplabv3_resnet50 # Step 2: Initialize the inference transforms
torchvision.models.segmentation.deeplabv3_resnet101 preprocess = weights.transforms()
torchvision.models.segmentation.deeplabv3_mobilenet_v3_large
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
LR-ASPP # Step 4: Use the model and visualize the prediction
------- prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()
.. autosummary:: The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
:toctree: generated/ The output format of the models is illustrated in :ref:`semantic_seg_output`.
:template: function.rst
Table of all available semantic segmentation weights
----------------------------------------------------
All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:
.. include:: generated/segmentation_table.rst
torchvision.models.segmentation.lraspp_mobilenet_v3_large
.. _object_det_inst_seg_pers_keypoint_det: .. _object_det_inst_seg_pers_keypoint_det:
Object Detection, Instance Segmentation and Person Keypoint Detection Object Detection, Instance Segmentation and Person Keypoint Detection
===================================================================== =====================================================================
The models subpackage contains definitions for the following model
architectures for detection:
- `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
- `FCOS <https://arxiv.org/abs/1904.01355>`_
- `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_
- `RetinaNet <https://arxiv.org/abs/1708.02002>`_
- `SSD <https://arxiv.org/abs/1512.02325>`_
- `SSDlite <https://arxiv.org/abs/1801.04381>`_
The pre-trained models for detection, instance segmentation and The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models keypoint detection are initialized with the classification models
in torchvision. in torchvision. The models expect a list of ``Tensor[C, H, W]``.
Check the constructor of the models for more information.
The models expect a list of ``Tensor[C, H, W]``, in the range ``0-1``.
The models internally resize the images but the behaviour varies depending
on the model. Check the constructor of the models for more information. The
output format of such models is illustrated in :ref:`instance_seg_output`.
For object detection and instance segmentation, the pre-trained
models return the predictions of the following classes:
.. code-block:: python
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
Here are the summary of the accuracies for the models trained on
the instances set of COCO train2017 and evaluated on COCO val2017.
====================================== ======= ======== ===========
Network box AP mask AP keypoint AP
====================================== ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
FCOS ResNet-50 FPN 39.2 - -
RetinaNet ResNet-50 FPN 36.4 - -
SSD300 VGG16 25.1 - -
SSDlite320 MobileNetV3-Large 21.3 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
====================================== ======= ======== ===========
For person keypoint detection, the accuracies for the pre-trained
models are as follows
================================ ======= ======== ===========
Network box AP mask AP keypoint AP
================================ ======= ======== ===========
Keypoint R-CNN ResNet-50 FPN 54.6 - 65.0
================================ ======= ======== ===========
For person keypoint detection, the pre-trained model return the
keypoints in the following order:
.. code-block:: python
COCO_PERSON_KEYPOINT_NAMES = [ Object Detection
'nose', ----------------
'left_eye',
'right_eye',
'left_ear',
'right_ear',
'left_shoulder',
'right_shoulder',
'left_elbow',
'right_elbow',
'left_wrist',
'right_wrist',
'left_hip',
'right_hip',
'left_knee',
'right_knee',
'left_ankle',
'right_ankle'
]
Runtime characteristics .. currentmodule:: torchvision.models.detection
-----------------------
The implementations of the models for object detection, instance segmentation The following object detection models are available, with or without pre-trained
and keypoint detection are efficient. weights:
In the following table, we use 8 GPUs to report the results. During training, .. toctree::
we use a batch size of 2 per GPU for all models except SSD which uses 4 :maxdepth: 1
and SSDlite which uses 24. During testing a batch size of 1 is used.
For test time, we report the time for the model evaluation and postprocessing models/faster_rcnn
(including mask pasting in image), but not the time for computing the models/fcos
precision-recall. models/retinanet
models/ssd
models/ssdlite
====================================== =================== ================== =========== |
Network train time (s / it) test time (s / it) memory (GB)
====================================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
FCOS ResNet-50 FPN 0.1450 0.0539 3.3
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD300 VGG16 0.2093 0.0744 1.5
SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
====================================== =================== ================== ===========
Here is an example of how to use the pre-trained object detection models:
Faster R-CNN .. code:: python
------------
.. autosummary::
:toctree: generated/
:template: function.rst
torchvision.models.detection.fasterrcnn_resnet50_fpn from torchvision.io.image import read_image
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
FCOS img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
----
.. autosummary:: # Step 1: Initialize model with the best available weights
:toctree: generated/ weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
:template: function.rst model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]
# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
labels=labels,
colors="red",
width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`.
Table of all available Object detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box MAPs are reported on COCO val2017:
.. include:: generated/detection_table.rst
torchvision.models.detection.fcos_resnet50_fpn
Instance Segmentation
---------------------
RetinaNet .. currentmodule:: torchvision.models.detection
---------
.. autosummary:: The following instance segmentation models are available, with or without pre-trained
:toctree: generated/ weights:
:template: function.rst
torchvision.models.detection.retinanet_resnet50_fpn .. toctree::
:maxdepth: 1
models/mask_rcnn
SSD |
---
.. autosummary::
:toctree: generated/
:template: function.rst
torchvision.models.detection.ssd300_vgg16 For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`.
Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
SSDlite Box and Mask MAPs are reported on COCO val2017:
-------
.. autosummary:: .. include:: generated/instance_segmentation_table.rst
:toctree: generated/
:template: function.rst
torchvision.models.detection.ssdlite320_mobilenet_v3_large Keypoint Detection
------------------
.. currentmodule:: torchvision.models.detection
Mask R-CNN The following person keypoint detection models are available, with or without
---------- pre-trained weights:
.. autosummary:: .. toctree::
:toctree: generated/ :maxdepth: 1
:template: function.rst
torchvision.models.detection.maskrcnn_resnet50_fpn models/keypoint_rcnn
|
Keypoint R-CNN The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``.
-------------- For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`.
.. autosummary:: Table of all available Keypoint detection weights
:toctree: generated/ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:template: function.rst
torchvision.models.detection.keypointrcnn_resnet50_fpn Box and Keypoint MAPs are reported on COCO val2017:
.. include:: generated/detection_keypoint_table.rst
Video classification
Video Classification
==================== ====================
We provide models for action recognition pre-trained on Kinetics-400. .. currentmodule:: torchvision.models.video
They have all been trained with the scripts provided in ``references/video_classification``.
All pre-trained models expect input images normalized in the same way, The following video classification models are available, with or without
i.e. mini-batches of 3-channel RGB videos of shape (3 x T x H x W), pre-trained weights:
where H and W are expected to be 112, and T is a number of video frames in a clip.
The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.43216, 0.394666, 0.37645]`` and ``std = [0.22803, 0.22145, 0.216989]``.
.. toctree::
:maxdepth: 1
.. note:: models/video_resnet
The normalization parameters are different from the image classification ones, and correspond
to the mean and std from Kinetics-400.
.. note:: |
For now, normalization code can be found in ``references/video_classification/transforms.py``,
see the ``Normalize`` function there. Note that it differs from standard normalization for Here is an example of how to use the pre-trained video classification models:
images because it assumes the video is 4d.
.. code:: python
Kinetics 1-crop accuracies for clip length 16 (16x112x112)
================================ ============= ============= from torchvision.io.video import read_video
Network Clip acc@1 Clip acc@5 from torchvision.models.video import r3d_18, R3D_18_Weights
================================ ============= =============
ResNet 3D 18 52.75 75.45
ResNet MC 18 53.90 76.29
ResNet (2+1)D 57.50 78.81
================================ ============= =============
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi")
vid = vid[:32] # optionally shorten duration
ResNet 3D # Step 1: Initialize model with the best available weights
---------- weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()
.. autosummary:: # Step 2: Initialize the inference transforms
:toctree: generated/ preprocess = weights.transforms()
:template: function.rst
torchvision.models.video.r3d_18 # Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)
ResNet Mixed Convolution # Step 4: Use the model and print the predicted category
------------------------ prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")
.. autosummary:: The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
:toctree: generated/
:template: function.rst
torchvision.models.video.mc3_18
ResNet (2+1)D Table of all available video classification weights
------------- ---------------------------------------------------
.. autosummary:: Accuracies are reported on Kinetics-400 using single crops for clip length 16:
:toctree: generated/
:template: function.rst
torchvision.models.video.r2plus1d_18 .. include:: generated/video_table.rst
Optical flow Optical Flow
============ ============
Raft .. currentmodule:: torchvision.models.optical_flow
----
The following Optical Flow models are available, with or without pre-trained
.. autosummary:: .. toctree::
:toctree: generated/ :maxdepth: 1
:template: function.rst
torchvision.models.optical_flow.raft_large models/raft
torchvision.models.optical_flow.raft_small
.. _models_new:
Models and pre-trained weights - New
####################################
The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection, video classification, and optical flow.
General information on pre-trained weights
==========================================
TorchVision offers pre-trained weights for every provided architecture, using
the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its
weights to a cache directory. This directory can be set using the `TORCH_HOME`
environment variable. See :func:`torch.hub.load_state_dict_from_url` for details.
.. note::
The pre-trained models provided in this library may have their own licenses or
terms and conditions derived from the dataset used for training. It is your
responsibility to determine whether you have permission to use the models for
your use case.
.. note ::
Backward compatibility is guaranteed for loading a serialized
``state_dict`` to the model created using old PyTorch version.
On the contrary, loading entire saved models or serialized
``ScriptModules`` (serialized using older versions of PyTorch)
may not preserve the historic behaviour. Refer to the following
`documentation
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
Initializing pre-trained models
-------------------------------
As of v0.13, TorchVision offers a new `Multi-weight support API
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_
for loading different weights to the existing model builder methods:
.. code:: python
from torchvision.models import resnet50, ResNet50_Weights
# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)
# Strings are also supported
resnet50(weights="IMAGENET1K_V2")
# No weights - random initialization
resnet50(weights=None)
Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
.. code:: python
from torchvision.models import resnet50, ResNet50_Weights
# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True) # deprecated
resnet50(True) # deprecated
# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False) # deprecated
resnet50(False) # deprecated
Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.
Using the pre-trained models
----------------------------
Before using the pre-trained models, one must preprocess the image
(resize with right resolution/interpolation, apply inference transforms,
rescale the values etc). There is no standard way to do this as it depends on
how a given model was trained. It can vary across model families, variants or
even weight versions. Using the correct preprocessing method is critical and
failing to do so may lead to decreased accuracy or incorrect outputs.
All the necessary information for the inference transforms of each pre-trained
model is provided on its weights documentation. To simplify inference, TorchVision
bundles the necessary preprocessing transforms into each model weight. These are
accessible via the ``weight.transforms`` attribute:
.. code:: python
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Apply it to the input image
img_transformed = preprocess(img)
Some models use modules which have different training and evaluation
behavior, such as batch normalization. To switch between these modes, use
``model.train()`` or ``model.eval()`` as appropriate. See
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
.. code:: python
# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
# Set model to eval mode
model.eval()
Classification
==============
.. currentmodule:: torchvision.models
The following classification models are available, with or without pre-trained
weights:
.. toctree::
:maxdepth: 1
models/alexnet
models/convnext
models/densenet
models/efficientnet
models/efficientnetv2
models/googlenet
models/inception
models/mnasnet
models/mobilenetv2
models/mobilenetv3
models/regnet
models/resnet
models/resnext
models/shufflenetv2
models/squeezenet
models/swin_transformer
models/vgg
models/vision_transformer
models/wide_resnet
|
Here is an example of how to use the pre-trained image classification models:
.. code:: python
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Table of all available classification weights
---------------------------------------------
Accuracies are reported on ImageNet-1K using single crops:
.. include:: generated/classification_table.rst
Quantized models
----------------
.. currentmodule:: torchvision.models.quantization
The following architectures provide support for INT8 quantized models, with or without
pre-trained weights:
.. toctree::
:maxdepth: 1
models/googlenet_quant
models/inception_quant
models/mobilenetv2_quant
models/mobilenetv3_quant
models/resnet_quant
models/resnext_quant
models/shufflenetv2_quant
|
Here is an example of how to use the pre-trained quantized image classification models:
.. code:: python
from torchvision.io import read_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Accuracies are reported on ImageNet-1K using single crops:
.. include:: generated/classification_quant_table.rst
Semantic Segmentation
=====================
.. currentmodule:: torchvision.models.segmentation
The following semantic segmentation models are available, with or without
pre-trained weights:
.. toctree::
:maxdepth: 1
models/deeplabv3
models/fcn
models/lraspp
|
Here is an example of how to use the pre-trained semantic segmentation models:
.. code:: python
from torchvision.io.image import read_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
img = read_image("gallery/assets/dog1.jpg")
# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
The output format of the models is illustrated in :ref:`semantic_seg_output`.
Table of all available semantic segmentation weights
----------------------------------------------------
All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:
.. include:: generated/segmentation_table.rst
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================
The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models
in torchvision. The models expect a list of ``Tensor[C, H, W]``.
Check the constructor of the models for more information.
Object Detection
----------------
.. currentmodule:: torchvision.models.detection
The following object detection models are available, with or without pre-trained
weights:
.. toctree::
:maxdepth: 1
models/faster_rcnn
models/fcos
models/retinanet
models/ssd
models/ssdlite
|
Here is an example of how to use the pre-trained object detection models:
.. code:: python
from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]
# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
labels=labels,
colors="red",
width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`.
Table of all available Object detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box MAPs are reported on COCO val2017:
.. include:: generated/detection_table.rst
Instance Segmentation
---------------------
.. currentmodule:: torchvision.models.detection
The following instance segmentation models are available, with or without pre-trained
weights:
.. toctree::
:maxdepth: 1
models/mask_rcnn
|
For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`.
Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box and Mask MAPs are reported on COCO val2017:
.. include:: generated/instance_segmentation_table.rst
Keypoint Detection
------------------
.. currentmodule:: torchvision.models.detection
The following person keypoint detection models are available, with or without
pre-trained weights:
.. toctree::
:maxdepth: 1
models/keypoint_rcnn
|
The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`.
Table of all available Keypoint detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Box and Keypoint MAPs are reported on COCO val2017:
.. include:: generated/detection_keypoint_table.rst
Video Classification
====================
.. currentmodule:: torchvision.models.video
The following video classification models are available, with or without
pre-trained weights:
.. toctree::
:maxdepth: 1
models/video_resnet
|
Here is an example of how to use the pre-trained video classification models:
.. code:: python
from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi")
vid = vid[:32] # optionally shorten duration
# Step 1: Initialize model with the best available weights
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
Table of all available video classification weights
---------------------------------------------------
Accuracies are reported on Kinetics-400 using single crops for clip length 16:
.. include:: generated/video_table.rst
Optical Flow
============
.. currentmodule:: torchvision.models.optical_flow
The following Optical Flow models are available, with or without pre-trained
.. toctree::
:maxdepth: 1
models/raft
Using models from Hub
=====================
Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
.. code:: python
import torch
# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
The only exception to the above are the detection models included on
:mod:`torchvision.models.detection`. These models require TorchVision
to be installed because they depend on custom C++ operators.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment