mnasnet.h 1 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7
8
9
10
11
12

namespace vision {
namespace models {
struct VISION_API MNASNetImpl : torch::nn::Module {
  torch::nn::Sequential layers, classifier;

  void _initialize_weights();

13
14
15
16
  explicit MNASNetImpl(
      double alpha,
      int64_t num_classes = 1000,
      double dropout = .2);
Shahriar's avatar
Shahriar committed
17
18
19
20
21

  torch::Tensor forward(torch::Tensor x);
};

struct MNASNet0_5Impl : MNASNetImpl {
22
  explicit MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
Shahriar's avatar
Shahriar committed
23
24
25
};

struct MNASNet0_75Impl : MNASNetImpl {
26
  explicit MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
Shahriar's avatar
Shahriar committed
27
28
29
};

struct MNASNet1_0Impl : MNASNetImpl {
30
  explicit MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
Shahriar's avatar
Shahriar committed
31
32
33
};

struct MNASNet1_3Impl : MNASNetImpl {
34
  explicit MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
Shahriar's avatar
Shahriar committed
35
36
37
38
39
40
41
42
43
44
};

TORCH_MODULE(MNASNet);
TORCH_MODULE(MNASNet0_5);
TORCH_MODULE(MNASNet0_75);
TORCH_MODULE(MNASNet1_0);
TORCH_MODULE(MNASNet1_3);

} // namespace models
} // namespace vision