googlenet.h 1.96 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#ifndef GOOGLENET_H
#define GOOGLENET_H

#include <torch/torch.h>

namespace vision {
namespace models {

namespace _googlenetimpl {
struct BasicConv2dImpl : torch::nn::Module {
  torch::nn::Conv2d conv{nullptr};
  torch::nn::BatchNorm bn{nullptr};

  BasicConv2dImpl(torch::nn::Conv2dOptions options);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(BasicConv2d);

struct InceptionImpl : torch::nn::Module {
  BasicConv2d branch1{nullptr};
  torch::nn::Sequential branch2, branch3, branch4;

  InceptionImpl(
      int64_t in_channels,
      int64_t ch1x1,
      int64_t ch3x3red,
      int64_t ch3x3,
      int64_t ch5x5red,
      int64_t ch5x5,
      int64_t pool_proj);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(Inception);

struct InceptionAuxImpl : torch::nn::Module {
  BasicConv2d conv{nullptr};
  torch::nn::Linear fc1{nullptr}, fc2{nullptr};

  InceptionAuxImpl(int64_t in_channels, int64_t num_classes);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(InceptionAux);

} // namespace _googlenetimpl

struct GoogLeNetOutput {
  torch::Tensor output;
  torch::Tensor aux1;
  torch::Tensor aux2;
};

struct GoogLeNetImpl : torch::nn::Module {
  bool aux_logits, transform_input;

  _googlenetimpl::BasicConv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};

  _googlenetimpl::Inception inception3a{nullptr}, inception3b{nullptr},
      inception4a{nullptr}, inception4b{nullptr}, inception4c{nullptr},
      inception4d{nullptr}, inception4e{nullptr}, inception5a{nullptr},
      inception5b{nullptr};

  _googlenetimpl::InceptionAux aux1{nullptr}, aux2{nullptr};

  torch::nn::Dropout dropout{nullptr};
  torch::nn::Linear fc{nullptr};

  GoogLeNetImpl(
      int64_t num_classes = 1000,
      bool aux_logits = true,
      bool transform_input = false,
      bool init_weights = true);

  void _initialize_weights();

  GoogLeNetOutput forward(torch::Tensor x);
};

TORCH_MODULE(GoogLeNet);

} // namespace models
} // namespace vision

#endif // GOOGLENET_H