"docs/en/get_started/introduction.md" did not exist on "259ea5f63fb42c612331589bb66c539ac9df5fe4"
mobilenet.cpp 4.05 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include "mobilenet.h"

#include "modelsimpl.h"

namespace vision {
namespace models {
using Options = torch::nn::Conv2dOptions;

struct ConvBNReLUImpl : torch::nn::SequentialImpl {
  ConvBNReLUImpl(
      int64_t in_planes,
      int64_t out_planes,
      int64_t kernel_size = 3,
      int64_t stride = 1,
      int64_t groups = 1) {
    auto padding = (kernel_size - 1) / 2;

    push_back(torch::nn::Conv2d(Options(in_planes, out_planes, kernel_size)
                                    .stride(stride)
                                    .padding(padding)
                                    .groups(groups)
                                    .with_bias(false)));
    push_back(torch::nn::BatchNorm(out_planes));
    push_back(torch::nn::Functional(modelsimpl::relu6_));
  }

  torch::Tensor forward(torch::Tensor x) {
    return torch::nn::SequentialImpl::forward(x);
  }
};

TORCH_MODULE(ConvBNReLU);

struct MobileNetInvertedResidualImpl : torch::nn::Module {
  int64_t stride;
  bool use_res_connect;
  torch::nn::Sequential conv;

  MobileNetInvertedResidualImpl(
      int64_t input,
      int64_t output,
      int64_t stride,
      double expand_ratio)
      : stride(stride), use_res_connect(stride == 1 && input == output) {
    auto double_compare = [](double a, double b) {
      return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
    };

    assert(stride == 1 || stride == 2);
    auto hidden_dim = int64_t(std::round(input * expand_ratio));

    if (!double_compare(expand_ratio, 1))
      conv->push_back(ConvBNReLU(input, hidden_dim, 1));

    conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim));
    conv->push_back(torch::nn::Conv2d(
        Options(hidden_dim, output, 1).stride(1).padding(0).with_bias(false)));
    conv->push_back(torch::nn::BatchNorm(output));

    register_module("conv", conv);
  }

  torch::Tensor forward(torch::Tensor x) {
    if (use_res_connect)
      return x + conv->forward(x);
    return conv->forward(x);
  }
};

TORCH_MODULE(MobileNetInvertedResidual);

MobileNetV2Impl::MobileNetV2Impl(int64_t num_classes, double width_mult) {
  using Block = MobileNetInvertedResidual;
  int64_t input_channel = 32;
  int64_t last_channel = 1280;

  std::vector<std::vector<int64_t>> inverted_residual_settings = {
      // t, c, n, s
      {1, 16, 1, 1},
      {6, 24, 2, 2},
      {6, 32, 3, 2},
      {6, 64, 4, 2},
      {6, 96, 3, 1},
      {6, 160, 3, 2},
      {6, 320, 1, 1},
  };

  input_channel = int64_t(input_channel * width_mult);
  this->last_channel = int64_t(last_channel * std::max(1.0, width_mult));
  features->push_back(ConvBNReLU(3, input_channel, 3, 2));

  for (auto setting : inverted_residual_settings) {
    auto output_channel = int64_t(setting[1] * width_mult);

    for (int64_t i = 0; i < setting[2]; ++i) {
      auto stride = i == 0 ? setting[3] : 1;
      features->push_back(
          Block(input_channel, output_channel, stride, setting[0]));
      input_channel = output_channel;
    }
  }

  features->push_back(ConvBNReLU(input_channel, this->last_channel, 1));

  classifier->push_back(torch::nn::Dropout(0.2));
  classifier->push_back(torch::nn::Linear(this->last_channel, num_classes));

  register_module("features", features);
  register_module("classifier", classifier);

  for (auto& module : modules(/*include_self=*/false)) {
    if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
      torch::nn::init::kaiming_normal_(
          M->weight, 0, torch::nn::init::FanMode::FanOut);
      if (M->options.with_bias())
        torch::nn::init::zeros_(M->bias);
    } else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
      torch::nn::init::ones_(M->weight);
      torch::nn::init::zeros_(M->bias);
    } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
      torch::nn::init::normal_(M->weight, 0, 0.01);
      torch::nn::init::zeros_(M->bias);
    }
  }
}

torch::Tensor MobileNetV2Impl::forward(at::Tensor x) {
  x = features->forward(x);
  x = x.mean({2, 3});
  x = classifier->forward(x);
  return x;
}

} // namespace models
} // namespace vision