shufflenetv2.h 1.26 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef SHUFFLENETV2_H
#define SHUFFLENETV2_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8
9

namespace vision {
namespace models {

10
struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
11
12
13
14
15
16
17
18
19
20
21
22
  std::vector<int64_t> _stage_out_channels;
  torch::nn::Sequential conv1{nullptr}, stage2, stage3, stage4, conv5{nullptr};
  torch::nn::Linear fc{nullptr};

  ShuffleNetV2Impl(
      const std::vector<int64_t>& stage_repeats,
      const std::vector<int64_t>& stage_out_channels,
      int64_t num_classes = 1000);

  torch::Tensor forward(torch::Tensor x);
};

23
struct VISION_API ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl {
24
  explicit ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
25
26
};

27
struct VISION_API ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl {
28
  explicit ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
29
30
};

31
struct VISION_API ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl {
32
  explicit ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
33
34
};

35
struct VISION_API ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl {
36
  explicit ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
37
38
39
40
41
42
43
44
45
46
47
48
};

TORCH_MODULE(ShuffleNetV2);
TORCH_MODULE(ShuffleNetV2_x0_5);
TORCH_MODULE(ShuffleNetV2_x1_0);
TORCH_MODULE(ShuffleNetV2_x1_5);
TORCH_MODULE(ShuffleNetV2_x2_0);

} // namespace models
} // namespace vision

#endif // SHUFFLENETV2_H