mobilenet.h 585 Bytes
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef MOBILENET_H
#define MOBILENET_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8

namespace vision {
namespace models {
9
struct VISION_API MobileNetV2Impl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
10
11
12
  int64_t last_channel;
  torch::nn::Sequential features, classifier;

13
14
15
16
17
  MobileNetV2Impl(
      int64_t num_classes = 1000,
      double width_mult = 1.0,
      std::vector<std::vector<int64_t>> inverted_residual_settings = {},
      int64_t round_nearest = 8);
Shahriar's avatar
Shahriar committed
18
19
20
21
22
23
24
25
26

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(MobileNetV2);
} // namespace models
} // namespace vision

#endif // MOBILENET_H