shufflenetv2.cpp 5.45 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "shufflenetv2.h"

#include "modelsimpl.h"

namespace vision {
namespace models {

using Options = torch::nn::Conv2dOptions;

torch::Tensor channel_shuffle(torch::Tensor x, int64_t groups) {
  auto shape = x.sizes();
  auto batchsize = shape[0];
  auto num_channels = shape[1];
  auto height = shape[2];
  auto width = shape[3];

  auto channels_per_group = num_channels / groups;

  x = x.view({batchsize, groups, channels_per_group, height, width});
  x = torch::transpose(x, 1, 2).contiguous();
  x = x.view({batchsize, -1, height, width});

  return x;
}

torch::nn::Conv2d conv11(int64_t input, int64_t output) {
  Options opts(input, output, 1);
28
  opts = opts.stride(1).padding(0).bias(false);
Shahriar's avatar
Shahriar committed
29
30
31
32
33
  return torch::nn::Conv2d(opts);
}

torch::nn::Conv2d conv33(int64_t input, int64_t output, int64_t stride) {
  Options opts(input, output, 3);
34
  opts = opts.stride(stride).padding(1).bias(false).groups(input);
Shahriar's avatar
Shahriar committed
35
36
37
38
39
40
41
42
43
  return torch::nn::Conv2d(opts);
}

struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module {
  int64_t stride;
  torch::nn::Sequential branch1{nullptr}, branch2{nullptr};

  ShuffleNetV2InvertedResidualImpl(int64_t inp, int64_t oup, int64_t stride)
      : stride(stride) {
44
    TORCH_CHECK(stride >= 1 && stride <= 3, "illegal stride value");
Shahriar's avatar
Shahriar committed
45
46

    auto branch_features = oup / 2;
47
    TORCH_CHECK(stride != 1 || inp == branch_features << 1);
Shahriar's avatar
Shahriar committed
48
49
50
51

    if (stride > 1) {
      branch1 = torch::nn::Sequential(
          conv33(inp, inp, stride),
52
          torch::nn::BatchNorm2d(inp),
Shahriar's avatar
Shahriar committed
53
          conv11(inp, branch_features),
54
          torch::nn::BatchNorm2d(branch_features),
Shahriar's avatar
Shahriar committed
55
56
57
58
59
          torch::nn::Functional(modelsimpl::relu_));
    }

    branch2 = torch::nn::Sequential(
        conv11(stride > 1 ? inp : branch_features, branch_features),
60
        torch::nn::BatchNorm2d(branch_features),
Shahriar's avatar
Shahriar committed
61
62
        torch::nn::Functional(modelsimpl::relu_),
        conv33(branch_features, branch_features, stride),
63
        torch::nn::BatchNorm2d(branch_features),
Shahriar's avatar
Shahriar committed
64
        conv11(branch_features, branch_features),
65
        torch::nn::BatchNorm2d(branch_features),
Shahriar's avatar
Shahriar committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        torch::nn::Functional(modelsimpl::relu_));

    if (!branch1.is_empty())
      register_module("branch1", branch1);

    register_module("branch2", branch2);
  }

  torch::Tensor forward(torch::Tensor x) {
    torch::Tensor out;

    if (stride == 1) {
      auto chunks = x.chunk(2, 1);
      out = torch::cat({chunks[0], branch2->forward(chunks[1])}, 1);
    } else
      out = torch::cat({branch1->forward(x), branch2->forward(x)}, 1);

83
    out = ::vision::models::channel_shuffle(out, 2);
Shahriar's avatar
Shahriar committed
84
85
86
87
88
89
90
91
92
93
    return out;
  }
};

TORCH_MODULE(ShuffleNetV2InvertedResidual);

ShuffleNetV2Impl::ShuffleNetV2Impl(
    const std::vector<int64_t>& stage_repeats,
    const std::vector<int64_t>& stage_out_channels,
    int64_t num_classes) {
94
95
96
  TORCH_CHECK(
      stage_repeats.size() == 3,
      "expected stage_repeats as vector of 3 positive ints");
Shahriar's avatar
Shahriar committed
97

98
99
100
  TORCH_CHECK(
      stage_out_channels.size() == 5,
      "expected stage_out_channels as vector of 5 positive ints");
Shahriar's avatar
Shahriar committed
101
102
103
104
105
106
107
108
109

  _stage_out_channels = stage_out_channels;
  int64_t input_channels = 3;
  auto output_channels = _stage_out_channels[0];

  conv1 = torch::nn::Sequential(
      torch::nn::Conv2d(Options(input_channels, output_channels, 3)
                            .stride(2)
                            .padding(1)
110
                            .bias(false)),
111
      torch::nn::BatchNorm2d(output_channels),
Shahriar's avatar
Shahriar committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
      torch::nn::Functional(modelsimpl::relu_));

  input_channels = output_channels;
  std::vector<torch::nn::Sequential> stages = {stage2, stage3, stage4};

  for (size_t i = 0; i < stages.size(); ++i) {
    auto& seq = stages[i];
    auto repeats = stage_repeats[i];
    auto output_channels = _stage_out_channels[i + 1];

    seq->push_back(
        ShuffleNetV2InvertedResidual(input_channels, output_channels, 2));

    for (size_t j = 0; j < size_t(repeats - 1); ++j)
      seq->push_back(
          ShuffleNetV2InvertedResidual(output_channels, output_channels, 1));

    input_channels = output_channels;
  }

  output_channels = _stage_out_channels.back();
  conv5 = torch::nn::Sequential(
      torch::nn::Conv2d(Options(input_channels, output_channels, 1)
                            .stride(1)
                            .padding(0)
137
                            .bias(false)),
138
      torch::nn::BatchNorm2d(output_channels),
Shahriar's avatar
Shahriar committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
      torch::nn::Functional(modelsimpl::relu_));

  fc = torch::nn::Linear(output_channels, num_classes);

  register_module("conv1", conv1);
  register_module("stage2", stage2);
  register_module("stage3", stage3);
  register_module("stage4", stage4);
  register_module("conv2", conv5);
  register_module("fc", fc);
}

torch::Tensor ShuffleNetV2Impl::forward(torch::Tensor x) {
  x = conv1->forward(x);
  x = torch::max_pool2d(x, 3, 2, 1);

  x = stage2->forward(x);
  x = stage3->forward(x);
  x = stage4->forward(x);
  x = conv5->forward(x);

  x = x.mean({2, 3});
  x = fc->forward(x);
  return x;
}

ShuffleNetV2_x0_5Impl::ShuffleNetV2_x0_5Impl(int64_t num_classes)
    : ShuffleNetV2Impl({4, 8, 4}, {24, 48, 96, 192, 1024}, num_classes) {}

ShuffleNetV2_x1_0Impl::ShuffleNetV2_x1_0Impl(int64_t num_classes)
    : ShuffleNetV2Impl({4, 8, 4}, {24, 116, 232, 464, 1024}, num_classes) {}

ShuffleNetV2_x1_5Impl::ShuffleNetV2_x1_5Impl(int64_t num_classes)
    : ShuffleNetV2Impl({4, 8, 4}, {24, 176, 352, 704, 1024}, num_classes) {}

ShuffleNetV2_x2_0Impl::ShuffleNetV2_x2_0Impl(int64_t num_classes)
    : ShuffleNetV2Impl({4, 8, 4}, {24, 244, 488, 976, 2048}, num_classes) {}

} // namespace models
} // namespace vision