models.rst 17.1 KB
Newer Older
limm's avatar
limm committed
1
.. _models:
2

limm's avatar
limm committed
3
4
Models and pre-trained weights
##############################
5

limm's avatar
limm committed
6
The ``torchvision.models`` subpackage contains definitions of models for addressing
7
different tasks, including: image classification, pixelwise semantic
8
segmentation, object detection, instance segmentation, person
limm's avatar
limm committed
9
keypoint detection, video classification, and optical flow.
10

limm's avatar
limm committed
11
12
General information on pre-trained weights
==========================================
13

limm's avatar
limm committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
TorchVision offers pre-trained weights for every provided architecture, using
the PyTorch :mod:`torch.hub`. Instancing a pre-trained model will download its
weights to a cache directory. This directory can be set using the `TORCH_HOME`
environment variable. See :func:`torch.hub.load_state_dict_from_url` for details.

.. note::

    The pre-trained models provided in this library may have their own licenses or
    terms and conditions derived from the dataset used for training. It is your
    responsibility to determine whether you have permission to use the models for
    your use case.

.. note ::
    Backward compatibility is guaranteed for loading a serialized
    ``state_dict`` to the model created using old PyTorch version.
    On the contrary, loading entire saved models or serialized
    ``ScriptModules`` (serialized using older versions of PyTorch)
    may not preserve the historic behaviour. Refer to the following
    `documentation
    <https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
34

limm's avatar
limm committed
35
36
37
38
39
40
41

Initializing pre-trained models
-------------------------------

As of v0.13, TorchVision offers a new `Multi-weight support API
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_
for loading different weights to the existing model builder methods:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
42
43
44

.. code:: python

limm's avatar
limm committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    from torchvision.models import resnet50, ResNet50_Weights

    # Old weights with accuracy 76.130%
    resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

    # New weights with accuracy 80.858%
    resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

    # Best available weights (currently alias for IMAGENET1K_V2)
    # Note that these weights may change across versions
    resnet50(weights=ResNet50_Weights.DEFAULT)

    # Strings are also supported
    resnet50(weights="IMAGENET1K_V2")

    # No weights - random initialization
    resnet50(weights=None)


Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
65
66
67

.. code:: python

limm's avatar
limm committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    from torchvision.models import resnet50, ResNet50_Weights

    # Using pretrained weights:
    resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    resnet50(weights="IMAGENET1K_V1")
    resnet50(pretrained=True)  # deprecated
    resnet50(True)  # deprecated

    # Using no weights:
    resnet50(weights=None)
    resnet50()
    resnet50(pretrained=False)  # deprecated
    resnet50(False)  # deprecated

Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.

Using the pre-trained models
----------------------------

Before using the pre-trained models, one must preprocess the image
(resize with right resolution/interpolation, apply inference transforms,
rescale the values etc). There is no standard way to do this as it depends on
how a given model was trained. It can vary across model families, variants or
even weight versions. Using the correct preprocessing method is critical and
failing to do so may lead to decreased accuracy or incorrect outputs.

All the necessary information for the inference transforms of each pre-trained
model is provided on its weights documentation. To simplify inference, TorchVision
bundles the necessary preprocessing transforms into each model weight. These are
accessible via the ``weight.transforms`` attribute:

.. code:: python

    # Initialize the Weight Transforms
    weights = ResNet50_Weights.DEFAULT
    preprocess = weights.transforms()

    # Apply it to the input image
    img_transformed = preprocess(img)

108

109
110
111
Some models use modules which have different training and evaluation
behavior, such as batch normalization. To switch between these modes, use
``model.train()`` or ``model.eval()`` as appropriate. See
112
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
113

limm's avatar
limm committed
114
.. code:: python
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
115

limm's avatar
limm committed
116
117
118
    # Initialize model
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
119

limm's avatar
limm committed
120
121
    # Set model to eval mode
    model.eval()
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
122

limm's avatar
limm committed
123
124
Listing and retrieving available models
---------------------------------------
125

limm's avatar
limm committed
126
127
128
As of v0.14, TorchVision offers a new mechanism which allows listing and
retrieving models and weights by their names. Here are a few examples on how to
use them:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
129

limm's avatar
limm committed
130
.. code:: python
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
131

limm's avatar
limm committed
132
133
134
    # List available models
    all_models = list_models()
    classification_models = list_models(module=torchvision.models)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
135

limm's avatar
limm committed
136
137
138
    # Initialize models
    m1 = get_model("mobilenet_v3_large", weights=None)
    m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
139

limm's avatar
limm committed
140
141
142
    # Fetch weights
    weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
    assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
143

limm's avatar
limm committed
144
145
    weights_enum = get_model_weights("quantized_mobilenet_v3_large")
    assert weights_enum == MobileNet_V3_Large_QuantizedWeights
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
146

limm's avatar
limm committed
147
148
    weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
    assert weights_enum == weights_enum2
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
149

limm's avatar
limm committed
150
Here are the available public functions to retrieve models and their corresponding weights:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
151

limm's avatar
limm committed
152
153
154
155
.. currentmodule:: torchvision.models
.. autosummary::
    :toctree: generated/
    :template: function.rst
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
156

limm's avatar
limm committed
157
158
159
160
    get_model
    get_model_weights
    get_weight
    list_models
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
161

limm's avatar
limm committed
162
163
Using models from Hub
---------------------
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
164

limm's avatar
limm committed
165
Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
166

limm's avatar
limm committed
167
.. code:: python
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
168

limm's avatar
limm committed
169
    import torch
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
170

limm's avatar
limm committed
171
172
    # Option 1: passing weights param as string
    model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
173

limm's avatar
limm committed
174
175
176
    # Option 2: passing weights param as enum
    weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
    model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
177

limm's avatar
limm committed
178
You can also retrieve all the available weights of a specific model via PyTorch Hub by doing:
179

limm's avatar
limm committed
180
.. code:: python
181

limm's avatar
limm committed
182
    import torch
183

limm's avatar
limm committed
184
185
186
187
188
189
190
191
192
    weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
    print([weight for weight in weight_enum])

The only exception to the above are the detection models included on
:mod:`torchvision.models.detection`. These models require TorchVision
to be installed because they depend on custom C++ operators.

Classification
==============
193

limm's avatar
limm committed
194
.. currentmodule:: torchvision.models
195

limm's avatar
limm committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
The following classification models are available, with or without pre-trained
weights:

.. toctree::
   :maxdepth: 1

   models/alexnet
   models/convnext
   models/densenet
   models/efficientnet
   models/efficientnetv2
   models/googlenet
   models/inception
   models/maxvit
   models/mnasnet
   models/mobilenetv2
   models/mobilenetv3
   models/regnet
   models/resnet
   models/resnext
   models/shufflenetv2
   models/squeezenet
   models/swin_transformer
   models/vgg
   models/vision_transformer
   models/wide_resnet

|

Here is an example of how to use the pre-trained image classification models:
Bar's avatar
Bar committed
226

limm's avatar
limm committed
227
.. code:: python
Bar's avatar
Bar committed
228

limm's avatar
limm committed
229
230
    from torchvision.io import read_image
    from torchvision.models import resnet50, ResNet50_Weights
231

limm's avatar
limm committed
232
    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
233

limm's avatar
limm committed
234
235
236
237
    # Step 1: Initialize model with the best available weights
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    model.eval()
238

limm's avatar
limm committed
239
240
    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()
241

limm's avatar
limm committed
242
243
    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)
244

limm's avatar
limm committed
245
246
247
248
249
250
    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score:.1f}%")
251

limm's avatar
limm committed
252
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
253

limm's avatar
limm committed
254
255
Table of all available classification weights
---------------------------------------------
256

limm's avatar
limm committed
257
Accuracies are reported on ImageNet-1K using single crops:
258

limm's avatar
limm committed
259
.. include:: generated/classification_table.rst
260

limm's avatar
limm committed
261
Quantized models
262
263
----------------

limm's avatar
limm committed
264
.. currentmodule:: torchvision.models.quantization
265

limm's avatar
limm committed
266
267
268
269
270
The following architectures provide support for INT8 quantized models, with or without
pre-trained weights:

.. toctree::
   :maxdepth: 1
271

limm's avatar
limm committed
272
273
274
275
276
277
278
279
280
281
282
   models/googlenet_quant
   models/inception_quant
   models/mobilenetv2_quant
   models/mobilenetv3_quant
   models/resnet_quant
   models/resnext_quant
   models/shufflenetv2_quant

|

Here is an example of how to use the pre-trained quantized image classification models:
283
284
285

.. code:: python

limm's avatar
limm committed
286
287
288
289
290
291
292
293
    from torchvision.io import read_image
    from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights

    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

    # Step 1: Initialize model with the best available weights
    weights = ResNet50_QuantizedWeights.DEFAULT
    model = resnet50(weights=weights, quantize=True)
294
295
    model.eval()

limm's avatar
limm committed
296
297
298
299
300
301
302
303
304
305
306
307
    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score}%")
308

limm's avatar
limm committed
309
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
310

311

limm's avatar
limm committed
312
313
314
315
316
317
318
Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Accuracies are reported on ImageNet-1K using single crops:

.. include:: generated/classification_quant_table.rst

319
320
321
Semantic Segmentation
=====================

limm's avatar
limm committed
322
.. currentmodule:: torchvision.models.segmentation
323

limm's avatar
limm committed
324
.. betastatus:: segmentation module
325

limm's avatar
limm committed
326
327
The following semantic segmentation models are available, with or without
pre-trained weights:
328

limm's avatar
limm committed
329
330
.. toctree::
   :maxdepth: 1
331

limm's avatar
limm committed
332
333
334
   models/deeplabv3
   models/fcn
   models/lraspp
335

limm's avatar
limm committed
336
|
337

limm's avatar
limm committed
338
Here is an example of how to use the pre-trained semantic segmentation models:
339

limm's avatar
limm committed
340
.. code:: python
341

limm's avatar
limm committed
342
343
344
    from torchvision.io.image import read_image
    from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
    from torchvision.transforms.functional import to_pil_image
345

limm's avatar
limm committed
346
    img = read_image("gallery/assets/dog1.jpg")
347

limm's avatar
limm committed
348
349
350
351
    # Step 1: Initialize model with the best available weights
    weights = FCN_ResNet50_Weights.DEFAULT
    model = fcn_resnet50(weights=weights)
    model.eval()
352

limm's avatar
limm committed
353
354
    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()
355

limm's avatar
limm committed
356
357
    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)
358

limm's avatar
limm committed
359
360
361
362
363
364
    # Step 4: Use the model and visualize the prediction
    prediction = model(batch)["out"]
    normalized_masks = prediction.softmax(dim=1)
    class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
    mask = normalized_masks[0, class_to_idx["dog"]]
    to_pil_image(mask).show()
365

limm's avatar
limm committed
366
367
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
The output format of the models is illustrated in :ref:`semantic_seg_output`.
368
369


limm's avatar
limm committed
370
371
372
373
374
375
Table of all available semantic segmentation weights
----------------------------------------------------

All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset:

.. include:: generated/segmentation_table.rst
376

377

378
.. _object_det_inst_seg_pers_keypoint_det:
379
380
381
382
383
384

Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================

The pre-trained models for detection, instance segmentation and
keypoint detection are initialized with the classification models
limm's avatar
limm committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
in torchvision. The models expect a list of ``Tensor[C, H, W]``.
Check the constructor of the models for more information.

.. betastatus:: detection module

Object Detection
----------------

.. currentmodule:: torchvision.models.detection

The following object detection models are available, with or without pre-trained
weights:

.. toctree::
   :maxdepth: 1

   models/faster_rcnn
   models/fcos
   models/retinanet
   models/ssd
   models/ssdlite

|
408

limm's avatar
limm committed
409
Here is an example of how to use the pre-trained object detection models:
410

limm's avatar
limm committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
.. code:: python


    from torchvision.io.image import read_image
    from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
    from torchvision.utils import draw_bounding_boxes
    from torchvision.transforms.functional import to_pil_image

    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

    # Step 1: Initialize model with the best available weights
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
    model.eval()
425

limm's avatar
limm committed
426
427
    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()
428

limm's avatar
limm committed
429
430
    # Step 3: Apply inference preprocessing transforms
    batch = [preprocess(img)]
431

limm's avatar
limm committed
432
433
434
435
436
437
438
439
440
    # Step 4: Use the model and visualize the prediction
    prediction = model(batch)[0]
    labels = [weights.meta["categories"][i] for i in prediction["labels"]]
    box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                              labels=labels,
                              colors="red",
                              width=4, font_size=30)
    im = to_pil_image(box.detach())
    im.show()
441

limm's avatar
limm committed
442
443
The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`instance_seg_output`.
444

limm's avatar
limm committed
445
446
Table of all available Object detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
447

limm's avatar
limm committed
448
Box MAPs are reported on COCO val2017:
449

limm's avatar
limm committed
450
.. include:: generated/detection_table.rst
451

452

limm's avatar
limm committed
453
454
Instance Segmentation
---------------------
455

limm's avatar
limm committed
456
.. currentmodule:: torchvision.models.detection
457

limm's avatar
limm committed
458
459
The following instance segmentation models are available, with or without pre-trained
weights:
460

limm's avatar
limm committed
461
462
.. toctree::
   :maxdepth: 1
463

limm's avatar
limm committed
464
   models/mask_rcnn
465

limm's avatar
limm committed
466
|
467
468


limm's avatar
limm committed
469
For details on how to plot the masks of the models, you may refer to :ref:`instance_seg_output`.
470

limm's avatar
limm committed
471
472
Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
473

limm's avatar
limm committed
474
Box and Mask MAPs are reported on COCO val2017:
475

limm's avatar
limm committed
476
.. include:: generated/instance_segmentation_table.rst
477

limm's avatar
limm committed
478
479
Keypoint Detection
------------------
480

limm's avatar
limm committed
481
.. currentmodule:: torchvision.models.detection
482

limm's avatar
limm committed
483
484
The following person keypoint detection models are available, with or without
pre-trained weights:
485

limm's avatar
limm committed
486
487
.. toctree::
   :maxdepth: 1
488

limm's avatar
limm committed
489
   models/keypoint_rcnn
490

limm's avatar
limm committed
491
|
492

limm's avatar
limm committed
493
494
The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"]``.
For details on how to plot the bounding boxes of the models, you may refer to :ref:`keypoint_output`.
495

limm's avatar
limm committed
496
497
Table of all available Keypoint detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
498

limm's avatar
limm committed
499
500
501
502
503
504
Box and Keypoint MAPs are reported on COCO val2017:

.. include:: generated/detection_keypoint_table.rst


Video Classification
505
506
====================

limm's avatar
limm committed
507
.. currentmodule:: torchvision.models.video
508

limm's avatar
limm committed
509
.. betastatus:: video module
510

limm's avatar
limm committed
511
512
The following video classification models are available, with or without
pre-trained weights:
513

limm's avatar
limm committed
514
515
.. toctree::
   :maxdepth: 1
516

limm's avatar
limm committed
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
   models/video_mvit
   models/video_resnet
   models/video_s3d
   models/video_swin_transformer

|

Here is an example of how to use the pre-trained video classification models:

.. code:: python


    from torchvision.io.video import read_video
    from torchvision.models.video import r3d_18, R3D_18_Weights

    vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
    vid = vid[:32]  # optionally shorten duration

    # Step 1: Initialize model with the best available weights
    weights = R3D_18_Weights.DEFAULT
    model = r3d_18(weights=weights)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(vid).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    label = prediction.argmax().item()
    score = prediction[label].item()
    category_name = weights.meta["categories"][label]
    print(f"{category_name}: {100 * score}%")

The classes of the pre-trained model outputs can be found at ``weights.meta["categories"]``.
554
555


limm's avatar
limm committed
556
557
Table of all available video classification weights
---------------------------------------------------
558

limm's avatar
limm committed
559
Accuracies are reported on Kinetics-400 using single crops for clip length 16:
560

limm's avatar
limm committed
561
.. include:: generated/video_table.rst
562

limm's avatar
limm committed
563
564
Optical Flow
============
565

limm's avatar
limm committed
566
.. currentmodule:: torchvision.models.optical_flow
567

limm's avatar
limm committed
568
The following Optical Flow models are available, with or without pre-trained
569

limm's avatar
limm committed
570
571
.. toctree::
   :maxdepth: 1
572

limm's avatar
limm committed
573
   models/raft