resnet.py 17.5 KB
Newer Older
1
from functools import partial
2
from typing import Any, List, Optional, Type, Union
3

4
5
import torch
import torch.nn as nn
6
from torch import Tensor
7
8
from torchvision.models.resnet import (
    BasicBlock,
9
    Bottleneck,
10
11
12
13
    ResNet,
    ResNet18_Weights,
    ResNet50_Weights,
    ResNeXt101_32X8D_Weights,
14
    ResNeXt101_64X4D_Weights,
15
16
)

17
from ...transforms._presets import ImageClassification
18
from .._api import register_model, Weights, WeightsEnum
19
from .._meta import _IMAGENET_CATEGORIES
20
from .._utils import _ovewrite_named_param, handle_legacy_interface
21
from .utils import _fuse_modules, _replace_relu, quantize_model
22
23


24
25
26
27
28
__all__ = [
    "QuantizableResNet",
    "ResNet18_QuantizedWeights",
    "ResNet50_QuantizedWeights",
    "ResNeXt101_32X8D_QuantizedWeights",
29
    "ResNeXt101_64X4D_QuantizedWeights",
30
31
32
    "resnet18",
    "resnet50",
    "resnext101_32x8d",
33
    "resnext101_64x4d",
34
]
35
36
37


class QuantizableBasicBlock(BasicBlock):
38
    def __init__(self, *args: Any, **kwargs: Any) -> None:
39
        super().__init__(*args, **kwargs)
40
41
        self.add_relu = torch.nn.quantized.FloatFunctional()

42
    def forward(self, x: Tensor) -> Tensor:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add_relu.add_relu(out, identity)

        return out

59
60
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
61
        if self.downsample:
62
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
63
64
65


class QuantizableBottleneck(Bottleneck):
66
    def __init__(self, *args: Any, **kwargs: Any) -> None:
67
        super().__init__(*args, **kwargs)
68
69
70
71
        self.skip_add_relu = nn.quantized.FloatFunctional()
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)

72
    def forward(self, x: Tensor) -> Tensor:
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.skip_add_relu.add_relu(out, identity)

        return out

90
91
92
93
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(
            self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
        )
94
        if self.downsample:
95
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
96
97
98


class QuantizableResNet(ResNet):
99
    def __init__(self, *args: Any, **kwargs: Any) -> None:
100
        super().__init__(*args, **kwargs)
101

102
103
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
104

105
    def forward(self, x: Tensor) -> Tensor:
106
107
108
109
        x = self.quant(x)
        # Ensure scriptability
        # super(QuantizableResNet,self).forward(x)
        # is not scriptable
110
        x = self._forward_impl(x)
111
112
113
        x = self.dequant(x)
        return x

114
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
115
116
117
118
119
120
        r"""Fuse conv/bn/relu modules in resnet models

        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """
121
        _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
122
        for m in self.modules():
123
            if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
124
                m.fuse_model(is_qat)
125
126


127
def _resnet(
128
    block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
129
    layers: List[int],
130
    weights: Optional[WeightsEnum],
131
132
133
134
    progress: bool,
    quantize: bool,
    **kwargs: Any,
) -> QuantizableResNet:
135
136
137
138
139
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "fbgemm")
140

141
142
143
144
145
    model = QuantizableResNet(block, layers, **kwargs)
    _replace_relu(model)
    if quantize:
        quantize_model(model, backend)

146
    if weights is not None:
147
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
148
149
150
151

    return model


152
153
154
155
156
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "backend": "fbgemm",
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
157
158
159
160
    "_docs": """
        These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized
        weights listed below.
    """,
161
162
163
164
165
166
167
168
169
170
171
}


class ResNet18_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 11689512,
            "unquantized": ResNet18_Weights.IMAGENET1K_V1,
172
173
174
175
176
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 69.494,
                    "acc@5": 88.882,
                }
177
            },
178
            "_ops": 1.814,
Nicolas Hug's avatar
Nicolas Hug committed
179
            "_file_size": 11.238,
180
181
182
183
184
185
186
187
188
189
190
191
192
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


class ResNet50_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V1,
193
194
195
196
197
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 75.920,
                    "acc@5": 92.814,
                }
198
            },
199
            "_ops": 4.089,
Nicolas Hug's avatar
Nicolas Hug committed
200
            "_file_size": 24.759,
201
202
203
204
205
206
207
208
209
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V2,
210
211
212
213
214
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.282,
                    "acc@5": 94.976,
                }
215
            },
216
            "_ops": 4.089,
Nicolas Hug's avatar
Nicolas Hug committed
217
            "_file_size": 24.953,
218
219
220
221
222
223
224
225
226
227
228
229
230
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
231
232
233
234
235
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.986,
                    "acc@5": 94.480,
                }
236
            },
237
            "_ops": 16.414,
Nicolas Hug's avatar
Nicolas Hug committed
238
            "_file_size": 86.034,
239
240
241
242
243
244
245
246
247
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
248
249
250
251
252
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.574,
                    "acc@5": 96.132,
                }
253
            },
254
            "_ops": 16.414,
Nicolas Hug's avatar
Nicolas Hug committed
255
            "_file_size": 86.645,
256
257
258
259
260
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


261
262
263
264
265
266
267
268
269
class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_64x4d_fbgemm-605a1cb3.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 83455272,
            "recipe": "https://github.com/pytorch/vision/pull/5935",
            "unquantized": ResNeXt101_64X4D_Weights.IMAGENET1K_V1,
270
271
272
273
274
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.898,
                    "acc@5": 96.326,
                }
275
            },
276
            "_ops": 15.46,
Nicolas Hug's avatar
Nicolas Hug committed
277
            "_file_size": 81.556,
278
279
280
281
282
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


283
@register_model(name="quantized_resnet18")
284
285
286
287
288
289
290
291
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet18_Weights.IMAGENET1K_V1,
    )
)
292
def resnet18(
293
294
    *,
    weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
295
296
297
298
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
299
    """ResNet-18 model from
300
    `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
301

302
303
304
305
306
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.

307
    Args:
308
309
310
311
312
313
314
315
316
317
        weights (:class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` or :class:`~torchvision.models.ResNet18_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool, optional): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
            base class. Please refer to the `source code
Nicolas Hug's avatar
Nicolas Hug committed
318
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
319
320
321
322
323
324
325
326
            for more details about this class.

    .. autoclass:: torchvision.models.quantization.ResNet18_QuantizedWeights
        :members:

    .. autoclass:: torchvision.models.ResNet18_Weights
        :members:
        :noindex:
327
    """
328
329
330
    weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)

    return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
331
332


333
@register_model(name="quantized_resnet50")
334
335
336
337
338
339
340
341
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet50_Weights.IMAGENET1K_V1,
    )
)
342
def resnet50(
343
344
    *,
    weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
345
346
347
348
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
349
    """ResNet-50 model from
350
    `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
351

352
353
354
355
356
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.

357
    Args:
358
359
360
361
362
363
364
365
366
367
        weights (:class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` or :class:`~torchvision.models.ResNet50_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool, optional): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
            base class. Please refer to the `source code
Nicolas Hug's avatar
Nicolas Hug committed
368
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
369
370
371
372
373
374
375
376
            for more details about this class.

    .. autoclass:: torchvision.models.quantization.ResNet50_QuantizedWeights
        :members:

    .. autoclass:: torchvision.models.ResNet50_Weights
        :members:
        :noindex:
377
    """
378
    weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
379

380
    return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
381

382

383
@register_model(name="quantized_resnext101_32x8d")
384
385
386
387
388
389
390
391
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
    )
)
392
def resnext101_32x8d(
393
394
    *,
    weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
395
396
397
398
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
399
    """ResNeXt-101 32x8d model from
400
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
401

402
403
404
405
406
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.

407
    Args:
Aditya Oke's avatar
Aditya Oke committed
408
        weights (:class:`~torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
409
410
411
412
413
414
415
416
417
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.ResNet101_32X8D_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool, optional): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
            base class. Please refer to the `source code
Nicolas Hug's avatar
Nicolas Hug committed
418
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
419
420
            for more details about this class.

Aditya Oke's avatar
Aditya Oke committed
421
    .. autoclass:: torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights
422
423
        :members:

Aditya Oke's avatar
Aditya Oke committed
424
    .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
425
426
        :members:
        :noindex:
427
    """
428
429
430
431
432
    weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)

    _ovewrite_named_param(kwargs, "groups", 32)
    _ovewrite_named_param(kwargs, "width_per_group", 8)
    return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
433
434


435
@register_model(name="quantized_resnext101_64x4d")
436
437
438
439
440
441
442
443
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNeXt101_64X4D_Weights.IMAGENET1K_V1,
    )
)
444
445
446
447
448
449
450
def resnext101_64x4d(
    *,
    weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
451
    """ResNeXt-101 64x4d model from
452
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
453

454
455
456
457
458
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.

459
    Args:
Aditya Oke's avatar
Aditya Oke committed
460
        weights (:class:`~torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
461
462
463
464
465
466
467
468
469
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.ResNet101_64X4D_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool, optional): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
            base class. Please refer to the `source code
Nicolas Hug's avatar
Nicolas Hug committed
470
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
471
472
            for more details about this class.

Aditya Oke's avatar
Aditya Oke committed
473
    .. autoclass:: torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights
474
475
        :members:

Aditya Oke's avatar
Aditya Oke committed
476
    .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
477
478
        :members:
        :noindex:
479
480
481
482
483
484
    """
    weights = (ResNeXt101_64X4D_QuantizedWeights if quantize else ResNeXt101_64X4D_Weights).verify(weights)

    _ovewrite_named_param(kwargs, "groups", 64)
    _ovewrite_named_param(kwargs, "width_per_group", 4)
    return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)