resnet.h 7.13 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

namespace vision {
namespace models {
template <typename Block>
struct ResNetImpl;

namespace _resnetimpl {
// 3x3 convolution with padding
torch::nn::Conv2d conv3x3(
    int64_t in,
    int64_t out,
    int64_t stride = 1,
    int64_t groups = 1);

// 1x1 convolution
torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride = 1);

22
struct VISION_API BasicBlock : torch::nn::Module {
Shahriar's avatar
Shahriar committed
23
24
25
26
27
28
29
  template <typename Block>
  friend struct vision::models::ResNetImpl;

  int64_t stride;
  torch::nn::Sequential downsample;

  torch::nn::Conv2d conv1{nullptr}, conv2{nullptr};
30
  torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr};
Shahriar's avatar
Shahriar committed
31
32
33
34
35
36
37

  static int expansion;

  BasicBlock(
      int64_t inplanes,
      int64_t planes,
      int64_t stride = 1,
38
      const torch::nn::Sequential& downsample = nullptr,
Shahriar's avatar
Shahriar committed
39
40
41
42
43
44
      int64_t groups = 1,
      int64_t base_width = 64);

  torch::Tensor forward(torch::Tensor x);
};

45
struct VISION_API Bottleneck : torch::nn::Module {
Shahriar's avatar
Shahriar committed
46
47
48
49
50
51
52
  template <typename Block>
  friend struct vision::models::ResNetImpl;

  int64_t stride;
  torch::nn::Sequential downsample;

  torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
53
  torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
Shahriar's avatar
Shahriar committed
54
55
56
57
58
59
60

  static int expansion;

  Bottleneck(
      int64_t inplanes,
      int64_t planes,
      int64_t stride = 1,
61
      const torch::nn::Sequential& downsample = nullptr,
Shahriar's avatar
Shahriar committed
62
63
64
65
66
67
68
69
70
71
72
      int64_t groups = 1,
      int64_t base_width = 64);

  torch::Tensor forward(torch::Tensor X);
};
} // namespace _resnetimpl

template <typename Block>
struct ResNetImpl : torch::nn::Module {
  int64_t groups, base_width, inplanes;
  torch::nn::Conv2d conv1;
73
  torch::nn::BatchNorm2d bn1;
Shahriar's avatar
Shahriar committed
74
  torch::nn::Sequential layer1, layer2, layer3, layer4;
75
  torch::nn::Linear fc;
Shahriar's avatar
Shahriar committed
76
77
78
79
80
81

  torch::nn::Sequential _make_layer(
      int64_t planes,
      int64_t blocks,
      int64_t stride = 1);

82
  explicit ResNetImpl(
Shahriar's avatar
Shahriar committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
      const std::vector<int>& layers,
      int64_t num_classes = 1000,
      bool zero_init_residual = false,
      int64_t groups = 1,
      int64_t width_per_group = 64);

  torch::Tensor forward(torch::Tensor X);
};

template <typename Block>
torch::nn::Sequential ResNetImpl<Block>::_make_layer(
    int64_t planes,
    int64_t blocks,
    int64_t stride) {
  torch::nn::Sequential downsample = nullptr;
  if (stride != 1 || inplanes != planes * Block::expansion) {
    downsample = torch::nn::Sequential(
        _resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride),
101
        torch::nn::BatchNorm2d(planes * Block::expansion));
Shahriar's avatar
Shahriar committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
  }

  torch::nn::Sequential layers;
  layers->push_back(
      Block(inplanes, planes, stride, downsample, groups, base_width));

  inplanes = planes * Block::expansion;

  for (int i = 1; i < blocks; ++i)
    layers->push_back(Block(inplanes, planes, 1, nullptr, groups, base_width));

  return layers;
}

template <typename Block>
ResNetImpl<Block>::ResNetImpl(
    const std::vector<int>& layers,
    int64_t num_classes,
    bool zero_init_residual,
    int64_t groups,
    int64_t width_per_group)
    : groups(groups),
      base_width(width_per_group),
      inplanes(64),
Francisco Massa's avatar
Francisco Massa committed
126
127
      conv1(
          torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(false)),
Shahriar's avatar
Shahriar committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
      bn1(64),
      layer1(_make_layer(64, layers[0])),
      layer2(_make_layer(128, layers[1], 2)),
      layer3(_make_layer(256, layers[2], 2)),
      layer4(_make_layer(512, layers[3], 2)),
      fc(512 * Block::expansion, num_classes) {
  register_module("conv1", conv1);
  register_module("bn1", bn1);
  register_module("fc", fc);

  register_module("layer1", layer1);
  register_module("layer2", layer2);
  register_module("layer3", layer3);
  register_module("layer4", layer4);

  for (auto& module : modules(/*include_self=*/false)) {
    if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
      torch::nn::init::kaiming_normal_(
          M->weight,
          /*a=*/0,
148
149
150
          torch::kFanOut,
          torch::kReLU);
    else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
Shahriar's avatar
Shahriar committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
      torch::nn::init::constant_(M->weight, 1);
      torch::nn::init::constant_(M->bias, 0);
    }
  }

  // Zero-initialize the last BN in each residual branch, so that the residual
  // branch starts with zeros, and each residual block behaves like an
  // identity. This improves the model by 0.2~0.3% according to
  // https://arxiv.org/abs/1706.02677
  if (zero_init_residual)
    for (auto& module : modules(/*include_self=*/false)) {
      if (auto* M = dynamic_cast<_resnetimpl::Bottleneck*>(module.get()))
        torch::nn::init::constant_(M->bn3->weight, 0);
      else if (auto* M = dynamic_cast<_resnetimpl::BasicBlock*>(module.get()))
        torch::nn::init::constant_(M->bn2->weight, 0);
    }
}

template <typename Block>
torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
  x = conv1->forward(x);
  x = bn1->forward(x).relu_();
  x = torch::max_pool2d(x, 3, 2, 1);

  x = layer1->forward(x);
  x = layer2->forward(x);
  x = layer3->forward(x);
  x = layer4->forward(x);

  x = torch::adaptive_avg_pool2d(x, {1, 1});
  x = x.reshape({x.size(0), -1});
  x = fc->forward(x);

  return x;
}

187
struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
188
189
190
  explicit ResNet18Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
191
192
};

193
struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
194
195
196
  explicit ResNet34Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
197
198
};

199
struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
200
201
202
  explicit ResNet50Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
203
204
};

205
struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
206
207
208
  explicit ResNet101Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
209
210
};

211
struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
212
213
214
  explicit ResNet152Impl(
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
Shahriar's avatar
Shahriar committed
215
216
};

217
struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
218
  explicit ResNext50_32x4dImpl(
Shahriar's avatar
Shahriar committed
219
220
221
222
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

223
struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
224
  explicit ResNext101_32x8dImpl(
Shahriar's avatar
Shahriar committed
225
226
227
228
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

229
struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
230
  explicit WideResNet50_2Impl(
231
232
233
234
235
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
236
  explicit WideResNet101_2Impl(
237
238
239
240
      int64_t num_classes = 1000,
      bool zero_init_residual = false);
};

Shahriar's avatar
Shahriar committed
241
template <typename Block>
242
struct VISION_API ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> {
Shahriar's avatar
Shahriar committed
243
244
245
246
247
248
249
250
251
252
  using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
};

TORCH_MODULE(ResNet18);
TORCH_MODULE(ResNet34);
TORCH_MODULE(ResNet50);
TORCH_MODULE(ResNet101);
TORCH_MODULE(ResNet152);
TORCH_MODULE(ResNext50_32x4d);
TORCH_MODULE(ResNext101_32x8d);
253
254
TORCH_MODULE(WideResNet50_2);
TORCH_MODULE(WideResNet101_2);
Shahriar's avatar
Shahriar committed
255
256
257

} // namespace models
} // namespace vision