squeezenet.h 1.11 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#ifndef SQUEEZENET_H
#define SQUEEZENET_H

#include <torch/torch.h>

namespace vision {
namespace models {
struct SqueezeNetImpl : torch::nn::Module {
  int64_t num_classes;
  torch::nn::Sequential features{nullptr}, classifier{nullptr};

  SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);

  torch::Tensor forward(torch::Tensor x);
};

// SqueezeNet model architecture from the "SqueezeNet: AlexNet-level
// accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper.
struct SqueezeNet1_0Impl : SqueezeNetImpl {
  SqueezeNet1_0Impl(int64_t num_classes = 1000);
};

// SqueezeNet 1.1 model from the official SqueezeNet repo
// <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>.
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy.
struct SqueezeNet1_1Impl : SqueezeNetImpl {
  SqueezeNet1_1Impl(int64_t num_classes = 1000);
};

TORCH_MODULE(SqueezeNet);
TORCH_MODULE(SqueezeNet1_0);
TORCH_MODULE(SqueezeNet1_1);

} // namespace models
} // namespace vision

#endif // SQUEEZENET_H