mobilenet.cpp 4.71 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
#include "mobilenet.h"

#include "modelsimpl.h"

namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;

9
10
11
12
13
14
15
16
17
18
19
20
21
int64_t make_divisible(
    double value,
    int64_t divisor,
    c10::optional<int64_t> min_value = {}) {
  if (!min_value.has_value())
    min_value = divisor;
  auto new_value = std::max(
      min_value.value(), (int64_t(value + divisor / 2) / divisor) * divisor);
  if (new_value < .9 * value)
    new_value += divisor;
  return new_value;
}

Shahriar's avatar
Shahriar committed
22
23
24
25
26
27
28
29
30
31
32
33
34
struct ConvBNReLUImpl : torch::nn::SequentialImpl {
  ConvBNReLUImpl(
      int64_t in_planes,
      int64_t out_planes,
      int64_t kernel_size = 3,
      int64_t stride = 1,
      int64_t groups = 1) {
    auto padding = (kernel_size - 1) / 2;

    push_back(torch::nn::Conv2d(Options(in_planes, out_planes, kernel_size)
                                    .stride(stride)
                                    .padding(padding)
                                    .groups(groups)
35
                                    .bias(false)));
36
    push_back(torch::nn::BatchNorm2d(out_planes));
Shahriar's avatar
Shahriar committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    push_back(torch::nn::Functional(modelsimpl::relu6_));
  }

  torch::Tensor forward(torch::Tensor x) {
    return torch::nn::SequentialImpl::forward(x);
  }
};

TORCH_MODULE(ConvBNReLU);

struct MobileNetInvertedResidualImpl : torch::nn::Module {
  int64_t stride;
  bool use_res_connect;
  torch::nn::Sequential conv;

  MobileNetInvertedResidualImpl(
      int64_t input,
      int64_t output,
      int64_t stride,
      double expand_ratio)
      : stride(stride), use_res_connect(stride == 1 && input == output) {
    auto double_compare = [](double a, double b) {
      return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
    };

62
    TORCH_CHECK(stride == 1 || stride == 2);
Shahriar's avatar
Shahriar committed
63
64
65
66
67
68
69
    auto hidden_dim = int64_t(std::round(input * expand_ratio));

    if (!double_compare(expand_ratio, 1))
      conv->push_back(ConvBNReLU(input, hidden_dim, 1));

    conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim));
    conv->push_back(torch::nn::Conv2d(
70
        Options(hidden_dim, output, 1).stride(1).padding(0).bias(false)));
71
    conv->push_back(torch::nn::BatchNorm2d(output));
Shahriar's avatar
Shahriar committed
72
73
74
75
76
77
78
79
80
81
82
83
84

    register_module("conv", conv);
  }

  torch::Tensor forward(torch::Tensor x) {
    if (use_res_connect)
      return x + conv->forward(x);
    return conv->forward(x);
  }
};

TORCH_MODULE(MobileNetInvertedResidual);

85
86
87
88
89
MobileNetV2Impl::MobileNetV2Impl(
    int64_t num_classes,
    double width_mult,
    std::vector<std::vector<int64_t>> inverted_residual_settings,
    int64_t round_nearest) {
Shahriar's avatar
Shahriar committed
90
91
92
93
  using Block = MobileNetInvertedResidual;
  int64_t input_channel = 32;
  int64_t last_channel = 1280;

94
95
96
97
98
99
100
101
102
103
104
105
  if (inverted_residual_settings.empty())
    inverted_residual_settings = {
        // t, c, n, s
        {1, 16, 1, 1},
        {6, 24, 2, 2},
        {6, 32, 3, 2},
        {6, 64, 4, 2},
        {6, 96, 3, 1},
        {6, 160, 3, 2},
        {6, 320, 1, 1},
    };

106
107
108
  TORCH_CHECK(
      inverted_residual_settings[0].size() == 4,
      "inverted_residual_settings should contain 4-element vectors");
109
110
111
112

  input_channel = make_divisible(input_channel * width_mult, round_nearest);
  this->last_channel =
      make_divisible(last_channel * std::max(1.0, width_mult), round_nearest);
Shahriar's avatar
Shahriar committed
113
114
115
  features->push_back(ConvBNReLU(3, input_channel, 3, 2));

  for (auto setting : inverted_residual_settings) {
116
117
    auto output_channel =
        make_divisible(setting[1] * width_mult, round_nearest);
Shahriar's avatar
Shahriar committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    for (int64_t i = 0; i < setting[2]; ++i) {
      auto stride = i == 0 ? setting[3] : 1;
      features->push_back(
          Block(input_channel, output_channel, stride, setting[0]));
      input_channel = output_channel;
    }
  }

  features->push_back(ConvBNReLU(input_channel, this->last_channel, 1));

  classifier->push_back(torch::nn::Dropout(0.2));
  classifier->push_back(torch::nn::Linear(this->last_channel, num_classes));

  register_module("features", features);
  register_module("classifier", classifier);

  for (auto& module : modules(/*include_self=*/false)) {
    if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
Francisco Massa's avatar
Francisco Massa committed
137
      torch::nn::init::kaiming_normal_(M->weight, 0, torch::kFanOut);
138
      if (M->options.bias())
Shahriar's avatar
Shahriar committed
139
        torch::nn::init::zeros_(M->bias);
Francisco Massa's avatar
Francisco Massa committed
140
141
    } else if (
        auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
Shahriar's avatar
Shahriar committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
      torch::nn::init::ones_(M->weight);
      torch::nn::init::zeros_(M->bias);
    } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
      torch::nn::init::normal_(M->weight, 0, 0.01);
      torch::nn::init::zeros_(M->bias);
    }
  }
}

torch::Tensor MobileNetV2Impl::forward(at::Tensor x) {
  x = features->forward(x);
  x = x.mean({2, 3});
  x = classifier->forward(x);
  return x;
}

} // namespace models
} // namespace vision