"vscode:/vscode.git/clone" did not exist on "2726b309dba058f427a0352f349d5901c9cf245b"
squeezenet.h 1.19 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef SQUEEZENET_H
#define SQUEEZENET_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8

namespace vision {
namespace models {
9
struct VISION_API SqueezeNetImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
10
11
12
  int64_t num_classes;
  torch::nn::Sequential features{nullptr}, classifier{nullptr};

13
  explicit SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
14
15
16
17
18
19
20

  torch::Tensor forward(torch::Tensor x);
};

// SqueezeNet model architecture from the "SqueezeNet: AlexNet-level
// accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper.
21
struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
22
  explicit SqueezeNet1_0Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
23
24
25
26
27
28
};

// SqueezeNet 1.1 model from the official SqueezeNet repo
// <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>.
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy.
29
struct VISION_API SqueezeNet1_1Impl : SqueezeNetImpl {
30
  explicit SqueezeNet1_1Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
31
32
33
34
35
36
37
38
39
40
};

TORCH_MODULE(SqueezeNet);
TORCH_MODULE(SqueezeNet1_0);
TORCH_MODULE(SqueezeNet1_1);

} // namespace models
} // namespace vision

#endif // SQUEEZENET_H