densenet.h 2.54 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

namespace vision {
namespace models {
// Densenet-BC model class, based on
// "Densely Connected Convolutional Networks"
// <https://arxiv.org/pdf/1608.06993.pdf>

// Args:
//     num_classes (int) - number of classification classes
//     growth_rate (int) - how many filters to add each layer (`k` in paper)
//     block_config (list of 4 ints) - how many layers in each pooling block
//     num_init_features (int) - the number of filters to learn in the first
//         convolution layer
//     bn_size (int) - multiplicative factor for number of bottle neck layers
//         (i.e. bn_size * k features in the bottleneck layer)
//     drop_rate (float) - dropout rate after each dense layer
21
struct VISION_API DenseNetImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
22
23
24
  torch::nn::Sequential features{nullptr};
  torch::nn::Linear classifier{nullptr};

25
  explicit DenseNetImpl(
Shahriar's avatar
Shahriar committed
26
27
      int64_t num_classes = 1000,
      int64_t growth_rate = 32,
28
      const std::vector<int64_t>& block_config = {6, 12, 24, 16},
Shahriar's avatar
Shahriar committed
29
30
31
32
33
34
35
      int64_t num_init_features = 64,
      int64_t bn_size = 4,
      double drop_rate = 0);

  torch::Tensor forward(torch::Tensor x);
};

36
struct VISION_API DenseNet121Impl : DenseNetImpl {
37
  explicit DenseNet121Impl(
Shahriar's avatar
Shahriar committed
38
39
      int64_t num_classes = 1000,
      int64_t growth_rate = 32,
40
      const std::vector<int64_t>& block_config = {6, 12, 24, 16},
Shahriar's avatar
Shahriar committed
41
42
43
44
45
      int64_t num_init_features = 64,
      int64_t bn_size = 4,
      double drop_rate = 0);
};

46
struct VISION_API DenseNet169Impl : DenseNetImpl {
47
  explicit DenseNet169Impl(
Shahriar's avatar
Shahriar committed
48
49
      int64_t num_classes = 1000,
      int64_t growth_rate = 32,
50
      const std::vector<int64_t>& block_config = {6, 12, 32, 32},
Shahriar's avatar
Shahriar committed
51
52
53
54
55
      int64_t num_init_features = 64,
      int64_t bn_size = 4,
      double drop_rate = 0);
};

56
struct VISION_API DenseNet201Impl : DenseNetImpl {
57
  explicit DenseNet201Impl(
Shahriar's avatar
Shahriar committed
58
59
      int64_t num_classes = 1000,
      int64_t growth_rate = 32,
60
      const std::vector<int64_t>& block_config = {6, 12, 48, 32},
Shahriar's avatar
Shahriar committed
61
62
63
64
65
      int64_t num_init_features = 64,
      int64_t bn_size = 4,
      double drop_rate = 0);
};

66
struct VISION_API DenseNet161Impl : DenseNetImpl {
67
  explicit DenseNet161Impl(
Shahriar's avatar
Shahriar committed
68
69
      int64_t num_classes = 1000,
      int64_t growth_rate = 48,
70
      const std::vector<int64_t>& block_config = {6, 12, 36, 24},
Shahriar's avatar
Shahriar committed
71
72
73
74
75
76
77
78
79
80
81
82
83
      int64_t num_init_features = 96,
      int64_t bn_size = 4,
      double drop_rate = 0);
};

TORCH_MODULE(DenseNet);
TORCH_MODULE(DenseNet121);
TORCH_MODULE(DenseNet169);
TORCH_MODULE(DenseNet201);
TORCH_MODULE(DenseNet161);

} // namespace models
} // namespace vision