"template/chatml.gotmpl" did not exist on "b0135f4b9b176eab9155b660d04c9ca2a1ec2341"
verify_onnx.cpp 3.01 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
#include <migraph/onnx.hpp>
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <migraph/cpu/cpu_target.hpp>
Paul's avatar
Paul committed
5
6
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
Paul's avatar
Paul committed
7
#include <migraph/generate.hpp>
8
#include <migraph/verify_args.hpp>
9
#include <migraph/instruction.hpp>
Paul's avatar
Paul committed
10

11
12
template <class T>
auto get_hash(const T& x)
Paul's avatar
Paul committed
13
{
14
15
16
17
18
19
20
    return std::hash<T>{}(x);
}

template<class F>
migraph::argument run_cpu(F f)
{
    auto p = f();
Paul's avatar
Paul committed
21
    p.compile(migraph::cpu::cpu_target{});
Paul's avatar
Paul committed
22
23
24
25
26
27
    migraph::program::parameter_map m;
    for(auto&& x : p.get_parameter_shapes())
    {
        m[x.first] = migraph::generate_argument(x.second);
    }
    auto out = p.eval(m);
Paul's avatar
Paul committed
28
29
30
31
    std::cout << p << std::endl;
    return out;
}

32
33
template<class F>
migraph::argument run_gpu(F f)
Paul's avatar
Paul committed
34
{
35
    auto p = f();
Paul's avatar
Paul committed
36
    p.compile(migraph::gpu::target{});
Paul's avatar
Paul committed
37

Paul's avatar
Paul committed
38
39
40
    migraph::program::parameter_map m;
    for(auto&& x : p.get_parameter_shapes())
    {
41
        m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second, get_hash(x.first)));
Paul's avatar
Paul committed
42
43
    }
    auto out = migraph::gpu::from_gpu(p.eval(m));
Paul's avatar
Paul committed
44
    std::cout << p << std::endl;
Paul's avatar
Paul committed
45
    return migraph::gpu::from_gpu(out);
Paul's avatar
Paul committed
46
47
}

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
template<class F>
void verify_program(const std::string & name, F f, double tolerance = 100)
{
    auto x = run_cpu(f);
    auto y = run_gpu(f);
    migraph::verify_args(name, x, y, tolerance);
}

void verify_instructions(const migraph::program& prog, double tolerance = 100)
{
    for(auto&& ins:prog)
    {
        if(ins.op.name().front() == '@')
            continue;
        if(ins.op.name() == "broadcast")
            continue;
        if(ins.op.name() == "transpose")
            continue;
        if(ins.op.name() == "reshape")
            continue;
        auto create_program = [&] {
            migraph::program p;
            std::vector<migraph::instruction_ref> inputs;
            for(auto&& arg:ins.arguments)
            {
                if(arg->op.name() == "@literal")
                    inputs.push_back(p.add_literal(arg->lit));
                else
                    inputs.push_back(p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
            }
            p.add_instruction(ins.op, inputs);
            return p;
        };
        try 
        {
            std::cout << "Verify: " << ins.op.name() << std::endl;
            std::cout << create_program() << std::endl;
            verify_program(ins.op.name(), create_program, tolerance);
        }
        catch(...) 
        {
            std::cout << "Instruction " << ins.op.name() << " threw an exception." << std::endl;
            throw;
        }
    }
}

Paul's avatar
Paul committed
95
96
int main(int argc, char const* argv[])
{
97
98
    std::vector<std::string> args(argv+1, argv+argc);
    if(not args.empty())
Paul's avatar
Paul committed
99
    {
100
        std::string file = args.front();
Paul's avatar
Paul committed
101
        auto p           = migraph::parse_onnx(file);
Paul's avatar
Paul committed
102
103
        std::cout << p << std::endl;

104
105
106
107
108
109
110
111
        if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; }))
        {
            verify_instructions(p);
        }
        else
        {
            verify_program(file, [&] { return migraph::parse_onnx(file); });
        }
Paul's avatar
Paul committed
112
113
    }
}