resnet.py 6.4 KB
Newer Older
1
from typing import Any, Type, Union, List, Optional
2

3
4
import torch
import torch.nn as nn
5
from torch import Tensor
6
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
7

8
from ..._internally_replaced_utils import load_state_dict_from_url
9
from .utils import _fuse_modules, _replace_relu, quantize_model
10

11
__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"]
12
13
14


quant_model_urls = {
15
16
17
    "resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
    "resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
    "resnext101_32x8d_fbgemm": "https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
18
19
20
21
}


class QuantizableBasicBlock(BasicBlock):
22
    def __init__(self, *args: Any, **kwargs: Any) -> None:
23
        super().__init__(*args, **kwargs)
24
25
        self.add_relu = torch.nn.quantized.FloatFunctional()

26
    def forward(self, x: Tensor) -> Tensor:
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add_relu.add_relu(out, identity)

        return out

43
44
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
45
        if self.downsample:
46
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
47
48
49


class QuantizableBottleneck(Bottleneck):
50
    def __init__(self, *args: Any, **kwargs: Any) -> None:
51
        super().__init__(*args, **kwargs)
52
53
54
55
        self.skip_add_relu = nn.quantized.FloatFunctional()
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)

56
    def forward(self, x: Tensor) -> Tensor:
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.skip_add_relu.add_relu(out, identity)

        return out

74
75
76
77
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(
            self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
        )
78
        if self.downsample:
79
            _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
80
81
82


class QuantizableResNet(ResNet):
83
    def __init__(self, *args: Any, **kwargs: Any) -> None:
84
        super().__init__(*args, **kwargs)
85

86
87
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
88

89
    def forward(self, x: Tensor) -> Tensor:
90
91
92
93
        x = self.quant(x)
        # Ensure scriptability
        # super(QuantizableResNet,self).forward(x)
        # is not scriptable
94
        x = self._forward_impl(x)
95
96
97
        x = self.dequant(x)
        return x

98
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
99
100
101
102
103
104
        r"""Fuse conv/bn/relu modules in resnet models

        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """
105
        _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
106
        for m in self.modules():
107
            if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
108
                m.fuse_model(is_qat)
109
110


111
112
def _resnet(
    arch: str,
113
    block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
114
115
116
117
118
119
120
    layers: List[int],
    pretrained: bool,
    progress: bool,
    quantize: bool,
    **kwargs: Any,
) -> QuantizableResNet:

121
122
123
124
    model = QuantizableResNet(block, layers, **kwargs)
    _replace_relu(model)
    if quantize:
        # TODO use pretrained as a string to specify the backend
125
        backend = "fbgemm"
126
127
128
129
130
131
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
132
            model_url = quant_model_urls[arch + "_" + backend]
133
134
135
        else:
            model_url = model_urls[arch]

136
        state_dict = load_state_dict_from_url(model_url, progress=progress)
137
138
139
140
141

        model.load_state_dict(state_dict)
    return model


142
143
144
145
146
147
def resnet18(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
148
149
150
151
152
153
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
154
        quantize (bool): If True, return a quantized version of the model
155
    """
156
    return _resnet("resnet18", QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs)
157
158


159
160
161
162
163
164
165
def resnet50(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:

166
167
168
169
170
171
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
172
        quantize (bool): If True, return a quantized version of the model
173
    """
174
    return _resnet("resnet50", QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs)
175
176


177
178
179
180
181
182
def resnext101_32x8d(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
183
184
185
186
187
188
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
189
        quantize (bool): If True, return a quantized version of the model
190
    """
191
192
193
    kwargs["groups"] = 32
    kwargs["width_per_group"] = 8
    return _resnet("resnext101_32x8d", QuantizableBottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs)