densenet.h 2.44 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#ifndef DENSENET_H
#define DENSENET_H

#include <torch/torch.h>

namespace vision {
namespace models {
// Densenet-BC model class, based on
// "Densely Connected Convolutional Networks"
// <https://arxiv.org/pdf/1608.06993.pdf>

// Args:
//     num_classes (int) - number of classification classes
//     growth_rate (int) - how many filters to add each layer (`k` in paper)
//     block_config (list of 4 ints) - how many layers in each pooling block
//     num_init_features (int) - the number of filters to learn in the first
//         convolution layer
//     bn_size (int) - multiplicative factor for number of bottle neck layers
//         (i.e. bn_size * k features in the bottleneck layer)
//     drop_rate (float) - dropout rate after each dense layer
struct DenseNetImpl : torch::nn::Module {
  torch::nn::Sequential features{nullptr};
  torch::nn::Linear classifier{nullptr};

  DenseNetImpl(
      int64_t num_classes = 1000,
      int64_t growth_rate = 32,
      std::vector<int64_t> block_config = {6, 12, 24, 16},
      int64_t num_init_features = 64,
      int64_t bn_size = 4,
      double drop_rate = 0);

  torch::Tensor forward(torch::Tensor x);
};

struct DenseNet121Impl : DenseNetImpl {
  DenseNet121Impl(
      int64_t num_classes = 1000,
      int64_t growth_rate = 32,
      std::vector<int64_t> block_config = {6, 12, 24, 16},
      int64_t num_init_features = 64,
      int64_t bn_size = 4,
      double drop_rate = 0);
};

struct DenseNet169Impl : DenseNetImpl {
  DenseNet169Impl(
      int64_t num_classes = 1000,
      int64_t growth_rate = 32,
      std::vector<int64_t> block_config = {6, 12, 32, 32},
      int64_t num_init_features = 64,
      int64_t bn_size = 4,
      double drop_rate = 0);
};

struct DenseNet201Impl : DenseNetImpl {
  DenseNet201Impl(
      int64_t num_classes = 1000,
      int64_t growth_rate = 32,
      std::vector<int64_t> block_config = {6, 12, 48, 32},
      int64_t num_init_features = 64,
      int64_t bn_size = 4,
      double drop_rate = 0);
};

struct DenseNet161Impl : DenseNetImpl {
  DenseNet161Impl(
      int64_t num_classes = 1000,
      int64_t growth_rate = 48,
      std::vector<int64_t> block_config = {6, 12, 36, 24},
      int64_t num_init_features = 96,
      int64_t bn_size = 4,
      double drop_rate = 0);
};

TORCH_MODULE(DenseNet);
TORCH_MODULE(DenseNet121);
TORCH_MODULE(DenseNet169);
TORCH_MODULE(DenseNet201);
TORCH_MODULE(DenseNet161);

} // namespace models
} // namespace vision

#endif // DENSENET_H