mnasnet.h 1.05 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#ifndef MNASNET_H
#define MNASNET_H

#include <torch/torch.h>
#include "general.h"

namespace vision {
namespace models {
struct VISION_API MNASNetImpl : torch::nn::Module {
  torch::nn::Sequential layers, classifier;

  void _initialize_weights();

14
15
16
17
  explicit MNASNetImpl(
      double alpha,
      int64_t num_classes = 1000,
      double dropout = .2);
Shahriar's avatar
Shahriar committed
18
19
20
21
22

  torch::Tensor forward(torch::Tensor x);
};

struct MNASNet0_5Impl : MNASNetImpl {
23
  explicit MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
Shahriar's avatar
Shahriar committed
24
25
26
};

struct MNASNet0_75Impl : MNASNetImpl {
27
  explicit MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
Shahriar's avatar
Shahriar committed
28
29
30
};

struct MNASNet1_0Impl : MNASNetImpl {
31
  explicit MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
Shahriar's avatar
Shahriar committed
32
33
34
};

struct MNASNet1_3Impl : MNASNetImpl {
35
  explicit MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
Shahriar's avatar
Shahriar committed
36
37
38
39
40
41
42
43
44
45
46
47
};

TORCH_MODULE(MNASNet);
TORCH_MODULE(MNASNet0_5);
TORCH_MODULE(MNASNet0_75);
TORCH_MODULE(MNASNet1_0);
TORCH_MODULE(MNASNet1_3);

} // namespace models
} // namespace vision

#endif // MNASNET_H