inception.h 3.32 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#ifndef INCEPTION_H
#define INCEPTION_H

#include <torch/torch.h>

namespace vision {
namespace models {
namespace _inceptionimpl {
struct BasicConv2dImpl : torch::nn::Module {
  torch::nn::Conv2d conv{nullptr};
  torch::nn::BatchNorm bn{nullptr};

  BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(BasicConv2d);

struct InceptionAImpl : torch::nn::Module {
  BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1,
      branch3x3dbl_2, branch3x3dbl_3, branch_pool;

  InceptionAImpl(int64_t in_channels, int64_t pool_features);

  torch::Tensor forward(torch::Tensor x);
};

struct InceptionBImpl : torch::nn::Module {
  BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;

  InceptionBImpl(int64_t in_channels);

  torch::Tensor forward(torch::Tensor x);
};

struct InceptionCImpl : torch::nn::Module {
  BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr},
      branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr},
      branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr},
      branch_pool{nullptr};

  InceptionCImpl(int64_t in_channels, int64_t channels_7x7);

  torch::Tensor forward(torch::Tensor x);
};

struct InceptionDImpl : torch::nn::Module {
  BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
      branch7x7x3_3, branch7x7x3_4;

  InceptionDImpl(int64_t in_channels);

  torch::Tensor forward(torch::Tensor x);
};

struct InceptionEImpl : torch::nn::Module {
  BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b,
      branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
      branch_pool;

  InceptionEImpl(int64_t in_channels);

  torch::Tensor forward(torch::Tensor x);
};

struct InceptionAuxImpl : torch::nn::Module {
  BasicConv2d conv0;
  BasicConv2d conv1;
  torch::nn::Linear fc;

  InceptionAuxImpl(int64_t in_channels, int64_t num_classes);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(InceptionA);
TORCH_MODULE(InceptionB);
TORCH_MODULE(InceptionC);
TORCH_MODULE(InceptionD);
TORCH_MODULE(InceptionE);
TORCH_MODULE(InceptionAux);

} // namespace _inceptionimpl

struct InceptionV3Output {
  torch::Tensor output;
  torch::Tensor aux;
};

// Inception v3 model architecture from
//"Rethinking the Inception Architecture for Computer Vision"
//<http://arxiv.org/abs/1512.00567>
struct InceptionV3Impl : torch::nn::Module {
  bool aux_logits, transform_input;

  _inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr},
      Conv2d_2b_3x3{nullptr}, Conv2d_3b_1x1{nullptr}, Conv2d_4a_3x3{nullptr};

  _inceptionimpl::InceptionA Mixed_5b{nullptr}, Mixed_5c{nullptr},
      Mixed_5d{nullptr};
  _inceptionimpl::InceptionB Mixed_6a{nullptr};
  _inceptionimpl::InceptionC Mixed_6b{nullptr}, Mixed_6c{nullptr},
      Mixed_6d{nullptr}, Mixed_6e{nullptr};
  _inceptionimpl::InceptionD Mixed_7a{nullptr};
  _inceptionimpl::InceptionE Mixed_7b{nullptr}, Mixed_7c{nullptr};

  torch::nn::Linear fc{nullptr};

  _inceptionimpl::InceptionAux AuxLogits{nullptr};

  InceptionV3Impl(
      int64_t num_classes = 1000,
      bool aux_logits = true,
      bool transform_input = false);

  InceptionV3Output forward(torch::Tensor x);
};

TORCH_MODULE(InceptionV3);

} // namespace models
} // namespace vision

#endif // INCEPTION_H