resnet.py 10.4 KB
Newer Older
1
from functools import partial
2
from typing import Any, Type, Union, List, Optional
3

4
5
import torch
import torch.nn as nn
6
from torch import Tensor
7
8
9
10
11
12
13
14
15
from torchvision.models.resnet import (
    Bottleneck,
    BasicBlock,
    ResNet,
    ResNet18_Weights,
    ResNet50_Weights,
    ResNeXt101_32X8D_Weights,
)

16
from ...transforms._presets import ImageClassification
17
18
19
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
20
from .utils import _fuse_modules, _replace_relu, quantize_model
21
22


23
24
25
26
27
28
29
30
31
__all__ = [
    "QuantizableResNet",
    "ResNet18_QuantizedWeights",
    "ResNet50_QuantizedWeights",
    "ResNeXt101_32X8D_QuantizedWeights",
    "resnet18",
    "resnet50",
    "resnext101_32x8d",
]
32
33
34


class QuantizableBasicBlock(BasicBlock):
35
    def __init__(self, *args: Any, **kwargs: Any) -> None:
36
        super().__init__(*args, **kwargs)
37
38
        self.add_relu = torch.nn.quantized.FloatFunctional()

39
    def forward(self, x: Tensor) -> Tensor:
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add_relu.add_relu(out, identity)

        return out

56
57
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
58
        if self.downsample:
59
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
60
61
62


class QuantizableBottleneck(Bottleneck):
63
    def __init__(self, *args: Any, **kwargs: Any) -> None:
64
        super().__init__(*args, **kwargs)
65
66
67
68
        self.skip_add_relu = nn.quantized.FloatFunctional()
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)

69
    def forward(self, x: Tensor) -> Tensor:
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.skip_add_relu.add_relu(out, identity)

        return out

87
88
89
90
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(
            self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
        )
91
        if self.downsample:
92
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
93
94
95


class QuantizableResNet(ResNet):
96
    def __init__(self, *args: Any, **kwargs: Any) -> None:
97
        super().__init__(*args, **kwargs)
98

99
100
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
101

102
    def forward(self, x: Tensor) -> Tensor:
103
104
105
106
        x = self.quant(x)
        # Ensure scriptability
        # super(QuantizableResNet,self).forward(x)
        # is not scriptable
107
        x = self._forward_impl(x)
108
109
110
        x = self.dequant(x)
        return x

111
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
112
113
114
115
116
117
        r"""Fuse conv/bn/relu modules in resnet models

        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """
118
        _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
119
        for m in self.modules():
120
            if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
121
                m.fuse_model(is_qat)
122
123


124
def _resnet(
125
    block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
126
    layers: List[int],
127
    weights: Optional[WeightsEnum],
128
129
130
131
    progress: bool,
    quantize: bool,
    **kwargs: Any,
) -> QuantizableResNet:
132
133
134
135
136
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "fbgemm")
137

138
139
140
141
142
    model = QuantizableResNet(block, layers, **kwargs)
    _replace_relu(model)
    if quantize:
        quantize_model(model, backend)

143
144
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
145
146
147
148

    return model


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "backend": "fbgemm",
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}


class ResNet18_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 11689512,
            "unquantized": ResNet18_Weights.IMAGENET1K_V1,
165
166
167
168
            "metrics": {
                "acc@1": 69.494,
                "acc@5": 88.882,
            },
169
170
171
172
173
174
175
176
177
178
179
180
181
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


class ResNet50_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V1,
182
183
184
185
            "metrics": {
                "acc@1": 75.920,
                "acc@5": 92.814,
            },
186
187
188
189
190
191
192
193
194
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V2,
195
196
197
198
            "metrics": {
                "acc@1": 80.282,
                "acc@5": 94.976,
            },
199
200
201
202
203
204
205
206
207
208
209
210
211
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
212
213
214
215
            "metrics": {
                "acc@1": 78.986,
                "acc@5": 94.480,
            },
216
217
218
219
220
221
222
223
224
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
225
226
227
228
            "metrics": {
                "acc@1": 82.574,
                "acc@5": 96.132,
            },
229
230
231
232
233
234
235
236
237
238
239
240
241
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet18_Weights.IMAGENET1K_V1,
    )
)
242
def resnet18(
243
244
    *,
    weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
245
246
247
248
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
249
250
251
252
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
253
254
        weights (ResNet18_QuantizedWeights or ResNet18_Weights, optional): The pretrained
            weights for the model
255
        progress (bool): If True, displays a progress bar of the download to stderr
256
        quantize (bool): If True, return a quantized version of the model
257
    """
258
259
260
    weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)

    return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
261
262


263
264
265
266
267
268
269
270
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet50_Weights.IMAGENET1K_V1,
    )
)
271
def resnet50(
272
273
    *,
    weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
274
275
276
277
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
278
279
280
281
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
282
283
        weights (ResNet50_QuantizedWeights or ResNet50_Weights, optional): The pretrained
            weights for the model
284
        progress (bool): If True, displays a progress bar of the download to stderr
285
        quantize (bool): If True, return a quantized version of the model
286
    """
287
    weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
288

289
    return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
290

291
292
293
294
295
296
297
298
299

@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
    )
)
300
def resnext101_32x8d(
301
302
    *,
    weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
303
304
305
306
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
307
308
309
310
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
311
312
        weights (ResNeXt101_32X8D_QuantizedWeights or ResNeXt101_32X8D_Weights, optional): The pretrained
            weights for the model
313
        progress (bool): If True, displays a progress bar of the download to stderr
314
        quantize (bool): If True, return a quantized version of the model
315
    """
316
317
318
319
320
    weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)

    _ovewrite_named_param(kwargs, "groups", 32)
    _ovewrite_named_param(kwargs, "width_per_group", 8)
    return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)