mobilenet.h 543 Bytes
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7

namespace vision {
namespace models {
8
struct VISION_API MobileNetV2Impl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
9
10
11
  int64_t last_channel;
  torch::nn::Sequential features, classifier;

12
  explicit MobileNetV2Impl(
13
14
15
16
      int64_t num_classes = 1000,
      double width_mult = 1.0,
      std::vector<std::vector<int64_t>> inverted_residual_settings = {},
      int64_t round_nearest = 8);
Shahriar's avatar
Shahriar committed
17
18
19
20
21
22
23

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(MobileNetV2);
} // namespace models
} // namespace vision