"docs/source/en/index.mdx" did not exist on "896c98a2aedaa25a9b47c6b4f9cafd7b3f7f0f54"
convert_models.cpp 3.59 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>

#include "../models/models.h"

using namespace vision::models;

template <typename Model>
void convert_and_save_model(
    const std::string& input_path,
    const std::string& output_path) {
  Model network;
  torch::load(network, input_path);
  torch::save(network, output_path);

  auto index = input_path.find("_python");
  auto name = input_path.substr(0, index);
  std::cout << "finished loading and saving " << name << std::endl;
}

int main(int argc, const char* argv[]) {
  convert_and_save_model<AlexNet>("alexnet_python.pt", "alexnet_cpp.pt");

  convert_and_save_model<VGG11>("vgg11_python.pt", "vgg11_cpp.pt");
  convert_and_save_model<VGG13>("vgg13_python.pt", "vgg13_cpp.pt");
  convert_and_save_model<VGG16>("vgg16_python.pt", "vgg16_cpp.pt");
  convert_and_save_model<VGG19>("vgg19_python.pt", "vgg19_cpp.pt");

  convert_and_save_model<VGG11BN>("vgg11bn_python.pt", "vgg11bn_cpp.pt");
  convert_and_save_model<VGG13BN>("vgg13bn_python.pt", "vgg13bn_cpp.pt");
  convert_and_save_model<VGG16BN>("vgg16bn_python.pt", "vgg16bn_cpp.pt");
  convert_and_save_model<VGG19BN>("vgg19bn_python.pt", "vgg19bn_cpp.pt");

  convert_and_save_model<ResNet18>("resnet18_python.pt", "resnet18_cpp.pt");
  convert_and_save_model<ResNet34>("resnet34_python.pt", "resnet34_cpp.pt");
  convert_and_save_model<ResNet50>("resnet50_python.pt", "resnet50_cpp.pt");
  convert_and_save_model<ResNet101>("resnet101_python.pt", "resnet101_cpp.pt");
  convert_and_save_model<ResNet152>("resnet152_python.pt", "resnet152_cpp.pt");
  convert_and_save_model<ResNext50_32x4d>(
      "resnext50_32x4d_python.pt", "resnext50_32x4d_cpp.pt");
  convert_and_save_model<ResNext101_32x8d>(
      "resnext101_32x8d_python.pt", "resnext101_32x8d_cpp.pt");
44
45
46
47
  convert_and_save_model<WideResNet50_2>(
      "wide_resnet50_2_python.pt", "wide_resnet50_2_cpp.pt");
  convert_and_save_model<WideResNet101_2>(
      "wide_resnet101_2_python.pt", "wide_resnet101_2_cpp.pt");
Shahriar's avatar
Shahriar committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

  convert_and_save_model<SqueezeNet1_0>(
      "squeezenet1_0_python.pt", "squeezenet1_0_cpp.pt");
  convert_and_save_model<SqueezeNet1_1>(
      "squeezenet1_1_python.pt", "squeezenet1_1_cpp.pt");

  convert_and_save_model<DenseNet121>(
      "densenet121_python.pt", "densenet121_cpp.pt");
  convert_and_save_model<DenseNet169>(
      "densenet169_python.pt", "densenet169_cpp.pt");
  convert_and_save_model<DenseNet201>(
      "densenet201_python.pt", "densenet201_cpp.pt");
  convert_and_save_model<DenseNet161>(
      "densenet161_python.pt", "densenet161_cpp.pt");

  convert_and_save_model<MobileNetV2>(
      "mobilenetv2_python.pt", "mobilenetv2_cpp.pt");

  convert_and_save_model<ShuffleNetV2_x0_5>(
      "shufflenetv2_x0_5_python.pt", "shufflenetv2_x0_5_cpp.pt");
  convert_and_save_model<ShuffleNetV2_x1_0>(
      "shufflenetv2_x1_0_python.pt", "shufflenetv2_x1_0_cpp.pt");
  convert_and_save_model<ShuffleNetV2_x1_5>(
      "shufflenetv2_x1_5_python.pt", "shufflenetv2_x1_5_cpp.pt");
  convert_and_save_model<ShuffleNetV2_x2_0>(
      "shufflenetv2_x2_0_python.pt", "shufflenetv2_x2_0_cpp.pt");

  convert_and_save_model<GoogLeNet>("googlenet_python.pt", "googlenet_cpp.pt");
  convert_and_save_model<InceptionV3>(
      "inceptionv3_python.pt", "inceptionv3_cpp.pt");

Shahriar's avatar
Shahriar committed
79
80
81
82
83
84
85
86
87
  convert_and_save_model<MNASNet0_5>(
      "mnasnet0_5_python.pt", "mnasnet0_5_cpp.pt");
  convert_and_save_model<MNASNet0_75>(
      "mnasnet0_75_python.pt", "mnasnet0_75_cpp.pt");
  convert_and_save_model<MNASNet1_0>(
      "mnasnet1_0_python.pt", "mnasnet1_0_cpp.pt");
  convert_and_save_model<MNASNet1_3>(
      "mnasnet1_3_python.pt", "mnasnet1_3_cpp.pt");

Shahriar's avatar
Shahriar committed
88
89
  return 0;
}