resnet.py 17.4 KB
Newer Older
1
from functools import partial
2
from typing import Any, List, Optional, Type, Union
3

4
5
import torch
import torch.nn as nn
6
from torch import Tensor
7
8
from torchvision.models.resnet import (
    BasicBlock,
9
    Bottleneck,
10
11
12
13
    ResNet,
    ResNet18_Weights,
    ResNet50_Weights,
    ResNeXt101_32X8D_Weights,
14
    ResNeXt101_64X4D_Weights,
15
16
)

17
from ...transforms._presets import ImageClassification
18
from .._api import register_model, Weights, WeightsEnum
19
from .._meta import _IMAGENET_CATEGORIES
20
from .._utils import _ovewrite_named_param, handle_legacy_interface
21
from .utils import _fuse_modules, _replace_relu, quantize_model
22
23


24
25
26
27
28
__all__ = [
    "QuantizableResNet",
    "ResNet18_QuantizedWeights",
    "ResNet50_QuantizedWeights",
    "ResNeXt101_32X8D_QuantizedWeights",
29
    "ResNeXt101_64X4D_QuantizedWeights",
30
31
32
    "resnet18",
    "resnet50",
    "resnext101_32x8d",
33
    "resnext101_64x4d",
34
]
35
36
37


class QuantizableBasicBlock(BasicBlock):
38
    def __init__(self, *args: Any, **kwargs: Any) -> None:
39
        super().__init__(*args, **kwargs)
40
41
        self.add_relu = torch.nn.quantized.FloatFunctional()

42
    def forward(self, x: Tensor) -> Tensor:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add_relu.add_relu(out, identity)

        return out

59
60
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
61
        if self.downsample:
62
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
63
64
65


class QuantizableBottleneck(Bottleneck):
66
    def __init__(self, *args: Any, **kwargs: Any) -> None:
67
        super().__init__(*args, **kwargs)
68
69
70
71
        self.skip_add_relu = nn.quantized.FloatFunctional()
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)

72
    def forward(self, x: Tensor) -> Tensor:
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.skip_add_relu.add_relu(out, identity)

        return out

90
91
92
93
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(
            self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
        )
94
        if self.downsample:
95
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
96
97
98


class QuantizableResNet(ResNet):
99
    def __init__(self, *args: Any, **kwargs: Any) -> None:
100
        super().__init__(*args, **kwargs)
101

102
103
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
104

105
    def forward(self, x: Tensor) -> Tensor:
106
107
108
109
        x = self.quant(x)
        # Ensure scriptability
        # super(QuantizableResNet,self).forward(x)
        # is not scriptable
110
        x = self._forward_impl(x)
111
112
113
        x = self.dequant(x)
        return x

114
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
115
116
117
118
119
120
        r"""Fuse conv/bn/relu modules in resnet models

        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """
121
        _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
122
        for m in self.modules():
123
            if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
124
                m.fuse_model(is_qat)
125
126


127
def _resnet(
128
    block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
129
    layers: List[int],
130
    weights: Optional[WeightsEnum],
131
132
133
134
    progress: bool,
    quantize: bool,
    **kwargs: Any,
) -> QuantizableResNet:
135
136
137
138
139
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "fbgemm")
140

141
142
143
144
145
    model = QuantizableResNet(block, layers, **kwargs)
    _replace_relu(model)
    if quantize:
        quantize_model(model, backend)

146
147
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
148
149
150
151

    return model


152
153
154
155
156
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "backend": "fbgemm",
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
157
158
159
160
    "_docs": """
        These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized
        weights listed below.
    """,
161
162
163
164
165
166
167
168
169
170
171
}


class ResNet18_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 11689512,
            "unquantized": ResNet18_Weights.IMAGENET1K_V1,
172
173
174
175
176
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 69.494,
                    "acc@5": 88.882,
                }
177
            },
178
179
180
181
182
183
184
185
186
187
188
189
190
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


class ResNet50_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V1,
191
192
193
194
195
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 75.920,
                    "acc@5": 92.814,
                }
196
            },
197
198
199
200
201
202
203
204
205
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V2,
206
207
208
209
210
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.282,
                    "acc@5": 94.976,
                }
211
            },
212
213
214
215
216
217
218
219
220
221
222
223
224
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
225
226
227
228
229
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.986,
                    "acc@5": 94.480,
                }
230
            },
231
232
233
234
235
236
237
238
239
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
240
241
242
243
244
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.574,
                    "acc@5": 96.132,
                }
245
            },
246
247
248
249
250
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


251
252
253
254
255
256
257
258
259
class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_64x4d_fbgemm-605a1cb3.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 83455272,
            "recipe": "https://github.com/pytorch/vision/pull/5935",
            "unquantized": ResNeXt101_64X4D_Weights.IMAGENET1K_V1,
260
261
262
263
264
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.898,
                    "acc@5": 96.326,
                }
265
266
267
268
269
270
            },
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


271
@register_model(name="quantized_resnet18")
272
273
274
275
276
277
278
279
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet18_Weights.IMAGENET1K_V1,
    )
)
280
def resnet18(
281
282
    *,
    weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
283
284
285
286
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
287
    """ResNet-18 model from
288
    `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
289

290
291
292
293
294
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.

295
    Args:
296
297
298
299
300
301
302
303
304
305
        weights (:class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` or :class:`~torchvision.models.ResNet18_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool, optional): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
            base class. Please refer to the `source code
Nicolas Hug's avatar
Nicolas Hug committed
306
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
307
308
309
310
311
312
313
314
            for more details about this class.

    .. autoclass:: torchvision.models.quantization.ResNet18_QuantizedWeights
        :members:

    .. autoclass:: torchvision.models.ResNet18_Weights
        :members:
        :noindex:
315
    """
316
317
318
    weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)

    return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
319
320


321
@register_model(name="quantized_resnet50")
322
323
324
325
326
327
328
329
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet50_Weights.IMAGENET1K_V1,
    )
)
330
def resnet50(
331
332
    *,
    weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
333
334
335
336
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
337
    """ResNet-50 model from
338
    `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
339

340
341
342
343
344
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.

345
    Args:
346
347
348
349
350
351
352
353
354
355
        weights (:class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` or :class:`~torchvision.models.ResNet50_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool, optional): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
            base class. Please refer to the `source code
Nicolas Hug's avatar
Nicolas Hug committed
356
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
357
358
359
360
361
362
363
364
            for more details about this class.

    .. autoclass:: torchvision.models.quantization.ResNet50_QuantizedWeights
        :members:

    .. autoclass:: torchvision.models.ResNet50_Weights
        :members:
        :noindex:
365
    """
366
    weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
367

368
    return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
369

370

371
@register_model(name="quantized_resnext101_32x8d")
372
373
374
375
376
377
378
379
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
    )
)
380
def resnext101_32x8d(
381
382
    *,
    weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
383
384
385
386
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
387
    """ResNeXt-101 32x8d model from
388
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
389

390
391
392
393
394
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.

395
    Args:
Aditya Oke's avatar
Aditya Oke committed
396
        weights (:class:`~torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
397
398
399
400
401
402
403
404
405
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.ResNet101_32X8D_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool, optional): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
            base class. Please refer to the `source code
Nicolas Hug's avatar
Nicolas Hug committed
406
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
407
408
            for more details about this class.

Aditya Oke's avatar
Aditya Oke committed
409
    .. autoclass:: torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights
410
411
        :members:

Aditya Oke's avatar
Aditya Oke committed
412
    .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
413
414
        :members:
        :noindex:
415
    """
416
417
418
419
420
    weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)

    _ovewrite_named_param(kwargs, "groups", 32)
    _ovewrite_named_param(kwargs, "width_per_group", 8)
    return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
421
422


423
@register_model(name="quantized_resnext101_64x4d")
424
425
426
427
428
429
430
def resnext101_64x4d(
    *,
    weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
431
    """ResNeXt-101 64x4d model from
432
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
433

434
435
436
437
438
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.

439
    Args:
Aditya Oke's avatar
Aditya Oke committed
440
        weights (:class:`~torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
441
442
443
444
445
446
447
448
449
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.ResNet101_64X4D_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool, optional): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
            base class. Please refer to the `source code
Nicolas Hug's avatar
Nicolas Hug committed
450
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
451
452
            for more details about this class.

Aditya Oke's avatar
Aditya Oke committed
453
    .. autoclass:: torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights
454
455
        :members:

Aditya Oke's avatar
Aditya Oke committed
456
    .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
457
458
        :members:
        :noindex:
459
460
461
462
463
464
    """
    weights = (ResNeXt101_64X4D_QuantizedWeights if quantize else ResNeXt101_64X4D_Weights).verify(weights)

    _ovewrite_named_param(kwargs, "groups", 64)
    _ovewrite_named_param(kwargs, "width_per_group", 4)
    return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
465
466
467
468
469
470
471
472
473
474
475
476
477
478


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
from ..resnet import model_urls  # noqa: F401


quant_model_urls = _ModelURLs(
    {
        "resnet18_fbgemm": ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
        "resnet50_fbgemm": ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
        "resnext101_32x8d_fbgemm": ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
    }
)