models_new.rst 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
.. _models_new:

Models and pre-trained weights - New
####################################

.. note::

    These are the new models docs, documenting the new multi-weight API.
    TODO: Once all is done, remove the "- New" part in the title above, and
    rename this file as models.rst


The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection, video classification, and optical flow.

.. note ::
    Backward compatibility is guaranteed for loading a serialized 
    ``state_dict`` to the model created using old PyTorch version. 
    On the contrary, loading entire saved models or serialized 
    ``ScriptModules`` (seralized using older versions of PyTorch) 
    may not preserve the historic behaviour. Refer to the following 
    `documentation 
    <https://pytorch.org/docs/stable/notes/serialization.html#id6>`_   

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
As of v0.13, TorchVision offers a new `Multi-weight support API
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_ for loading different weights to the
existing model builder methods:

.. code:: python

    from torchvision.models import resnet50, ResNet50_Weights

    # Old weights with accuracy 76.130%
    resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

    # New weights with accuracy 80.858%
    resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

    # Best available weights (currently alias for IMAGENET1K_V2)
    # Note that these weights may change across versions
    resnet50(weights=ResNet50_Weights.DEFAULT)

    # Strings are also supported
    resnet50(weights="IMAGENET1K_V2")

    # No weights - random initialization
    resnet50(weights=None)  # or resnet50()


Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:

.. code:: python

    from torchvision.models import resnet50, ResNet50_Weights

    # Using pretrained weights:
    resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    resnet50(pretrained=True)  # deprecated
    resnet50(True)  # deprecated

    # Using no weights:
    resnet50(weights=None)
    resnet50(pretrained=False)  # deprecated
    resnet50(False)  # deprecated

Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.

70
71
72
73
74
75
76
77
78
79
80
81

Classification
==============

.. currentmodule:: torchvision.models

The following classification models are available, with or without pre-trained
weights:

.. toctree::
   :maxdepth: 1

82
   models/alexnet
Hu Ye's avatar
Hu Ye committed
83
   models/convnext
84
   models/densenet
85
   models/efficientnet
86
   models/efficientnetv2
87
   models/googlenet
Aditya Oke's avatar
Aditya Oke committed
88
   models/inception
Joao Gomes's avatar
Joao Gomes committed
89
   models/mnasnet
90
   models/mobilenetv2
91
   models/mobilenetv3
92
   models/regnet
93
   models/resnet
94
   models/resnext
95
   models/shufflenetv2
Nicolas Hug's avatar
Nicolas Hug committed
96
   models/squeezenet
97
   models/swin_transformer
98
   models/vgg
99
   models/vision_transformer
100
   models/wide_resnet
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|

Here is an example of how to use the pre-trained image classification models:

.. code:: python

    from torchvision.io import read_image
    from torchvision.models import resnet50, ResNet50_Weights

    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

    # Step 1: Initialize model with the best available weights
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score:.1f}%")
130
131
132
133
134
135
136
137

Table of all available classification weights
---------------------------------------------

Accuracies are reported on ImageNet

.. include:: generated/classification_table.rst

138
139
140
141
142
143
144
145
146
147
148
149
Quantized models
----------------

.. currentmodule:: torchvision.models.quantization

The following quantized classification models are available, with or without
pre-trained weights:

.. toctree::
   :maxdepth: 1

   models/googlenet_quant
150
   models/inception_quant
151
   models/mobilenetv2_quant
152
   models/resnet_quant
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
|

Here is an example of how to use the pre-trained quantized image classification models:

.. code:: python

    from torchvision.io import read_image
    from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights

    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

    # Step 1: Initialize model with the best available weights
    weights = ResNet50_QuantizedWeights.DEFAULT
    model = resnet50(weights=weights, quantize=True)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score}%")

183
184
185
186
187
188
189
190

Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Accuracies are reported on ImageNet

.. include:: generated/classification_quant_table.rst

191
192
193
194
195
196
197
198
199
200
201
202
Semantic Segmentation
=====================

.. currentmodule:: torchvision.models.segmentation

The following semantic segmentation models are available, with or without
pre-trained weights:

.. toctree::
   :maxdepth: 1

   models/deeplabv3
203
   models/fcn
Aditya Oke's avatar
Aditya Oke committed
204
   models/lraspp
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
|

Here is an example of how to use the pre-trained semantic segmentation models:

.. code:: python

    from torchvision.io.image import read_image
    from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
    from torchvision.transforms.functional import to_pil_image

    img = read_image("gallery/assets/dog1.jpg")

    # Step 1: Initialize model with the best available weights
    weights = FCN_ResNet50_Weights.DEFAULT
    model = fcn_resnet50(weights=weights)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and visualize the prediction
    prediction = model(batch)["out"]
    normalized_masks = prediction.softmax(dim=1)
    class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
    mask = normalized_masks[0, class_to_idx["dog"]]
    to_pil_image(mask).show()


237
238
239
240
241
242
243
244
Table of all available semantic segmentation weights
----------------------------------------------------

All models are evaluated on COCO val2017:

.. include:: generated/segmentation_table.rst


245

246
247
248
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================

249
Object Detection
250
----------------
251

252
253
.. currentmodule:: torchvision.models.detection

254
The following object detection models are available, with or without pre-trained
255
256
257
258
259
weights:

.. toctree::
   :maxdepth: 1

260
   models/faster_rcnn
Hu Ye's avatar
Hu Ye committed
261
262
   models/fcos
   models/retinanet
263
   models/ssd
264
   models/ssdlite
265

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
|

Here is an example of how to use the pre-trained object detection models:

.. code:: python


    from torchvision.io.image import read_image
    from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
    from torchvision.utils import draw_bounding_boxes
    from torchvision.transforms.functional import to_pil_image

    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

    # Step 1: Initialize model with the best available weights
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = [preprocess(img)]

    # Step 4: Use the model and visualize the prediction
    prediction = model(batch)[0]
    labels = [weights.meta["categories"][i] for i in prediction["labels"]]
    box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                              labels=labels,
                              colors="red",
                              width=4, font_size=30)
    im = to_pil_image(box.detach())
    im.show()

301
302
Table of all available Object detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
303
304
305
306

Box MAPs are reported on COCO

.. include:: generated/detection_table.rst
307

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
Instance Segmentation
---------------------

.. currentmodule:: torchvision.models.detection

The following instance segmentation models are available, with or without pre-trained
weights:

.. toctree::
   :maxdepth: 1

   models/mask_rcnn

Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Box and Mask MAPs are reported on COCO

.. include:: generated/instance_segmentation_table.rst
327

328
329
Keypoint Detection
------------------
330
331
332
333
334
335
336
337
338
339
340
341

.. currentmodule:: torchvision.models.detection

The following keypoint detection models are available, with or without
pre-trained weights:

.. toctree::
   :maxdepth: 1

   models/keypoint_rcnn

Table of all available Keypoint detection weights
342
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
343
344
345
346
347
348

Box and Keypoint MAPs are reported on COCO:

.. include:: generated/detection_keypoint_table.rst


349
350
351
352
353
354
355
356
357
358
359
360
361
Video Classification
====================

.. currentmodule:: torchvision.models.video

The following video classification models are available, with or without
pre-trained weights:

.. toctree::
   :maxdepth: 1

   models/video_resnet

362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
|

Here is an example of how to use the pre-trained video classification models:

.. code:: python


    from torchvision.io.video import read_video
    from torchvision.models.video import r3d_18, R3D_18_Weights

    vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi")
    vid = vid[:32]  # optionally shorten duration

    # Step 1: Initialize model with the best available weights
    weights = R3D_18_Weights.DEFAULT
    model = r3d_18(weights=weights)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(vid).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    label = prediction.argmax().item()
    score = prediction[label].item()
    category_name = weights.meta["categories"][label]
    print(f"{category_name}: {100 * score}%")


394
395
396
397
398
399
Table of all available video classification weights
---------------------------------------------------

Accuracies are reported on Kinetics-400

.. include:: generated/video_table.rst