inception.h 3.53 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef INCEPTION_H
#define INCEPTION_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8
9

namespace vision {
namespace models {
namespace _inceptionimpl {
10
struct VISION_API BasicConv2dImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
11
  torch::nn::Conv2d conv{nullptr};
12
  torch::nn::BatchNorm2d bn{nullptr};
Shahriar's avatar
Shahriar committed
13

14
15
16
  explicit BasicConv2dImpl(
      torch::nn::Conv2dOptions options,
      double std_dev = 0.1);
Shahriar's avatar
Shahriar committed
17
18
19
20
21
22

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(BasicConv2d);

23
struct VISION_API InceptionAImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
24
25
26
27
28
  BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1,
      branch3x3dbl_2, branch3x3dbl_3, branch_pool;

  InceptionAImpl(int64_t in_channels, int64_t pool_features);

29
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
30
31
};

32
struct VISION_API InceptionBImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
33
34
  BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;

35
  explicit InceptionBImpl(int64_t in_channels);
Shahriar's avatar
Shahriar committed
36

37
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
38
39
};

40
struct VISION_API InceptionCImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
41
42
43
44
45
46
47
  BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr},
      branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr},
      branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr},
      branch_pool{nullptr};

  InceptionCImpl(int64_t in_channels, int64_t channels_7x7);

48
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
49
50
};

51
struct VISION_API InceptionDImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
52
53
54
  BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
      branch7x7x3_3, branch7x7x3_4;

55
  explicit InceptionDImpl(int64_t in_channels);
Shahriar's avatar
Shahriar committed
56

57
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
58
59
};

60
struct VISION_API InceptionEImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
61
62
63
64
  BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b,
      branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
      branch_pool;

65
  explicit InceptionEImpl(int64_t in_channels);
Shahriar's avatar
Shahriar committed
66

67
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
68
69
};

70
struct VISION_API InceptionAuxImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  BasicConv2d conv0;
  BasicConv2d conv1;
  torch::nn::Linear fc;

  InceptionAuxImpl(int64_t in_channels, int64_t num_classes);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(InceptionA);
TORCH_MODULE(InceptionB);
TORCH_MODULE(InceptionC);
TORCH_MODULE(InceptionD);
TORCH_MODULE(InceptionE);
TORCH_MODULE(InceptionAux);

} // namespace _inceptionimpl

89
struct VISION_API InceptionV3Output {
Shahriar's avatar
Shahriar committed
90
91
92
93
94
95
96
  torch::Tensor output;
  torch::Tensor aux;
};

// Inception v3 model architecture from
//"Rethinking the Inception Architecture for Computer Vision"
//<http://arxiv.org/abs/1512.00567>
97
struct VISION_API InceptionV3Impl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  bool aux_logits, transform_input;

  _inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr},
      Conv2d_2b_3x3{nullptr}, Conv2d_3b_1x1{nullptr}, Conv2d_4a_3x3{nullptr};

  _inceptionimpl::InceptionA Mixed_5b{nullptr}, Mixed_5c{nullptr},
      Mixed_5d{nullptr};
  _inceptionimpl::InceptionB Mixed_6a{nullptr};
  _inceptionimpl::InceptionC Mixed_6b{nullptr}, Mixed_6c{nullptr},
      Mixed_6d{nullptr}, Mixed_6e{nullptr};
  _inceptionimpl::InceptionD Mixed_7a{nullptr};
  _inceptionimpl::InceptionE Mixed_7b{nullptr}, Mixed_7c{nullptr};

  torch::nn::Linear fc{nullptr};

  _inceptionimpl::InceptionAux AuxLogits{nullptr};

115
  explicit InceptionV3Impl(
Shahriar's avatar
Shahriar committed
116
117
118
119
120
121
122
123
124
125
126
127
128
      int64_t num_classes = 1000,
      bool aux_logits = true,
      bool transform_input = false);

  InceptionV3Output forward(torch::Tensor x);
};

TORCH_MODULE(InceptionV3);

} // namespace models
} // namespace vision

#endif // INCEPTION_H