inception.h 3.48 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7
8

namespace vision {
namespace models {
namespace _inceptionimpl {
9
struct VISION_API BasicConv2dImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
10
  torch::nn::Conv2d conv{nullptr};
11
  torch::nn::BatchNorm2d bn{nullptr};
Shahriar's avatar
Shahriar committed
12

13
14
15
  explicit BasicConv2dImpl(
      torch::nn::Conv2dOptions options,
      double std_dev = 0.1);
Shahriar's avatar
Shahriar committed
16
17
18
19
20
21

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(BasicConv2d);

22
struct VISION_API InceptionAImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
23
24
25
26
27
  BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1,
      branch3x3dbl_2, branch3x3dbl_3, branch_pool;

  InceptionAImpl(int64_t in_channels, int64_t pool_features);

28
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
29
30
};

31
struct VISION_API InceptionBImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
32
33
  BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;

34
  explicit InceptionBImpl(int64_t in_channels);
Shahriar's avatar
Shahriar committed
35

36
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
37
38
};

39
struct VISION_API InceptionCImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
40
41
42
43
44
45
46
  BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr},
      branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr},
      branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr},
      branch_pool{nullptr};

  InceptionCImpl(int64_t in_channels, int64_t channels_7x7);

47
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
48
49
};

50
struct VISION_API InceptionDImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
51
52
53
  BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
      branch7x7x3_3, branch7x7x3_4;

54
  explicit InceptionDImpl(int64_t in_channels);
Shahriar's avatar
Shahriar committed
55

56
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
57
58
};

59
struct VISION_API InceptionEImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
60
61
62
63
  BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b,
      branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
      branch_pool;

64
  explicit InceptionEImpl(int64_t in_channels);
Shahriar's avatar
Shahriar committed
65

66
  torch::Tensor forward(const torch::Tensor& x);
Shahriar's avatar
Shahriar committed
67
68
};

69
struct VISION_API InceptionAuxImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  BasicConv2d conv0;
  BasicConv2d conv1;
  torch::nn::Linear fc;

  InceptionAuxImpl(int64_t in_channels, int64_t num_classes);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(InceptionA);
TORCH_MODULE(InceptionB);
TORCH_MODULE(InceptionC);
TORCH_MODULE(InceptionD);
TORCH_MODULE(InceptionE);
TORCH_MODULE(InceptionAux);

} // namespace _inceptionimpl

88
struct VISION_API InceptionV3Output {
Shahriar's avatar
Shahriar committed
89
90
91
92
93
94
95
  torch::Tensor output;
  torch::Tensor aux;
};

// Inception v3 model architecture from
//"Rethinking the Inception Architecture for Computer Vision"
//<http://arxiv.org/abs/1512.00567>
96
struct VISION_API InceptionV3Impl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
  bool aux_logits, transform_input;

  _inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr},
      Conv2d_2b_3x3{nullptr}, Conv2d_3b_1x1{nullptr}, Conv2d_4a_3x3{nullptr};

  _inceptionimpl::InceptionA Mixed_5b{nullptr}, Mixed_5c{nullptr},
      Mixed_5d{nullptr};
  _inceptionimpl::InceptionB Mixed_6a{nullptr};
  _inceptionimpl::InceptionC Mixed_6b{nullptr}, Mixed_6c{nullptr},
      Mixed_6d{nullptr}, Mixed_6e{nullptr};
  _inceptionimpl::InceptionD Mixed_7a{nullptr};
  _inceptionimpl::InceptionE Mixed_7b{nullptr}, Mixed_7c{nullptr};

  torch::nn::Linear fc{nullptr};

  _inceptionimpl::InceptionAux AuxLogits{nullptr};

114
  explicit InceptionV3Impl(
Shahriar's avatar
Shahriar committed
115
116
117
118
119
120
121
122
123
124
125
      int64_t num_classes = 1000,
      bool aux_logits = true,
      bool transform_input = false);

  InceptionV3Output forward(torch::Tensor x);
};

TORCH_MODULE(InceptionV3);

} // namespace models
} // namespace vision