main.cpp 1.23 KB
Newer Older
1
#include <iostream>
2
#include <torch/script.h>
3
4
#include <torch/torch.h>
#include <torchvision/vision.h>
5

6
7
8
int main() {
  torch::DeviceType device_type;
  device_type = torch::kCPU;
9

10
11
12
13
14
15
16
17
18
19
20
21
22
  torch::jit::script::Module model;
  try {
    std::cout << "Loading model\n";
    // Deserialize the ScriptModule from a file using torch::jit::load().
    model = torch::jit::load("resnet18.pt");
    std::cout << "Model loaded\n";
  } catch (const torch::Error& e) {
    std::cout << "error loading the model\n";
    return -1;
  } catch (const std::exception& e) {
    std::cout << "Other error: " << e.what() << "\n";
    return -1;
  }
23

24
25
26
27
28
29
30
  // TorchScript models require a List[IValue] as input
  std::vector<torch::jit::IValue> inputs;

  // Create a random input tensor and run it through the model.
  inputs.push_back(torch::rand({1, 3, 10, 10}));
  auto out = model.forward(inputs);
  std::cout << out << "\n";
31
32
33

  if (torch::cuda::is_available()) {
    // Move model and inputs to GPU
34
35
36
37
38
39
    model.to(torch::kCUDA);

    // Add GPU inputs
    inputs.clear();
    torch::TensorOptions options = torch::TensorOptions{torch::kCUDA};
    inputs.push_back(torch::rand({1, 3, 10, 10}, options));
40

41
42
    auto gpu_out = model.forward(inputs);
    std::cout << gpu_out << "\n";
43
  }
44
}