squeezenet.cpp 3.68 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include "squeezenet.h"

#include <limits>
#include "modelsimpl.h"

namespace vision {
namespace models {
struct Fire : torch::nn::Module {
  torch::nn::Conv2d squeeze, expand1x1, expand3x3;

  Fire(
      int64_t inplanes,
      int64_t squeeze_planes,
      int64_t expand1x1_planes,
      int64_t expand3x3_planes)
      : squeeze(torch::nn::Conv2dOptions(inplanes, squeeze_planes, 1)),
        expand1x1(
            torch::nn::Conv2dOptions(squeeze_planes, expand1x1_planes, 1)),
        expand3x3(torch::nn::Conv2dOptions(squeeze_planes, expand3x3_planes, 3)
                      .padding(1)) {
    register_module("squeeze", squeeze);
    register_module("expand1x1", expand1x1);
    register_module("expand3x3", expand3x3);
  }

  torch::Tensor forward(torch::Tensor x) {
    x = torch::relu(squeeze->forward(x));
    return torch::cat(
        {torch::relu(expand1x1->forward(x)),
         torch::relu(expand3x3->forward(x))},
        1);
  }
};

SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
    : num_classes(num_classes) {
  if (modelsimpl::double_compare(version, 1.0)) {
    features = torch::nn::Sequential(
        torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 96, 7).stride(2)),
        torch::nn::Functional(modelsimpl::relu_),
        torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true),
        Fire(96, 16, 64, 64),
        Fire(128, 16, 64, 64),
        Fire(128, 32, 128, 128),
        torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true),
        Fire(256, 32, 128, 128),
        Fire(256, 48, 192, 192),
        Fire(384, 48, 192, 192),
        Fire(384, 64, 256, 256),
        torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true),
        Fire(512, 64, 256, 256));
  } else if (modelsimpl::double_compare(version, 1.1)) {
    features = torch::nn::Sequential(
        torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, 3).stride(2)),
        torch::nn::Functional(modelsimpl::relu_),
        torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true),
        Fire(64, 16, 64, 64),
        Fire(128, 16, 64, 64),
        torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true),
        Fire(128, 32, 128, 128),
        Fire(256, 32, 128, 128),
        torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true),
        Fire(256, 48, 192, 192),
        Fire(384, 48, 192, 192),
        Fire(384, 64, 256, 256),
        Fire(512, 64, 256, 256));
  } else {
    std::cerr << "Wrong version number is passed th SqueeseNet constructor!"
              << std::endl;
    assert(false);
  }

  // Final convolution is initialized differently from the rest
  auto final_conv =
      torch::nn::Conv2d(torch::nn::Conv2dOptions(512, num_classes, 1));

  classifier = torch::nn::Sequential(
      torch::nn::Dropout(0.5),
      final_conv,
      torch::nn::Functional(modelsimpl::relu_),
      torch::nn::Functional(modelsimpl::adaptive_avg_pool2d, 1));

  register_module("features", features);
  register_module("classifier", classifier);

  for (auto& module : modules(/*include_self=*/false))
    if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
      if (M == final_conv.get())
        torch::nn::init::normal_(M->weight, 0.0, 0.01);
      else
        torch::nn::init::kaiming_uniform_(M->weight);

      if (M->options.with_bias())
        torch::nn::init::constant_(M->bias, 0);
    }
}

torch::Tensor SqueezeNetImpl::forward(torch::Tensor x) {
  x = features->forward(x);
  x = classifier->forward(x);
  return x.view({x.size(0), -1});
}

SqueezeNet1_0Impl::SqueezeNet1_0Impl(int64_t num_classes)
    : SqueezeNetImpl(1.0, num_classes) {}

SqueezeNet1_1Impl::SqueezeNet1_1Impl(int64_t num_classes)
    : SqueezeNetImpl(1.1, num_classes) {}

} // namespace models
} // namespace vision