resnet.py 12.6 KB
Newer Older
1
from functools import partial
2
from typing import Any, Type, Union, List, Optional
3

4
5
import torch
import torch.nn as nn
6
from torch import Tensor
7
8
9
10
11
12
13
from torchvision.models.resnet import (
    Bottleneck,
    BasicBlock,
    ResNet,
    ResNet18_Weights,
    ResNet50_Weights,
    ResNeXt101_32X8D_Weights,
14
    ResNeXt101_64X4D_Weights,
15
16
)

17
from ...transforms._presets import ImageClassification
18
19
20
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
21
from .utils import _fuse_modules, _replace_relu, quantize_model
22
23


24
25
26
27
28
__all__ = [
    "QuantizableResNet",
    "ResNet18_QuantizedWeights",
    "ResNet50_QuantizedWeights",
    "ResNeXt101_32X8D_QuantizedWeights",
29
    "ResNeXt101_64X4D_QuantizedWeights",
30
31
32
    "resnet18",
    "resnet50",
    "resnext101_32x8d",
33
    "resnext101_64x4d",
34
]
35
36
37


class QuantizableBasicBlock(BasicBlock):
38
    def __init__(self, *args: Any, **kwargs: Any) -> None:
39
        super().__init__(*args, **kwargs)
40
41
        self.add_relu = torch.nn.quantized.FloatFunctional()

42
    def forward(self, x: Tensor) -> Tensor:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add_relu.add_relu(out, identity)

        return out

59
60
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
61
        if self.downsample:
62
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
63
64
65


class QuantizableBottleneck(Bottleneck):
66
    def __init__(self, *args: Any, **kwargs: Any) -> None:
67
        super().__init__(*args, **kwargs)
68
69
70
71
        self.skip_add_relu = nn.quantized.FloatFunctional()
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)

72
    def forward(self, x: Tensor) -> Tensor:
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.skip_add_relu.add_relu(out, identity)

        return out

90
91
92
93
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(
            self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
        )
94
        if self.downsample:
95
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
96
97
98


class QuantizableResNet(ResNet):
99
    def __init__(self, *args: Any, **kwargs: Any) -> None:
100
        super().__init__(*args, **kwargs)
101

102
103
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
104

105
    def forward(self, x: Tensor) -> Tensor:
106
107
108
109
        x = self.quant(x)
        # Ensure scriptability
        # super(QuantizableResNet,self).forward(x)
        # is not scriptable
110
        x = self._forward_impl(x)
111
112
113
        x = self.dequant(x)
        return x

114
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
115
116
117
118
119
120
        r"""Fuse conv/bn/relu modules in resnet models

        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """
121
        _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
122
        for m in self.modules():
123
            if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
124
                m.fuse_model(is_qat)
125
126


127
def _resnet(
128
    block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
129
    layers: List[int],
130
    weights: Optional[WeightsEnum],
131
132
133
134
    progress: bool,
    quantize: bool,
    **kwargs: Any,
) -> QuantizableResNet:
135
136
137
138
139
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "fbgemm")
140

141
142
143
144
145
    model = QuantizableResNet(block, layers, **kwargs)
    _replace_relu(model)
    if quantize:
        quantize_model(model, backend)

146
147
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
148
149
150
151

    return model


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "backend": "fbgemm",
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}


class ResNet18_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 11689512,
            "unquantized": ResNet18_Weights.IMAGENET1K_V1,
168
169
170
171
            "metrics": {
                "acc@1": 69.494,
                "acc@5": 88.882,
            },
172
173
174
175
176
177
178
179
180
181
182
183
184
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


class ResNet50_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V1,
185
186
187
188
            "metrics": {
                "acc@1": 75.920,
                "acc@5": 92.814,
            },
189
190
191
192
193
194
195
196
197
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V2,
198
199
200
201
            "metrics": {
                "acc@1": 80.282,
                "acc@5": 94.976,
            },
202
203
204
205
206
207
208
209
210
211
212
213
214
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
215
216
217
218
            "metrics": {
                "acc@1": 78.986,
                "acc@5": 94.480,
            },
219
220
221
222
223
224
225
226
227
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
228
229
230
231
            "metrics": {
                "acc@1": 82.574,
                "acc@5": 96.132,
            },
232
233
234
235
236
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_64x4d_fbgemm-605a1cb3.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 83455272,
            "recipe": "https://github.com/pytorch/vision/pull/5935",
            "unquantized": ResNeXt101_64X4D_Weights.IMAGENET1K_V1,
            "metrics": {
                "acc@1": 82.898,
                "acc@5": 96.326,
            },
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


255
256
257
258
259
260
261
262
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet18_Weights.IMAGENET1K_V1,
    )
)
263
def resnet18(
264
265
    *,
    weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
266
267
268
269
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
270
271
272
273
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
274
275
        weights (ResNet18_QuantizedWeights or ResNet18_Weights, optional): The pretrained
            weights for the model
276
        progress (bool): If True, displays a progress bar of the download to stderr
277
        quantize (bool): If True, return a quantized version of the model
278
    """
279
280
281
    weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)

    return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
282
283


284
285
286
287
288
289
290
291
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet50_Weights.IMAGENET1K_V1,
    )
)
292
def resnet50(
293
294
    *,
    weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
295
296
297
298
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
299
300
301
302
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
303
304
        weights (ResNet50_QuantizedWeights or ResNet50_Weights, optional): The pretrained
            weights for the model
305
        progress (bool): If True, displays a progress bar of the download to stderr
306
        quantize (bool): If True, return a quantized version of the model
307
    """
308
    weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
309

310
    return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
311

312
313
314
315
316
317
318
319
320

@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
    )
)
321
def resnext101_32x8d(
322
323
    *,
    weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
324
325
326
327
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
328
329
330
331
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
332
333
        weights (ResNeXt101_32X8D_QuantizedWeights or ResNeXt101_32X8D_Weights, optional): The pretrained
            weights for the model
334
        progress (bool): If True, displays a progress bar of the download to stderr
335
        quantize (bool): If True, return a quantized version of the model
336
    """
337
338
339
340
341
    weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)

    _ovewrite_named_param(kwargs, "groups", 32)
    _ovewrite_named_param(kwargs, "width_per_group", 8)
    return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364


def resnext101_64x4d(
    *,
    weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
    r"""ResNeXt-101 64x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        weights (ResNeXt101_64X4D_QuantizedWeights or ResNeXt101_64X4D_Weights, optional): The pretrained
            weights for the model
        progress (bool): If True, displays a progress bar of the download to stderr
        quantize (bool): If True, return a quantized version of the model
    """
    weights = (ResNeXt101_64X4D_QuantizedWeights if quantize else ResNeXt101_64X4D_Weights).verify(weights)

    _ovewrite_named_param(kwargs, "groups", 64)
    _ovewrite_named_param(kwargs, "width_per_group", 4)
    return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
365
366
367
368
369
370
371
372
373
374
375
376
377
378


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
from ..resnet import model_urls  # noqa: F401


quant_model_urls = _ModelURLs(
    {
        "resnet18_fbgemm": ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
        "resnet50_fbgemm": ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
        "resnext101_32x8d_fbgemm": ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
    }
)