resnet18.cpp 1022 Bytes
Newer Older
1
2
3
4
5
6
7
8
#include <cstdio>
#include <string>
#include <fstream>
#include <numeric>
#include <stdexcept>

#include <migraph/onnx.hpp>

wsttiger's avatar
wsttiger committed
9
10
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
11
12
13
14
15
16
#include <migraph/generate.hpp>

int main(int argc, char const* argv[])
{
    std::string file = argv[1];
    auto prog        = migraph::parse_onnx(file);
wsttiger's avatar
wsttiger committed
17
18
19
20

    // GPU target
    prog.compile(migraph::gpu::target{});
    migraph::program::parameter_map m;
21
    auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
wsttiger's avatar
wsttiger committed
22
23
24
    m["output"] =
        migraph::gpu::to_gpu(migraph::generate_argument(prog.get_parameter_shape("output")));
    m["0"]      = migraph::gpu::to_gpu(migraph::generate_argument(s, 12345));
wsttiger's avatar
wsttiger committed
25
26
27
28
29
30
31
32
    auto result = migraph::gpu::from_gpu(prog.eval(m));

    // // CPU target
    // prog.compile(migraph::cpu::cpu_target{});
    // auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
    // auto input3 = migraph::generate_argument(s, 12345);
    // auto result = prog.eval({{"0", input3}});
}