vgg.cpp 3.83 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include "vgg.h"

#include <unordered_map>
#include "modelsimpl.h"

namespace vision {
namespace models {
torch::nn::Sequential makeLayers(
    const std::vector<int>& cfg,
    bool batch_norm = false) {
  torch::nn::Sequential seq;
  auto channels = 3;

  for (const auto& V : cfg) {
    if (V <= -1)
      seq->push_back(torch::nn::Functional(modelsimpl::max_pool2d, 2, 2));
    else {
      seq->push_back(torch::nn::Conv2d(
          torch::nn::Conv2dOptions(channels, V, 3).padding(1)));

      if (batch_norm)
22
        seq->push_back(torch::nn::BatchNorm2d(V));
Shahriar's avatar
Shahriar committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
      seq->push_back(torch::nn::Functional(modelsimpl::relu_));

      channels = V;
    }
  }

  return seq;
}

void VGGImpl::_initialize_weights() {
  for (auto& module : modules(/*include_self=*/false)) {
    if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
      torch::nn::init::kaiming_normal_(
          M->weight,
          /*a=*/0,
38
39
          torch::kFanOut,
          torch::kReLU);
Shahriar's avatar
Shahriar committed
40
      torch::nn::init::constant_(M->bias, 0);
Francisco Massa's avatar
Francisco Massa committed
41
42
    } else if (
        auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
Shahriar's avatar
Shahriar committed
43
44
45
46
47
48
49
50
51
52
      torch::nn::init::constant_(M->weight, 1);
      torch::nn::init::constant_(M->bias, 0);
    } else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
      torch::nn::init::normal_(M->weight, 0, 0.01);
      torch::nn::init::constant_(M->bias, 0);
    }
  }
}

VGGImpl::VGGImpl(
53
    const torch::nn::Sequential& features,
Shahriar's avatar
Shahriar committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    int64_t num_classes,
    bool initialize_weights) {
  classifier = torch::nn::Sequential(
      torch::nn::Linear(512 * 7 * 7, 4096),
      torch::nn::Functional(modelsimpl::relu_),
      torch::nn::Dropout(),
      torch::nn::Linear(4096, 4096),
      torch::nn::Functional(modelsimpl::relu_),
      torch::nn::Dropout(),
      torch::nn::Linear(4096, num_classes));

  this->features = features;

  register_module("features", this->features);
  register_module("classifier", classifier);

  if (initialize_weights)
    _initialize_weights();
72
73

  modelsimpl::deprecation_warning();
Shahriar's avatar
Shahriar committed
74
75
76
77
78
79
80
81
82
83
84
}

torch::Tensor VGGImpl::forward(torch::Tensor x) {
  x = features->forward(x);
  x = torch::adaptive_avg_pool2d(x, {7, 7});
  x = x.view({x.size(0), -1});
  x = classifier->forward(x);
  return x;
}

// clang-format off
85
static std::unordered_map<char, std::vector<int>> cfgs = {
Shahriar's avatar
Shahriar committed
86
87
88
89
90
91
92
  {'A', {64, -1, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
  {'B', {64, 64, -1, 128, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
  {'D', {64, 64, -1, 128, 128, -1, 256, 256, 256, -1, 512, 512, 512, -1, 512, 512, 512, -1}},
  {'E', {64,  64,  -1,  128, 128, -1,  256, 256, 256, 256, -1, 512, 512, 512, 512, -1,  512, 512, 512, 512, -1}}};
// clang-format on

VGG11Impl::VGG11Impl(int64_t num_classes, bool initialize_weights)
93
    : VGGImpl(makeLayers(cfgs['A']), num_classes, initialize_weights) {}
Shahriar's avatar
Shahriar committed
94
95

VGG13Impl::VGG13Impl(int64_t num_classes, bool initialize_weights)
96
    : VGGImpl(makeLayers(cfgs['B']), num_classes, initialize_weights) {}
Shahriar's avatar
Shahriar committed
97
98

VGG16Impl::VGG16Impl(int64_t num_classes, bool initialize_weights)
99
    : VGGImpl(makeLayers(cfgs['D']), num_classes, initialize_weights) {}
Shahriar's avatar
Shahriar committed
100
101

VGG19Impl::VGG19Impl(int64_t num_classes, bool initialize_weights)
102
    : VGGImpl(makeLayers(cfgs['E']), num_classes, initialize_weights) {}
Shahriar's avatar
Shahriar committed
103
104

VGG11BNImpl::VGG11BNImpl(int64_t num_classes, bool initialize_weights)
105
    : VGGImpl(makeLayers(cfgs['A'], true), num_classes, initialize_weights) {}
Shahriar's avatar
Shahriar committed
106
107

VGG13BNImpl::VGG13BNImpl(int64_t num_classes, bool initialize_weights)
108
    : VGGImpl(makeLayers(cfgs['B'], true), num_classes, initialize_weights) {}
Shahriar's avatar
Shahriar committed
109
110

VGG16BNImpl::VGG16BNImpl(int64_t num_classes, bool initialize_weights)
111
    : VGGImpl(makeLayers(cfgs['D'], true), num_classes, initialize_weights) {}
Shahriar's avatar
Shahriar committed
112
113

VGG19BNImpl::VGG19BNImpl(int64_t num_classes, bool initialize_weights)
114
    : VGGImpl(makeLayers(cfgs['E'], true), num_classes, initialize_weights) {}
Shahriar's avatar
Shahriar committed
115
116
117

} // namespace models
} // namespace vision