inception.h 3.44 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef INCEPTION_H
#define INCEPTION_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8
9

namespace vision {
namespace models {
namespace _inceptionimpl {
10
struct VISION_API BasicConv2dImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
11
12
13
14
15
16
17
18
19
20
  torch::nn::Conv2d conv{nullptr};
  torch::nn::BatchNorm bn{nullptr};

  BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(BasicConv2d);

21
struct VISION_API InceptionAImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
22
23
24
25
26
27
28
29
  BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1,
      branch3x3dbl_2, branch3x3dbl_3, branch_pool;

  InceptionAImpl(int64_t in_channels, int64_t pool_features);

  torch::Tensor forward(torch::Tensor x);
};

30
struct VISION_API InceptionBImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
31
32
33
34
35
36
37
  BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;

  InceptionBImpl(int64_t in_channels);

  torch::Tensor forward(torch::Tensor x);
};

38
struct VISION_API InceptionCImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
39
40
41
42
43
44
45
46
47
48
  BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr},
      branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr},
      branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr},
      branch_pool{nullptr};

  InceptionCImpl(int64_t in_channels, int64_t channels_7x7);

  torch::Tensor forward(torch::Tensor x);
};

49
struct VISION_API InceptionDImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
50
51
52
53
54
55
56
57
  BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
      branch7x7x3_3, branch7x7x3_4;

  InceptionDImpl(int64_t in_channels);

  torch::Tensor forward(torch::Tensor x);
};

58
struct VISION_API InceptionEImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
59
60
61
62
63
64
65
66
67
  BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b,
      branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
      branch_pool;

  InceptionEImpl(int64_t in_channels);

  torch::Tensor forward(torch::Tensor x);
};

68
struct VISION_API InceptionAuxImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  BasicConv2d conv0;
  BasicConv2d conv1;
  torch::nn::Linear fc;

  InceptionAuxImpl(int64_t in_channels, int64_t num_classes);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(InceptionA);
TORCH_MODULE(InceptionB);
TORCH_MODULE(InceptionC);
TORCH_MODULE(InceptionD);
TORCH_MODULE(InceptionE);
TORCH_MODULE(InceptionAux);

} // namespace _inceptionimpl

87
struct VISION_API InceptionV3Output {
Shahriar's avatar
Shahriar committed
88
89
90
91
92
93
94
  torch::Tensor output;
  torch::Tensor aux;
};

// Inception v3 model architecture from
//"Rethinking the Inception Architecture for Computer Vision"
//<http://arxiv.org/abs/1512.00567>
95
struct VISION_API InceptionV3Impl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  bool aux_logits, transform_input;

  _inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr},
      Conv2d_2b_3x3{nullptr}, Conv2d_3b_1x1{nullptr}, Conv2d_4a_3x3{nullptr};

  _inceptionimpl::InceptionA Mixed_5b{nullptr}, Mixed_5c{nullptr},
      Mixed_5d{nullptr};
  _inceptionimpl::InceptionB Mixed_6a{nullptr};
  _inceptionimpl::InceptionC Mixed_6b{nullptr}, Mixed_6c{nullptr},
      Mixed_6d{nullptr}, Mixed_6e{nullptr};
  _inceptionimpl::InceptionD Mixed_7a{nullptr};
  _inceptionimpl::InceptionE Mixed_7b{nullptr}, Mixed_7c{nullptr};

  torch::nn::Linear fc{nullptr};

  _inceptionimpl::InceptionAux AuxLogits{nullptr};

  InceptionV3Impl(
      int64_t num_classes = 1000,
      bool aux_logits = true,
      bool transform_input = false);

  InceptionV3Output forward(torch::Tensor x);
};

TORCH_MODULE(InceptionV3);

} // namespace models
} // namespace vision

#endif // INCEPTION_H