shufflenetv2.h 1.23 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef SHUFFLENETV2_H
#define SHUFFLENETV2_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8
9

namespace vision {
namespace models {

10
struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
11
12
13
14
15
16
17
18
19
20
21
22
  std::vector<int64_t> _stage_out_channels;
  torch::nn::Sequential conv1{nullptr}, stage2, stage3, stage4, conv5{nullptr};
  torch::nn::Linear fc{nullptr};

  ShuffleNetV2Impl(
      const std::vector<int64_t>& stage_repeats,
      const std::vector<int64_t>& stage_out_channels,
      int64_t num_classes = 1000);

  torch::Tensor forward(torch::Tensor x);
};

23
struct VISION_API ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl {
Shahriar's avatar
Shahriar committed
24
25
26
  ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
};

27
struct VISION_API ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl {
Shahriar's avatar
Shahriar committed
28
29
30
  ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
};

31
struct VISION_API ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl {
Shahriar's avatar
Shahriar committed
32
33
34
  ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
};

35
struct VISION_API ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl {
Shahriar's avatar
Shahriar committed
36
37
38
39
40
41
42
43
44
45
46
47
48
  ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
};

TORCH_MODULE(ShuffleNetV2);
TORCH_MODULE(ShuffleNetV2_x0_5);
TORCH_MODULE(ShuffleNetV2_x1_0);
TORCH_MODULE(ShuffleNetV2_x1_5);
TORCH_MODULE(ShuffleNetV2_x2_0);

} // namespace models
} // namespace vision

#endif // SHUFFLENETV2_H