vgg.h 2.18 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7

namespace vision {
namespace models {
8
struct VISION_API VGGImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
9
10
11
12
  torch::nn::Sequential features{nullptr}, classifier{nullptr};

  void _initialize_weights();

13
  explicit VGGImpl(
14
      const torch::nn::Sequential& features,
Shahriar's avatar
Shahriar committed
15
16
17
18
19
20
21
      int64_t num_classes = 1000,
      bool initialize_weights = true);

  torch::Tensor forward(torch::Tensor x);
};

// VGG 11-layer model (configuration "A")
22
struct VISION_API VGG11Impl : VGGImpl {
23
24
25
  explicit VGG11Impl(
      int64_t num_classes = 1000,
      bool initialize_weights = true);
Shahriar's avatar
Shahriar committed
26
27
28
};

// VGG 13-layer model (configuration "B")
29
struct VISION_API VGG13Impl : VGGImpl {
30
31
32
  explicit VGG13Impl(
      int64_t num_classes = 1000,
      bool initialize_weights = true);
Shahriar's avatar
Shahriar committed
33
34
35
};

// VGG 16-layer model (configuration "D")
36
struct VISION_API VGG16Impl : VGGImpl {
37
38
39
  explicit VGG16Impl(
      int64_t num_classes = 1000,
      bool initialize_weights = true);
Shahriar's avatar
Shahriar committed
40
41
42
};

// VGG 19-layer model (configuration "E")
43
struct VISION_API VGG19Impl : VGGImpl {
44
45
46
  explicit VGG19Impl(
      int64_t num_classes = 1000,
      bool initialize_weights = true);
Shahriar's avatar
Shahriar committed
47
48
49
};

// VGG 11-layer model (configuration "A") with batch normalization
50
struct VISION_API VGG11BNImpl : VGGImpl {
51
52
53
  explicit VGG11BNImpl(
      int64_t num_classes = 1000,
      bool initialize_weights = true);
Shahriar's avatar
Shahriar committed
54
55
56
};

// VGG 13-layer model (configuration "B") with batch normalization
57
struct VISION_API VGG13BNImpl : VGGImpl {
58
59
60
  explicit VGG13BNImpl(
      int64_t num_classes = 1000,
      bool initialize_weights = true);
Shahriar's avatar
Shahriar committed
61
62
63
};

// VGG 16-layer model (configuration "D") with batch normalization
64
struct VISION_API VGG16BNImpl : VGGImpl {
65
66
67
  explicit VGG16BNImpl(
      int64_t num_classes = 1000,
      bool initialize_weights = true);
Shahriar's avatar
Shahriar committed
68
69
70
};

// VGG 19-layer model (configuration 'E') with batch normalization
71
struct VISION_API VGG19BNImpl : VGGImpl {
72
73
74
  explicit VGG19BNImpl(
      int64_t num_classes = 1000,
      bool initialize_weights = true);
Shahriar's avatar
Shahriar committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
};

TORCH_MODULE(VGG);

TORCH_MODULE(VGG11);
TORCH_MODULE(VGG13);
TORCH_MODULE(VGG16);
TORCH_MODULE(VGG19);

TORCH_MODULE(VGG11BN);
TORCH_MODULE(VGG13BN);
TORCH_MODULE(VGG16BN);
TORCH_MODULE(VGG19BN);

} // namespace models
} // namespace vision