"vscode:/vscode.git/clone" did not exist on "c5a9d22f816b31c463b17d5897813acc5869bed6"
googlenet.h 2.06 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef GOOGLENET_H
#define GOOGLENET_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8
9
10

namespace vision {
namespace models {

namespace _googlenetimpl {
11
struct VISION_API BasicConv2dImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
12
  torch::nn::Conv2d conv{nullptr};
13
  torch::nn::BatchNorm2d bn{nullptr};
Shahriar's avatar
Shahriar committed
14

15
  explicit BasicConv2dImpl(torch::nn::Conv2dOptions options);
Shahriar's avatar
Shahriar committed
16
17
18
19
20
21

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(BasicConv2d);

22
struct VISION_API InceptionImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
  BasicConv2d branch1{nullptr};
  torch::nn::Sequential branch2, branch3, branch4;

  InceptionImpl(
      int64_t in_channels,
      int64_t ch1x1,
      int64_t ch3x3red,
      int64_t ch3x3,
      int64_t ch5x5red,
      int64_t ch5x5,
      int64_t pool_proj);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(Inception);

40
struct VISION_API InceptionAuxImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
41
42
43
44
45
46
47
48
49
50
51
52
  BasicConv2d conv{nullptr};
  torch::nn::Linear fc1{nullptr}, fc2{nullptr};

  InceptionAuxImpl(int64_t in_channels, int64_t num_classes);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(InceptionAux);

} // namespace _googlenetimpl

53
struct VISION_API GoogLeNetOutput {
Shahriar's avatar
Shahriar committed
54
55
56
57
58
  torch::Tensor output;
  torch::Tensor aux1;
  torch::Tensor aux2;
};

59
struct VISION_API GoogLeNetImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  bool aux_logits, transform_input;

  _googlenetimpl::BasicConv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};

  _googlenetimpl::Inception inception3a{nullptr}, inception3b{nullptr},
      inception4a{nullptr}, inception4b{nullptr}, inception4c{nullptr},
      inception4d{nullptr}, inception4e{nullptr}, inception5a{nullptr},
      inception5b{nullptr};

  _googlenetimpl::InceptionAux aux1{nullptr}, aux2{nullptr};

  torch::nn::Dropout dropout{nullptr};
  torch::nn::Linear fc{nullptr};

74
  explicit GoogLeNetImpl(
Shahriar's avatar
Shahriar committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
      int64_t num_classes = 1000,
      bool aux_logits = true,
      bool transform_input = false,
      bool init_weights = true);

  void _initialize_weights();

  GoogLeNetOutput forward(torch::Tensor x);
};

TORCH_MODULE(GoogLeNet);

} // namespace models
} // namespace vision

#endif // GOOGLENET_H