googlenet.h 2.01 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7
8
9

namespace vision {
namespace models {

namespace _googlenetimpl {
10
struct VISION_API BasicConv2dImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
11
  torch::nn::Conv2d conv{nullptr};
12
  torch::nn::BatchNorm2d bn{nullptr};
Shahriar's avatar
Shahriar committed
13

14
  explicit BasicConv2dImpl(torch::nn::Conv2dOptions options);
Shahriar's avatar
Shahriar committed
15
16
17
18
19
20

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(BasicConv2d);

21
struct VISION_API InceptionImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  BasicConv2d branch1{nullptr};
  torch::nn::Sequential branch2, branch3, branch4;

  InceptionImpl(
      int64_t in_channels,
      int64_t ch1x1,
      int64_t ch3x3red,
      int64_t ch3x3,
      int64_t ch5x5red,
      int64_t ch5x5,
      int64_t pool_proj);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(Inception);

39
struct VISION_API InceptionAuxImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
40
41
42
43
44
45
46
47
48
49
50
51
  BasicConv2d conv{nullptr};
  torch::nn::Linear fc1{nullptr}, fc2{nullptr};

  InceptionAuxImpl(int64_t in_channels, int64_t num_classes);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(InceptionAux);

} // namespace _googlenetimpl

52
struct VISION_API GoogLeNetOutput {
Shahriar's avatar
Shahriar committed
53
54
55
56
57
  torch::Tensor output;
  torch::Tensor aux1;
  torch::Tensor aux2;
};

58
struct VISION_API GoogLeNetImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  bool aux_logits, transform_input;

  _googlenetimpl::BasicConv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};

  _googlenetimpl::Inception inception3a{nullptr}, inception3b{nullptr},
      inception4a{nullptr}, inception4b{nullptr}, inception4c{nullptr},
      inception4d{nullptr}, inception4e{nullptr}, inception5a{nullptr},
      inception5b{nullptr};

  _googlenetimpl::InceptionAux aux1{nullptr}, aux2{nullptr};

  torch::nn::Dropout dropout{nullptr};
  torch::nn::Linear fc{nullptr};

73
  explicit GoogLeNetImpl(
Shahriar's avatar
Shahriar committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
      int64_t num_classes = 1000,
      bool aux_logits = true,
      bool transform_input = false,
      bool init_weights = true);

  void _initialize_weights();

  GoogLeNetOutput forward(torch::Tensor x);
};

TORCH_MODULE(GoogLeNet);

} // namespace models
} // namespace vision