mobilenet.h 434 Bytes
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#ifndef MOBILENET_H
#define MOBILENET_H

#include <torch/torch.h>

namespace vision {
namespace models {
struct MobileNetV2Impl : torch::nn::Module {
  int64_t last_channel;
  torch::nn::Sequential features, classifier;

  MobileNetV2Impl(int64_t num_classes = 1000, double width_mult = 1.0);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(MobileNetV2);
} // namespace models
} // namespace vision

#endif // MOBILENET_H