resnet.py 10.5 KB
Newer Older
1
from functools import partial
2
from typing import Any, Type, Union, List, Optional
3

4
5
import torch
import torch.nn as nn
6
from torch import Tensor
7
8
9
10
11
12
13
14
15
from torchvision.models.resnet import (
    Bottleneck,
    BasicBlock,
    ResNet,
    ResNet18_Weights,
    ResNet50_Weights,
    ResNeXt101_32X8D_Weights,
)

16
from ...transforms._presets import ImageClassification
17
18
19
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
20
from .utils import _fuse_modules, _replace_relu, quantize_model
21
22


23
24
25
26
27
28
29
30
31
__all__ = [
    "QuantizableResNet",
    "ResNet18_QuantizedWeights",
    "ResNet50_QuantizedWeights",
    "ResNeXt101_32X8D_QuantizedWeights",
    "resnet18",
    "resnet50",
    "resnext101_32x8d",
]
32
33
34


class QuantizableBasicBlock(BasicBlock):
35
    def __init__(self, *args: Any, **kwargs: Any) -> None:
36
        super().__init__(*args, **kwargs)
37
38
        self.add_relu = torch.nn.quantized.FloatFunctional()

39
    def forward(self, x: Tensor) -> Tensor:
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add_relu.add_relu(out, identity)

        return out

56
57
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
58
        if self.downsample:
59
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
60
61
62


class QuantizableBottleneck(Bottleneck):
63
    def __init__(self, *args: Any, **kwargs: Any) -> None:
64
        super().__init__(*args, **kwargs)
65
66
67
68
        self.skip_add_relu = nn.quantized.FloatFunctional()
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)

69
    def forward(self, x: Tensor) -> Tensor:
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.skip_add_relu.add_relu(out, identity)

        return out

87
88
89
90
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(
            self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
        )
91
        if self.downsample:
92
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
93
94
95


class QuantizableResNet(ResNet):
96
    def __init__(self, *args: Any, **kwargs: Any) -> None:
97
        super().__init__(*args, **kwargs)
98

99
100
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
101

102
    def forward(self, x: Tensor) -> Tensor:
103
104
105
106
        x = self.quant(x)
        # Ensure scriptability
        # super(QuantizableResNet,self).forward(x)
        # is not scriptable
107
        x = self._forward_impl(x)
108
109
110
        x = self.dequant(x)
        return x

111
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
112
113
114
115
116
117
        r"""Fuse conv/bn/relu modules in resnet models

        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """
118
        _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
119
        for m in self.modules():
120
            if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
121
                m.fuse_model(is_qat)
122
123


124
def _resnet(
125
    block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
126
    layers: List[int],
127
    weights: Optional[WeightsEnum],
128
129
130
131
    progress: bool,
    quantize: bool,
    **kwargs: Any,
) -> QuantizableResNet:
132
133
134
135
136
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "fbgemm")
137

138
139
140
141
142
    model = QuantizableResNet(block, layers, **kwargs)
    _replace_relu(model)
    if quantize:
        quantize_model(model, backend)

143
144
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
145
146
147
148

    return model


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
_COMMON_META = {
    "task": "image_classification",
    "size": (224, 224),
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "backend": "fbgemm",
    "quantization": "Post Training Quantization",
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}


class ResNet18_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "architecture": "ResNet",
            "num_params": 11689512,
            "unquantized": ResNet18_Weights.IMAGENET1K_V1,
            "acc@1": 69.494,
            "acc@5": 88.882,
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


class ResNet50_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "architecture": "ResNet",
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V1,
            "acc@1": 75.920,
            "acc@5": 92.814,
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "architecture": "ResNet",
            "num_params": 25557032,
            "unquantized": ResNet50_Weights.IMAGENET1K_V2,
            "acc@1": 80.282,
            "acc@5": 94.976,
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "architecture": "ResNeXt",
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
            "acc@1": 78.986,
            "acc@5": 94.480,
        },
    )
    IMAGENET1K_FBGEMM_V2 = Weights(
        url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "architecture": "ResNeXt",
            "num_params": 88791336,
            "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
            "acc@1": 82.574,
            "acc@5": 96.132,
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V2


@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet18_Weights.IMAGENET1K_V1,
    )
)
240
def resnet18(
241
242
    *,
    weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
243
244
245
246
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
247
248
249
250
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
251
252
        weights (ResNet18_QuantizedWeights or ResNet18_Weights, optional): The pretrained
            weights for the model
253
        progress (bool): If True, displays a progress bar of the download to stderr
254
        quantize (bool): If True, return a quantized version of the model
255
    """
256
257
258
    weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)

    return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
259
260


261
262
263
264
265
266
267
268
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNet50_Weights.IMAGENET1K_V1,
    )
)
269
def resnet50(
270
271
    *,
    weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
272
273
274
275
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
276
277
278
279
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
280
281
        weights (ResNet50_QuantizedWeights or ResNet50_Weights, optional): The pretrained
            weights for the model
282
        progress (bool): If True, displays a progress bar of the download to stderr
283
        quantize (bool): If True, return a quantized version of the model
284
    """
285
    weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
286

287
    return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
288

289
290
291
292
293
294
295
296
297

@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
    )
)
298
def resnext101_32x8d(
299
300
    *,
    weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
301
302
303
304
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
305
306
307
308
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
309
310
        weights (ResNeXt101_32X8D_QuantizedWeights or ResNeXt101_32X8D_Weights, optional): The pretrained
            weights for the model
311
        progress (bool): If True, displays a progress bar of the download to stderr
312
        quantize (bool): If True, return a quantized version of the model
313
    """
314
315
316
317
318
    weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)

    _ovewrite_named_param(kwargs, "groups", 32)
    _ovewrite_named_param(kwargs, "width_per_group", 8)
    return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)