"vscode:/vscode.git/clone" did not exist on "3fd6a5d0812c040537cea6d9e74a91defac799a9"
models_new.rst 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
.. _models_new:

Models and pre-trained weights - New
####################################

.. note::

    These are the new models docs, documenting the new multi-weight API.
    TODO: Once all is done, remove the "- New" part in the title above, and
    rename this file as models.rst


The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection, video classification, and optical flow.

.. note ::
    Backward compatibility is guaranteed for loading a serialized 
    ``state_dict`` to the model created using old PyTorch version. 
    On the contrary, loading entire saved models or serialized 
    ``ScriptModules`` (seralized using older versions of PyTorch) 
    may not preserve the historic behaviour. Refer to the following 
    `documentation 
    <https://pytorch.org/docs/stable/notes/serialization.html#id6>`_   

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
As of v0.13, TorchVision offers a new `Multi-weight support API
<https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/>`_ for loading different weights to the
existing model builder methods:

.. code:: python

    from torchvision.models import resnet50, ResNet50_Weights

    # Old weights with accuracy 76.130%
    resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

    # New weights with accuracy 80.858%
    resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

    # Best available weights (currently alias for IMAGENET1K_V2)
    # Note that these weights may change across versions
    resnet50(weights=ResNet50_Weights.DEFAULT)

    # Strings are also supported
    resnet50(weights="IMAGENET1K_V2")

    # No weights - random initialization
    resnet50(weights=None)  # or resnet50()


Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:

.. code:: python

    from torchvision.models import resnet50, ResNet50_Weights

    # Using pretrained weights:
    resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    resnet50(pretrained=True)  # deprecated
    resnet50(True)  # deprecated

    # Using no weights:
    resnet50(weights=None)
    resnet50(pretrained=False)  # deprecated
    resnet50(False)  # deprecated

Note that the ``pretrained`` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.

70
71
72
73
74
75
76
77
78
79
80
81

Classification
==============

.. currentmodule:: torchvision.models

The following classification models are available, with or without pre-trained
weights:

.. toctree::
   :maxdepth: 1

82
   models/alexnet
Hu Ye's avatar
Hu Ye committed
83
   models/convnext
84
   models/densenet
85
   models/efficientnet
86
   models/efficientnetv2
87
   models/googlenet
Aditya Oke's avatar
Aditya Oke committed
88
   models/inception
Joao Gomes's avatar
Joao Gomes committed
89
   models/mnasnet
90
   models/mobilenetv2
91
   models/mobilenetv3
92
   models/regnet
93
   models/resnet
94
   models/resnext
95
   models/shufflenetv2
Nicolas Hug's avatar
Nicolas Hug committed
96
   models/squeezenet
97
   models/swin_transformer
98
   models/vgg
99
   models/vision_transformer
100
   models/wide_resnet
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|

Here is an example of how to use the pre-trained image classification models:

.. code:: python

    from torchvision.io import read_image
    from torchvision.models import resnet50, ResNet50_Weights

    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

    # Step 1: Initialize model with the best available weights
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score:.1f}%")
130
131
132
133
134
135
136
137

Table of all available classification weights
---------------------------------------------

Accuracies are reported on ImageNet

.. include:: generated/classification_table.rst

138
139
140
141
142
143
144
145
146
147
148
149
Quantized models
----------------

.. currentmodule:: torchvision.models.quantization

The following quantized classification models are available, with or without
pre-trained weights:

.. toctree::
   :maxdepth: 1

   models/googlenet_quant
150
   models/mobilenetv2_quant
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
|

Here is an example of how to use the pre-trained quantized image classification models:

.. code:: python

    from torchvision.io import read_image
    from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights

    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

    # Step 1: Initialize model with the best available weights
    weights = ResNet50_QuantizedWeights.DEFAULT
    model = resnet50(weights=weights, quantize=True)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score}%")

181
182
183
184
185
186
187
188

Table of all available quantized classification weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Accuracies are reported on ImageNet

.. include:: generated/classification_quant_table.rst

189
190
191
192
193
194
195
196
197
198
199
200
Semantic Segmentation
=====================

.. currentmodule:: torchvision.models.segmentation

The following semantic segmentation models are available, with or without
pre-trained weights:

.. toctree::
   :maxdepth: 1

   models/deeplabv3
201
   models/fcn
Aditya Oke's avatar
Aditya Oke committed
202
   models/lraspp
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
|

Here is an example of how to use the pre-trained semantic segmentation models:

.. code:: python

    from torchvision.io.image import read_image
    from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
    from torchvision.transforms.functional import to_pil_image

    img = read_image("gallery/assets/dog1.jpg")

    # Step 1: Initialize model with the best available weights
    weights = FCN_ResNet50_Weights.DEFAULT
    model = fcn_resnet50(weights=weights)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)

    # Step 4: Use the model and visualize the prediction
    prediction = model(batch)["out"]
    normalized_masks = prediction.softmax(dim=1)
    class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
    mask = normalized_masks[0, class_to_idx["dog"]]
    to_pil_image(mask).show()


235
236
237
238
239
240
241
242
Table of all available semantic segmentation weights
----------------------------------------------------

All models are evaluated on COCO val2017:

.. include:: generated/segmentation_table.rst


243

244
245
246
Object Detection, Instance Segmentation and Person Keypoint Detection
=====================================================================

247
Object Detection
248
----------------
249

250
251
.. currentmodule:: torchvision.models.detection

252
The following object detection models are available, with or without pre-trained
253
254
255
256
257
weights:

.. toctree::
   :maxdepth: 1

258
   models/faster_rcnn
Hu Ye's avatar
Hu Ye committed
259
260
   models/fcos
   models/retinanet
261
   models/ssd
262
   models/ssdlite
263

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
|

Here is an example of how to use the pre-trained object detection models:

.. code:: python


    from torchvision.io.image import read_image
    from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
    from torchvision.utils import draw_bounding_boxes
    from torchvision.transforms.functional import to_pil_image

    img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

    # Step 1: Initialize model with the best available weights
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = [preprocess(img)]

    # Step 4: Use the model and visualize the prediction
    prediction = model(batch)[0]
    labels = [weights.meta["categories"][i] for i in prediction["labels"]]
    box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                              labels=labels,
                              colors="red",
                              width=4, font_size=30)
    im = to_pil_image(box.detach())
    im.show()

299
300
Table of all available Object detection weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
301
302
303
304

Box MAPs are reported on COCO

.. include:: generated/detection_table.rst
305

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
Instance Segmentation
---------------------

.. currentmodule:: torchvision.models.detection

The following instance segmentation models are available, with or without pre-trained
weights:

.. toctree::
   :maxdepth: 1

   models/mask_rcnn

Table of all available Instance segmentation weights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Box and Mask MAPs are reported on COCO

.. include:: generated/instance_segmentation_table.rst
325

326
327
Keypoint Detection
------------------
328
329
330
331
332
333
334
335
336
337
338
339

.. currentmodule:: torchvision.models.detection

The following keypoint detection models are available, with or without
pre-trained weights:

.. toctree::
   :maxdepth: 1

   models/keypoint_rcnn

Table of all available Keypoint detection weights
340
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
341
342
343
344
345
346

Box and Keypoint MAPs are reported on COCO:

.. include:: generated/detection_keypoint_table.rst


347
348
349
350
351
352
353
354
355
356
357
358
359
Video Classification
====================

.. currentmodule:: torchvision.models.video

The following video classification models are available, with or without
pre-trained weights:

.. toctree::
   :maxdepth: 1

   models/video_resnet

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
|

Here is an example of how to use the pre-trained video classification models:

.. code:: python


    from torchvision.io.video import read_video
    from torchvision.models.video import r3d_18, R3D_18_Weights

    vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi")
    vid = vid[:32]  # optionally shorten duration

    # Step 1: Initialize model with the best available weights
    weights = R3D_18_Weights.DEFAULT
    model = r3d_18(weights=weights)
    model.eval()

    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()

    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(vid).unsqueeze(0)

    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    label = prediction.argmax().item()
    score = prediction[label].item()
    category_name = weights.meta["categories"][label]
    print(f"{category_name}: {100 * score}%")


392
393
394
395
396
397
Table of all available video classification weights
---------------------------------------------------

Accuracies are reported on Kinetics-400

.. include:: generated/video_table.rst