verify_onnx.cpp 4 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
#include <migraph/onnx.hpp>
Paul's avatar
Paul committed
3

Shucai Xiao's avatar
Shucai Xiao committed
4
#include <migraph/cpu/target.hpp>
Paul's avatar
Paul committed
5
6
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
Paul's avatar
Paul committed
7
#include <migraph/generate.hpp>
8
#include <migraph/verify_args.hpp>
9
#include <migraph/instruction.hpp>
Paul's avatar
Paul committed
10

11
12
template <class T>
auto get_hash(const T& x)
Paul's avatar
Paul committed
13
{
14
15
16
    return std::hash<T>{}(x);
}

Paul's avatar
Paul committed
17
template <class F>
18
19
20
migraph::argument run_cpu(F f)
{
    auto p = f();
Shucai Xiao's avatar
Shucai Xiao committed
21
    p.compile(migraph::cpu::target{});
Paul's avatar
Paul committed
22
23
24
    migraph::program::parameter_map m;
    for(auto&& x : p.get_parameter_shapes())
    {
Paul's avatar
Paul committed
25
        m[x.first] = migraph::generate_argument(x.second, get_hash(x.first));
Paul's avatar
Paul committed
26
27
    }
    auto out = p.eval(m);
Paul's avatar
Paul committed
28
29
30
31
    std::cout << p << std::endl;
    return out;
}

Paul's avatar
Paul committed
32
template <class F>
33
migraph::argument run_gpu(F f)
Paul's avatar
Paul committed
34
{
35
    auto p = f();
Paul's avatar
Paul committed
36
    p.compile(migraph::gpu::target{});
Paul's avatar
Paul committed
37

Paul's avatar
Paul committed
38
39
40
    migraph::program::parameter_map m;
    for(auto&& x : p.get_parameter_shapes())
    {
41
        m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second, get_hash(x.first)));
Paul's avatar
Paul committed
42
43
    }
    auto out = migraph::gpu::from_gpu(p.eval(m));
Paul's avatar
Paul committed
44
    std::cout << p << std::endl;
Paul's avatar
Paul committed
45
    return migraph::gpu::from_gpu(out);
Paul's avatar
Paul committed
46
47
}

Paul's avatar
Paul committed
48
49
template <class F>
void verify_program(const std::string& name, F f, double tolerance = 100)
50
51
52
53
{
    auto x = run_cpu(f);
    auto y = run_gpu(f);
    migraph::verify_args(name, x, y, tolerance);
Paul's avatar
Paul committed
54
55
    // std::cout << "cpu: " << x << std::endl;
    // std::cout << "gpu: " << y << std::endl;
56
57
}

Paul's avatar
Paul committed
58
void verify_instructions(const migraph::program& prog, double tolerance = 80)
59
{
Paul's avatar
Paul committed
60
    for(auto&& ins : prog)
61
    {
Paul's avatar
Paul committed
62
        if(ins.name().front() == '@')
63
            continue;
Paul's avatar
Paul committed
64
        if(ins.name() == "broadcast")
65
            continue;
Paul's avatar
Paul committed
66
        if(ins.name() == "transpose")
67
            continue;
Paul's avatar
Paul committed
68
        if(ins.name() == "reshape")
69
70
71
72
            continue;
        auto create_program = [&] {
            migraph::program p;
            std::vector<migraph::instruction_ref> inputs;
Paul's avatar
Paul committed
73
            for(auto&& arg : ins.inputs())
74
            {
Paul's avatar
Paul committed
75
                if(arg->name() == "@literal")
Paul's avatar
Paul committed
76
                    inputs.push_back(p.add_literal(arg->get_literal()));
77
                else
Paul's avatar
Paul committed
78
79
                    inputs.push_back(
                        p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
80
            }
81
            p.add_instruction(ins.get_operator(), inputs);
82
83
            return p;
        };
Paul's avatar
Paul committed
84
        try
85
        {
Paul's avatar
Paul committed
86
            std::cout << "Verify: " << ins.name() << std::endl;
87
            std::cout << create_program() << std::endl;
Paul's avatar
Paul committed
88
            verify_program(ins.name(), create_program, tolerance);
89
        }
Paul's avatar
Paul committed
90
        catch(...)
91
        {
Paul's avatar
Paul committed
92
            std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
93
94
95
96
97
            throw;
        }
    }
}

Paul's avatar
Paul committed
98
template <class F>
Paul's avatar
Paul committed
99
100
void verify_reduced(F f, int n, double tolerance = 80)
{
Paul's avatar
Paul committed
101

Paul's avatar
Paul committed
102
103
    auto create_program = [&] {
        migraph::program p = f();
Paul's avatar
Paul committed
104
        auto last          = std::prev(p.end(), n + 1);
Paul's avatar
Paul committed
105
106
107
108
109
110
111
112
        p.remove_instructions(last, p.end());
        return p;
    };
    std::cout << "Verify: " << std::endl;
    std::cout << create_program() << std::endl;
    verify_program(std::to_string(n), create_program, tolerance);
}

Paul's avatar
Paul committed
113
template <class F>
Paul's avatar
Paul committed
114
115
116
void verify_reduced_program(F f, double tolerance = 80)
{
    migraph::program p = f();
Paul's avatar
Paul committed
117
118
    auto n             = std::distance(p.begin(), p.end());
    for(int i = 0; i < n; i++)
Paul's avatar
Paul committed
119
120
121
122
123
    {
        verify_reduced(f, i, tolerance);
    }
}

Paul's avatar
Paul committed
124
125
int main(int argc, char const* argv[])
{
Paul's avatar
Paul committed
126
    std::vector<std::string> args(argv + 1, argv + argc);
127
    if(not args.empty())
Paul's avatar
Paul committed
128
    {
129
        std::string file = args.front();
Paul's avatar
Paul committed
130
        auto p           = migraph::parse_onnx(file);
Paul's avatar
Paul committed
131
132
        std::cout << p << std::endl;

133
134
135
136
        if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; }))
        {
            verify_instructions(p);
        }
Paul's avatar
Paul committed
137
138
139
140
        else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; }))
        {
            verify_reduced_program([&] { return migraph::parse_onnx(file); });
        }
141
142
143
144
        else
        {
            verify_program(file, [&] { return migraph::parse_onnx(file); });
        }
Paul's avatar
Paul committed
145
146
    }
}