resnet.py 6.6 KB
Newer Older
1
2
3
import torch
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
import torch.nn as nn
4
5
6
from torch import Tensor
from typing import Any, Type, Union, List

7
from ..._internally_replaced_utils import load_state_dict_from_url
Aditya Oke's avatar
Aditya Oke committed
8
from torch.quantization import fuse_modules
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from .utils import _replace_relu, quantize_model

__all__ = ['QuantizableResNet', 'resnet18', 'resnet50',
           'resnext101_32x8d']


quant_model_urls = {
    'resnet18_fbgemm':
        'https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth',
    'resnet50_fbgemm':
        'https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth',
    'resnext101_32x8d_fbgemm':
        'https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth',
}


class QuantizableBasicBlock(BasicBlock):
26
    def __init__(self, *args: Any, **kwargs: Any) -> None:
27
28
29
        super(QuantizableBasicBlock, self).__init__(*args, **kwargs)
        self.add_relu = torch.nn.quantized.FloatFunctional()

30
    def forward(self, x: Tensor) -> Tensor:
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add_relu.add_relu(out, identity)

        return out

47
    def fuse_model(self) -> None:
48
49
50
51
52
53
54
        torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
                                               ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)


class QuantizableBottleneck(Bottleneck):
55
    def __init__(self, *args: Any, **kwargs: Any) -> None:
56
57
58
59
60
        super(QuantizableBottleneck, self).__init__(*args, **kwargs)
        self.skip_add_relu = nn.quantized.FloatFunctional()
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)

61
    def forward(self, x: Tensor) -> Tensor:
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.skip_add_relu.add_relu(out, identity)

        return out

79
    def fuse_model(self) -> None:
80
81
82
83
84
85
86
87
88
        fuse_modules(self, [['conv1', 'bn1', 'relu1'],
                            ['conv2', 'bn2', 'relu2'],
                            ['conv3', 'bn3']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)


class QuantizableResNet(ResNet):

89
    def __init__(self, *args: Any, **kwargs: Any) -> None:
90
91
92
93
94
        super(QuantizableResNet, self).__init__(*args, **kwargs)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

95
    def forward(self, x: Tensor) -> Tensor:
96
97
98
99
        x = self.quant(x)
        # Ensure scriptability
        # super(QuantizableResNet,self).forward(x)
        # is not scriptable
100
        x = self._forward_impl(x)
101
102
103
        x = self.dequant(x)
        return x

104
    def fuse_model(self) -> None:
105
106
107
108
109
110
111
112
113
114
115
116
117
        r"""Fuse conv/bn/relu modules in resnet models

        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True)
        for m in self.modules():
            if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock:
                m.fuse_model()


118
119
120
121
122
123
124
125
126
127
def _resnet(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    quantize: bool,
    **kwargs: Any,
) -> QuantizableResNet:

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    model = QuantizableResNet(block, layers, **kwargs)
    _replace_relu(model)
    if quantize:
        # TODO use pretrained as a string to specify the backend
        backend = 'fbgemm'
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
            model_url = quant_model_urls[arch + '_' + backend]
        else:
            model_url = model_urls[arch]

        state_dict = load_state_dict_from_url(model_url,
                                              progress=progress)

        model.load_state_dict(state_dict)
    return model


150
151
152
153
154
155
def resnet18(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
156
157
158
159
160
161
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
162
        quantize (bool): If True, return a quantized version of the model
163
164
165
166
167
    """
    return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress,
                   quantize, **kwargs)


168
169
170
171
172
173
174
def resnet50(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:

175
176
177
178
179
180
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
181
        quantize (bool): If True, return a quantized version of the model
182
183
184
185
186
    """
    return _resnet('resnet50', QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress,
                   quantize, **kwargs)


187
188
189
190
191
192
def resnext101_32x8d(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableResNet:
193
194
195
196
197
198
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
199
        quantize (bool): If True, return a quantized version of the model
200
201
202
203
204
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', QuantizableBottleneck, [3, 4, 23, 3],
                   pretrained, progress, quantize, **kwargs)