shufflenetv2.h 1.2 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7
8

namespace vision {
namespace models {

9
struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
10
11
12
13
14
15
16
17
18
19
20
21
  std::vector<int64_t> _stage_out_channels;
  torch::nn::Sequential conv1{nullptr}, stage2, stage3, stage4, conv5{nullptr};
  torch::nn::Linear fc{nullptr};

  ShuffleNetV2Impl(
      const std::vector<int64_t>& stage_repeats,
      const std::vector<int64_t>& stage_out_channels,
      int64_t num_classes = 1000);

  torch::Tensor forward(torch::Tensor x);
};

22
struct VISION_API ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl {
23
  explicit ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
24
25
};

26
struct VISION_API ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl {
27
  explicit ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
28
29
};

30
struct VISION_API ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl {
31
  explicit ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
32
33
};

34
struct VISION_API ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl {
35
  explicit ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
36
37
38
39
40
41
42
43
44
45
};

TORCH_MODULE(ShuffleNetV2);
TORCH_MODULE(ShuffleNetV2_x0_5);
TORCH_MODULE(ShuffleNetV2_x1_0);
TORCH_MODULE(ShuffleNetV2_x1_5);
TORCH_MODULE(ShuffleNetV2_x2_0);

} // namespace models
} // namespace vision