verify_onnx.cpp 3.06 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
#include <migraph/onnx.hpp>
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <migraph/cpu/cpu_target.hpp>
Paul's avatar
Paul committed
5
6
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
Paul's avatar
Paul committed
7
#include <migraph/generate.hpp>
8
#include <migraph/verify_args.hpp>
9
#include <migraph/instruction.hpp>
Paul's avatar
Paul committed
10

11
12
template <class T>
auto get_hash(const T& x)
Paul's avatar
Paul committed
13
{
14
15
16
    return std::hash<T>{}(x);
}

Paul's avatar
Paul committed
17
template <class F>
18
19
20
migraph::argument run_cpu(F f)
{
    auto p = f();
Paul's avatar
Paul committed
21
    p.compile(migraph::cpu::cpu_target{});
Paul's avatar
Paul committed
22
23
24
    migraph::program::parameter_map m;
    for(auto&& x : p.get_parameter_shapes())
    {
Paul's avatar
Paul committed
25
        m[x.first] = migraph::generate_argument(x.second, get_hash(x.first));
Paul's avatar
Paul committed
26
27
    }
    auto out = p.eval(m);
Paul's avatar
Paul committed
28
29
30
31
    std::cout << p << std::endl;
    return out;
}

Paul's avatar
Paul committed
32
template <class F>
33
migraph::argument run_gpu(F f)
Paul's avatar
Paul committed
34
{
35
    auto p = f();
Paul's avatar
Paul committed
36
    p.compile(migraph::gpu::target{});
Paul's avatar
Paul committed
37

Paul's avatar
Paul committed
38
39
40
    migraph::program::parameter_map m;
    for(auto&& x : p.get_parameter_shapes())
    {
41
        m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second, get_hash(x.first)));
Paul's avatar
Paul committed
42
43
    }
    auto out = migraph::gpu::from_gpu(p.eval(m));
Paul's avatar
Paul committed
44
    std::cout << p << std::endl;
Paul's avatar
Paul committed
45
    return migraph::gpu::from_gpu(out);
Paul's avatar
Paul committed
46
47
}

Paul's avatar
Paul committed
48
49
template <class F>
void verify_program(const std::string& name, F f, double tolerance = 100)
50
51
52
53
54
55
{
    auto x = run_cpu(f);
    auto y = run_gpu(f);
    migraph::verify_args(name, x, y, tolerance);
}

Paul's avatar
Paul committed
56
void verify_instructions(const migraph::program& prog, double tolerance = 80)
57
{
Paul's avatar
Paul committed
58
    for(auto&& ins : prog)
59
60
61
62
63
64
65
66
67
68
69
70
    {
        if(ins.op.name().front() == '@')
            continue;
        if(ins.op.name() == "broadcast")
            continue;
        if(ins.op.name() == "transpose")
            continue;
        if(ins.op.name() == "reshape")
            continue;
        auto create_program = [&] {
            migraph::program p;
            std::vector<migraph::instruction_ref> inputs;
Paul's avatar
Paul committed
71
            for(auto&& arg : ins.arguments)
72
73
74
75
            {
                if(arg->op.name() == "@literal")
                    inputs.push_back(p.add_literal(arg->lit));
                else
Paul's avatar
Paul committed
76
77
                    inputs.push_back(
                        p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
78
79
80
81
            }
            p.add_instruction(ins.op, inputs);
            return p;
        };
Paul's avatar
Paul committed
82
        try
83
84
85
86
87
        {
            std::cout << "Verify: " << ins.op.name() << std::endl;
            std::cout << create_program() << std::endl;
            verify_program(ins.op.name(), create_program, tolerance);
        }
Paul's avatar
Paul committed
88
        catch(...)
89
90
91
92
93
94
95
        {
            std::cout << "Instruction " << ins.op.name() << " threw an exception." << std::endl;
            throw;
        }
    }
}

Paul's avatar
Paul committed
96
97
int main(int argc, char const* argv[])
{
Paul's avatar
Paul committed
98
    std::vector<std::string> args(argv + 1, argv + argc);
99
    if(not args.empty())
Paul's avatar
Paul committed
100
    {
101
        std::string file = args.front();
Paul's avatar
Paul committed
102
        auto p           = migraph::parse_onnx(file);
Paul's avatar
Paul committed
103
104
        std::cout << p << std::endl;

105
106
107
108
109
110
111
112
        if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; }))
        {
            verify_instructions(p);
        }
        else
        {
            verify_program(file, [&] { return migraph::parse_onnx(file); });
        }
Paul's avatar
Paul committed
113
114
    }
}