resnet.h 7.04 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef RESNET_H
#define RESNET_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

namespace vision {
namespace models {
template <typename Block>
struct ResNetImpl;

namespace _resnetimpl {
// 3x3 convolution with padding
torch::nn::Conv2d conv3x3(
    int64_t in,
    int64_t out,
    int64_t stride = 1,
    int64_t groups = 1);

// 1x1 convolution
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride = 1);

23
struct VISION_API BasicBlock : torch::nn::Module {
Shahriar's avatar
Shahriar committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  template <typename Block>
  friend struct vision::models::ResNetImpl;

  int64_t stride;
  torch::nn::Sequential downsample;

  torch::nn::Conv2d conv1{nullptr}, conv2{nullptr};
  torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr};

  static int expansion;

  BasicBlock(
      int64_t inplanes,
      int64_t planes,
      int64_t stride = 1,
      torch::nn::Sequential downsample = nullptr,
      int64_t groups = 1,
      int64_t base_width = 64);

  torch::Tensor forward(torch::Tensor x);
};

46
struct VISION_API Bottleneck : torch::nn::Module {
Shahriar's avatar
Shahriar committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  template <typename Block>
  friend struct vision::models::ResNetImpl;

  int64_t stride;
  torch::nn::Sequential downsample;

  torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
  torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr}, bn3{nullptr};

  static int expansion;

  Bottleneck(
      int64_t inplanes,
      int64_t planes,
      int64_t stride = 1,
      torch::nn::Sequential downsample = nullptr,
      int64_t groups = 1,
      int64_t base_width = 64);

  torch::Tensor forward(torch::Tensor X);
};
} // namespace _resnetimpl

template <typename Block>
struct ResNetImpl : torch::nn::Module {
  int64_t groups, base_width, inplanes;
  torch::nn::Conv2d conv1;
  torch::nn::BatchNorm bn1;
  torch::nn::Sequential layer1, layer2, layer3, layer4;
76
  torch::nn::Linear fc;
Shahriar's avatar
Shahriar committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

  torch::nn::Sequential _make_layer(
      int64_t planes,
      int64_t blocks,
      int64_t stride = 1);

  ResNetImpl(
      const std::vector<int>& layers,
      int64_t num_classes = 1000,
      bool zero_init_residual = false,
      int64_t groups = 1,
      int64_t width_per_group = 64);

  torch::Tensor forward(torch::Tensor X);
};

template <typename Block>
torch::nn::Sequential ResNetImpl<Block>::_make_layer(
    int64_t planes,
    int64_t blocks,
    int64_t stride) {
  torch::nn::Sequential downsample = nullptr;
  if (stride != 1 || inplanes != planes * Block::expansion) {
    downsample = torch::nn::Sequential(
        _resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride),
        torch::nn::BatchNorm(planes * Block::expansion));
  }

  torch::nn::Sequential layers;
  layers->push_back(
      Block(inplanes, planes, stride, downsample, groups, base_width));

  inplanes = planes * Block::expansion;

  for (int i = 1; i < blocks; ++i)
    layers->push_back(Block(inplanes, planes, 1, nullptr, groups, base_width));

  return layers;
}

template <typename Block>
ResNetImpl<Block>::ResNetImpl(
    const std::vector<int>& layers,
    int64_t num_classes,
    bool zero_init_residual,
    int64_t groups,
    int64_t width_per_group)
    : groups(groups),
      base_width(width_per_group),
      inplanes(64),
      conv1(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).with_bias(
          false)),
      bn1(64),
      layer1(_make_layer(64, layers[0])),
      layer2(_make_layer(128, layers[1], 2)),
      layer3(_make_layer(256, layers[2], 2)),
      layer4(_make_layer(512, layers[3], 2)),
      fc(512 * Block::expansion, num_classes) {
  register_module("conv1", conv1);
  register_module("bn1", bn1);
  register_module("fc", fc);

  register_module("layer1", layer1);
  register_module("layer2", layer2);
  register_module("layer3", layer3);
  register_module("layer4", layer4);

  for (auto& module : modules(/*include_self=*/false)) {
    if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
      torch::nn::init::kaiming_normal_(
          M->weight,
          /*a=*/0,
149
150
          torch::nn::init::FanMode::FanOut,
          torch::nn::init::Nonlinearity::ReLU);
Shahriar's avatar
Shahriar committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
      torch::nn::init::constant_(M->weight, 1);
      torch::nn::init::constant_(M->bias, 0);
    }
  }

  // Zero-initialize the last BN in each residual branch, so that the residual
  // branch starts with zeros, and each residual block behaves like an
  // identity. This improves the model by 0.2~0.3% according to
  // https://arxiv.org/abs/1706.02677
  if (zero_init_residual)
    for (auto& module : modules(/*include_self=*/false)) {
      if (auto* M = dynamic_cast<_resnetimpl::Bottleneck*>(module.get()))
        torch::nn::init::constant_(M->bn3->weight, 0);
      else if (auto* M = dynamic_cast<_resnetimpl::BasicBlock*>(module.get()))
        torch::nn::init::constant_(M->bn2->weight, 0);
    }
}

template <typename Block>
torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
  x = conv1->forward(x);
  x = bn1->forward(x).relu_();
  x = torch::max_pool2d(x, 3, 2, 1);

  x = layer1->forward(x);
  x = layer2->forward(x);
  x = layer3->forward(x);
  x = layer4->forward(x);

  x = torch::adaptive_avg_pool2d(x, {1, 1});
  x = x.reshape({x.size(0), -1});
  x = fc->forward(x);

  return x;
}

188
struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
Shahriar's avatar
Shahriar committed
189
190
191
  ResNet18Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};

192
struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
Shahriar's avatar
Shahriar committed
193
194
195
  ResNet34Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};

196
struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
Shahriar's avatar
Shahriar committed
197
198
199
  ResNet50Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};

200
struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
Shahriar's avatar
Shahriar committed
201
202
203
  ResNet101Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};

204
struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
Shahriar's avatar
Shahriar committed
205
206
207
  ResNet152Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
};

208
struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
Shahriar's avatar
Shahriar committed
209
210
211
212
213
  ResNext50_32x4dImpl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

214
struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
Shahriar's avatar
Shahriar committed
215
216
217
218
219
  ResNext101_32x8dImpl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

220
221
222
223
224
225
226
227
228
229
230
231
struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
  WideResNet50_2Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
  WideResNet101_2Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

Shahriar's avatar
Shahriar committed
232
template <typename Block>
233
struct VISION_API ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> {
Shahriar's avatar
Shahriar committed
234
235
236
237
238
239
240
241
242
243
  using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
};

TORCH_MODULE(ResNet18);
TORCH_MODULE(ResNet34);
TORCH_MODULE(ResNet50);
TORCH_MODULE(ResNet101);
TORCH_MODULE(ResNet152);
TORCH_MODULE(ResNext50_32x4d);
TORCH_MODULE(ResNext101_32x8d);
244
245
TORCH_MODULE(WideResNet50_2);
TORCH_MODULE(WideResNet101_2);
Shahriar's avatar
Shahriar committed
246
247
248
249
250

} // namespace models
} // namespace vision

#endif // RESNET_H