squeezenet.h 1.16 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef SQUEEZENET_H
#define SQUEEZENET_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8

namespace vision {
namespace models {
9
struct VISION_API SqueezeNetImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
10
11
12
13
14
15
16
17
18
19
20
  int64_t num_classes;
  torch::nn::Sequential features{nullptr}, classifier{nullptr};

  SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);

  torch::Tensor forward(torch::Tensor x);
};

// SqueezeNet model architecture from the "SqueezeNet: AlexNet-level
// accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper.
21
struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
Shahriar's avatar
Shahriar committed
22
23
24
25
26
27
28
  SqueezeNet1_0Impl(int64_t num_classes = 1000);
};

// SqueezeNet 1.1 model from the official SqueezeNet repo
// <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>.
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy.
29
struct VISION_API SqueezeNet1_1Impl : SqueezeNetImpl {
Shahriar's avatar
Shahriar committed
30
31
32
33
34
35
36
37
38
39
40
  SqueezeNet1_1Impl(int64_t num_classes = 1000);
};

TORCH_MODULE(SqueezeNet);
TORCH_MODULE(SqueezeNet1_0);
TORCH_MODULE(SqueezeNet1_1);

} // namespace models
} // namespace vision

#endif // SQUEEZENET_H