resnet.h 7.19 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
5
#include "modelsimpl.h"
Shahriar's avatar
Shahriar committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

namespace vision {
namespace models {
template <typename Block>
struct ResNetImpl;

namespace _resnetimpl {
// 3x3 convolution with padding
torch::nn::Conv2d conv3x3(
    int64_t in,
    int64_t out,
    int64_t stride = 1,
    int64_t groups = 1);

// 1x1 convolution
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride = 1);

23
struct VISION_API BasicBlock : torch::nn::Module {
Shahriar's avatar
Shahriar committed
24
25
26
27
28
29
30
  template <typename Block>
  friend struct vision::models::ResNetImpl;

  int64_t stride;
  torch::nn::Sequential downsample;

  torch::nn::Conv2d conv1{nullptr}, conv2{nullptr};
31
  torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr};
Shahriar's avatar
Shahriar committed
32
33
34
35
36
37
38

  static int expansion;

  BasicBlock(
      int64_t inplanes,
      int64_t planes,
      int64_t stride = 1,
39
      const torch::nn::Sequential& downsample = nullptr,
Shahriar's avatar
Shahriar committed
40
41
42
43
44
45
      int64_t groups = 1,
      int64_t base_width = 64);

  torch::Tensor forward(torch::Tensor x);
};

46
struct VISION_API Bottleneck : torch::nn::Module {
Shahriar's avatar
Shahriar committed
47
48
49
50
51
52
53
  template <typename Block>
  friend struct vision::models::ResNetImpl;

  int64_t stride;
  torch::nn::Sequential downsample;

  torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
54
  torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
Shahriar's avatar
Shahriar committed
55
56
57
58
59
60
61

  static int expansion;

  Bottleneck(
      int64_t inplanes,
      int64_t planes,
      int64_t stride = 1,
62
      const torch::nn::Sequential& downsample = nullptr,
Shahriar's avatar
Shahriar committed
63
64
65
66
67
68
69
70
71
72
73
      int64_t groups = 1,
      int64_t base_width = 64);

  torch::Tensor forward(torch::Tensor X);
};
} // namespace _resnetimpl

template <typename Block>
struct ResNetImpl : torch::nn::Module {
  int64_t groups, base_width, inplanes;
  torch::nn::Conv2d conv1;
74
  torch::nn::BatchNorm2d bn1;
Shahriar's avatar
Shahriar committed
75
  torch::nn::Sequential layer1, layer2, layer3, layer4;
76
  torch::nn::Linear fc;
Shahriar's avatar
Shahriar committed
77
78
79
80
81
82

  torch::nn::Sequential _make_layer(
      int64_t planes,
      int64_t blocks,
      int64_t stride = 1);

83
  explicit ResNetImpl(
Shahriar's avatar
Shahriar committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
      const std::vector<int>& layers,
      int64_t num_classes = 1000,
      bool zero_init_residual = false,
      int64_t groups = 1,
      int64_t width_per_group = 64);

  torch::Tensor forward(torch::Tensor X);
};

template <typename Block>
torch::nn::Sequential ResNetImpl<Block>::_make_layer(
    int64_t planes,
    int64_t blocks,
    int64_t stride) {
  torch::nn::Sequential downsample = nullptr;
  if (stride != 1 || inplanes != planes * Block::expansion) {
    downsample = torch::nn::Sequential(
        _resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride),
102
        torch::nn::BatchNorm2d(planes * Block::expansion));
Shahriar's avatar
Shahriar committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  }

  torch::nn::Sequential layers;
  layers->push_back(
      Block(inplanes, planes, stride, downsample, groups, base_width));

  inplanes = planes * Block::expansion;

  for (int i = 1; i < blocks; ++i)
    layers->push_back(Block(inplanes, planes, 1, nullptr, groups, base_width));

  return layers;
}

template <typename Block>
ResNetImpl<Block>::ResNetImpl(
    const std::vector<int>& layers,
    int64_t num_classes,
    bool zero_init_residual,
    int64_t groups,
    int64_t width_per_group)
    : groups(groups),
      base_width(width_per_group),
      inplanes(64),
Francisco Massa's avatar
Francisco Massa committed
127
128
      conv1(
          torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(false)),
Shahriar's avatar
Shahriar committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
      bn1(64),
      layer1(_make_layer(64, layers[0])),
      layer2(_make_layer(128, layers[1], 2)),
      layer3(_make_layer(256, layers[2], 2)),
      layer4(_make_layer(512, layers[3], 2)),
      fc(512 * Block::expansion, num_classes) {
  register_module("conv1", conv1);
  register_module("bn1", bn1);
  register_module("fc", fc);

  register_module("layer1", layer1);
  register_module("layer2", layer2);
  register_module("layer3", layer3);
  register_module("layer4", layer4);

  for (auto& module : modules(/*include_self=*/false)) {
    if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
      torch::nn::init::kaiming_normal_(
          M->weight,
          /*a=*/0,
149
150
151
          torch::kFanOut,
          torch::kReLU);
    else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
Shahriar's avatar
Shahriar committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
      torch::nn::init::constant_(M->weight, 1);
      torch::nn::init::constant_(M->bias, 0);
    }
  }

  // Zero-initialize the last BN in each residual branch, so that the residual
  // branch starts with zeros, and each residual block behaves like an
  // identity. This improves the model by 0.2~0.3% according to
  // https://arxiv.org/abs/1706.02677
  if (zero_init_residual)
    for (auto& module : modules(/*include_self=*/false)) {
      if (auto* M = dynamic_cast<_resnetimpl::Bottleneck*>(module.get()))
        torch::nn::init::constant_(M->bn3->weight, 0);
      else if (auto* M = dynamic_cast<_resnetimpl::BasicBlock*>(module.get()))
        torch::nn::init::constant_(M->bn2->weight, 0);
    }
168
169

  modelsimpl::deprecation_warning();
Shahriar's avatar
Shahriar committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
}

template <typename Block>
torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
  x = conv1->forward(x);
  x = bn1->forward(x).relu_();
  x = torch::max_pool2d(x, 3, 2, 1);

  x = layer1->forward(x);
  x = layer2->forward(x);
  x = layer3->forward(x);
  x = layer4->forward(x);

  x = torch::adaptive_avg_pool2d(x, {1, 1});
  x = x.reshape({x.size(0), -1});
  x = fc->forward(x);

  return x;
}

190
struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
191
192
193
  explicit ResNet18Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
194
195
};

196
struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
197
198
199
  explicit ResNet34Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
200
201
};

202
struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
203
204
205
  explicit ResNet50Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
206
207
};

208
struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
209
210
211
  explicit ResNet101Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
212
213
};

214
struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
215
216
217
  explicit ResNet152Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
218
219
};

220
struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
221
  explicit ResNext50_32x4dImpl(
Shahriar's avatar
Shahriar committed
222
223
224
225
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

226
struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
227
  explicit ResNext101_32x8dImpl(
Shahriar's avatar
Shahriar committed
228
229
230
231
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

232
struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
233
  explicit WideResNet50_2Impl(
234
235
236
237
238
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
239
  explicit WideResNet101_2Impl(
240
241
242
243
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

Shahriar's avatar
Shahriar committed
244
template <typename Block>
245
struct VISION_API ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> {
Shahriar's avatar
Shahriar committed
246
247
248
249
250
251
252
253
254
255
  using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
};

TORCH_MODULE(ResNet18);
TORCH_MODULE(ResNet34);
TORCH_MODULE(ResNet50);
TORCH_MODULE(ResNet101);
TORCH_MODULE(ResNet152);
TORCH_MODULE(ResNext50_32x4d);
TORCH_MODULE(ResNext101_32x8d);
256
257
TORCH_MODULE(WideResNet50_2);
TORCH_MODULE(WideResNet101_2);
Shahriar's avatar
Shahriar committed
258
259
260

} // namespace models
} // namespace vision