verify_onnx.cpp 4.02 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
#include <migraphx/onnx.hpp>
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
5
6
7
8
9
#include <migraphx/cpu/target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
10

11
12
template <class T>
auto get_hash(const T& x)
Paul's avatar
Paul committed
13
{
14
15
16
    return std::hash<T>{}(x);
}

Paul's avatar
Paul committed
17
template <class F>
Paul's avatar
Paul committed
18
migraphx::argument run_cpu(F f)
19
20
{
    auto p = f();
Paul's avatar
Paul committed
21
22
    p.compile(migraphx::cpu::target{});
    migraphx::program::parameter_map m;
Paul's avatar
Paul committed
23
24
    for(auto&& x : p.get_parameter_shapes())
    {
Paul's avatar
Paul committed
25
        m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
Paul's avatar
Paul committed
26
27
    }
    auto out = p.eval(m);
Paul's avatar
Paul committed
28
29
30
31
    std::cout << p << std::endl;
    return out;
}

Paul's avatar
Paul committed
32
template <class F>
Paul's avatar
Paul committed
33
migraphx::argument run_gpu(F f)
Paul's avatar
Paul committed
34
{
35
    auto p = f();
Paul's avatar
Paul committed
36
    p.compile(migraphx::gpu::target{});
Paul's avatar
Paul committed
37

Paul's avatar
Paul committed
38
    migraphx::program::parameter_map m;
Paul's avatar
Paul committed
39
40
    for(auto&& x : p.get_parameter_shapes())
    {
Paul's avatar
Paul committed
41
        m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second, get_hash(x.first)));
Paul's avatar
Paul committed
42
    }
Paul's avatar
Paul committed
43
    auto out = migraphx::gpu::from_gpu(p.eval(m));
Paul's avatar
Paul committed
44
    std::cout << p << std::endl;
Paul's avatar
Paul committed
45
    return migraphx::gpu::from_gpu(out);
Paul's avatar
Paul committed
46
47
}

Paul's avatar
Paul committed
48
49
template <class F>
void verify_program(const std::string& name, F f, double tolerance = 100)
50
51
52
{
    auto x = run_cpu(f);
    auto y = run_gpu(f);
Paul's avatar
Paul committed
53
    migraphx::verify_args(name, x, y, tolerance);
Paul's avatar
Paul committed
54
55
    // std::cout << "cpu: " << x << std::endl;
    // std::cout << "gpu: " << y << std::endl;
56
57
}

Paul's avatar
Paul committed
58
void verify_instructions(const migraphx::program& prog, double tolerance = 80)
59
{
Paul's avatar
Paul committed
60
    for(auto&& ins : prog)
61
    {
Paul's avatar
Paul committed
62
        if(ins.name().front() == '@')
63
            continue;
Paul's avatar
Paul committed
64
        if(ins.name() == "broadcast")
65
            continue;
Paul's avatar
Paul committed
66
        if(ins.name() == "transpose")
67
            continue;
Paul's avatar
Paul committed
68
        if(ins.name() == "reshape")
69
70
            continue;
        auto create_program = [&] {
Paul's avatar
Paul committed
71
72
            migraphx::program p;
            std::vector<migraphx::instruction_ref> inputs;
Paul's avatar
Paul committed
73
            for(auto&& arg : ins.inputs())
74
            {
Paul's avatar
Paul committed
75
                if(arg->name() == "@literal")
Paul's avatar
Paul committed
76
                    inputs.push_back(p.add_literal(arg->get_literal()));
77
                else
Paul's avatar
Paul committed
78
79
                    inputs.push_back(
                        p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
80
            }
81
            p.add_instruction(ins.get_operator(), inputs);
82
83
            return p;
        };
Paul's avatar
Paul committed
84
        try
85
        {
Paul's avatar
Paul committed
86
            std::cout << "Verify: " << ins.name() << std::endl;
87
            std::cout << create_program() << std::endl;
Paul's avatar
Paul committed
88
            verify_program(ins.name(), create_program, tolerance);
89
        }
Paul's avatar
Paul committed
90
        catch(...)
91
        {
Paul's avatar
Paul committed
92
            std::cout << "Instruction " << ins.name() << " threw an exception." << std::endl;
93
94
95
96
97
            throw;
        }
    }
}

Paul's avatar
Paul committed
98
template <class F>
Paul's avatar
Paul committed
99
100
void verify_reduced(F f, int n, double tolerance = 80)
{
Paul's avatar
Paul committed
101

Paul's avatar
Paul committed
102
    auto create_program = [&] {
Paul's avatar
Paul committed
103
        migraphx::program p = f();
Paul's avatar
Paul committed
104
        auto last          = std::prev(p.end(), n + 1);
Paul's avatar
Paul committed
105
106
107
108
109
110
111
112
        p.remove_instructions(last, p.end());
        return p;
    };
    std::cout << "Verify: " << std::endl;
    std::cout << create_program() << std::endl;
    verify_program(std::to_string(n), create_program, tolerance);
}

Paul's avatar
Paul committed
113
template <class F>
Paul's avatar
Paul committed
114
115
void verify_reduced_program(F f, double tolerance = 80)
{
Paul's avatar
Paul committed
116
    migraphx::program p = f();
Paul's avatar
Paul committed
117
118
    auto n             = std::distance(p.begin(), p.end());
    for(int i = 0; i < n; i++)
Paul's avatar
Paul committed
119
120
121
122
123
    {
        verify_reduced(f, i, tolerance);
    }
}

Paul's avatar
Paul committed
124
125
int main(int argc, char const* argv[])
{
Paul's avatar
Paul committed
126
    std::vector<std::string> args(argv + 1, argv + argc);
127
    if(not args.empty())
Paul's avatar
Paul committed
128
    {
129
        std::string file = args.front();
Paul's avatar
Paul committed
130
        auto p           = migraphx::parse_onnx(file);
Paul's avatar
Paul committed
131
132
        std::cout << p << std::endl;

133
134
135
136
        if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; }))
        {
            verify_instructions(p);
        }
Paul's avatar
Paul committed
137
138
        else if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-r"; }))
        {
Paul's avatar
Paul committed
139
            verify_reduced_program([&] { return migraphx::parse_onnx(file); });
Paul's avatar
Paul committed
140
        }
141
142
        else
        {
Paul's avatar
Paul committed
143
            verify_program(file, [&] { return migraphx::parse_onnx(file); });
144
        }
Paul's avatar
Paul committed
145
146
    }
}