squeezenet.h 1.14 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7

namespace vision {
namespace models {
8
struct VISION_API SqueezeNetImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
9
10
11
  int64_t num_classes;
  torch::nn::Sequential features{nullptr}, classifier{nullptr};

12
  explicit SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
13
14
15
16
17
18
19

  torch::Tensor forward(torch::Tensor x);
};

// SqueezeNet model architecture from the "SqueezeNet: AlexNet-level
// accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper.
20
struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
21
  explicit SqueezeNet1_0Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
22
23
24
25
26
27
};

// SqueezeNet 1.1 model from the official SqueezeNet repo
// <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>.
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy.
28
struct VISION_API SqueezeNet1_1Impl : SqueezeNetImpl {
29
  explicit SqueezeNet1_1Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
30
31
32
33
34
35
36
37
};

TORCH_MODULE(SqueezeNet);
TORCH_MODULE(SqueezeNet1_0);
TORCH_MODULE(SqueezeNet1_1);

} // namespace models
} // namespace vision