main.cpp 3.37 KB
Newer Older
Paul's avatar
Paul committed
1
2
#include "argument_parser.hpp"
#include "command.hpp"
Paul's avatar
Paul committed
3
#include "verify.hpp"
Paul's avatar
Paul committed
4

Paul's avatar
Paul committed
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>

namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

struct loader
{
    std::string file;
    std::string type;
    bool is_nhwc = false;
Paul's avatar
Paul committed
18
    unsigned trim = 0;
Paul's avatar
Paul committed
19
20
21
22
23
24
25

    void parse(argument_parser& ap)
    {
        ap.add(file, {}, ap.metavar("<input file>"));
        ap.add(type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
        ap.add(type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
        ap.add(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
Paul's avatar
Paul committed
26
27
        ap.add(
            is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
Paul's avatar
Paul committed
28
        ap.add(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
Paul's avatar
Paul committed
29
30
    }

Paul's avatar
Paul committed
31
    program load()
Paul's avatar
Paul committed
32
33
    {
        program p;
Paul's avatar
Paul committed
34
        if(type.empty())
Paul's avatar
Paul committed
35
        {
Paul's avatar
Paul committed
36
            if(ends_with(file, ".onnx"))
Paul's avatar
Paul committed
37
38
39
40
                type = "onnx";
            else
                type = "tf";
        }
Paul's avatar
Paul committed
41
        std::cout << "Reading: " << file << std::endl;
Paul's avatar
Paul committed
42
        if(type == "onnx")
Paul's avatar
Paul committed
43
            p = parse_onnx(file);
Paul's avatar
Paul committed
44
        else if(type == "tf")
Paul's avatar
Paul committed
45
            p = parse_tf(file, is_nhwc);
Paul's avatar
Paul committed
46
47
48
49
50
        if (trim > 0)
        {
            auto last           = std::prev(p.end(), trim);
            p.remove_instructions(last, p.end());
        }
Paul's avatar
Paul committed
51
52
53
54
55
56
57
        return p;
    }
};

struct read : command<read>
{
    loader l;
Paul's avatar
Paul committed
58
    void parse(argument_parser& ap) { l.parse(ap); }
Paul's avatar
Paul committed
59
60
61
62
63
64
65
66

    void run()
    {
        auto p = l.load();
        std::cout << p << std::endl;
    }
};

Paul's avatar
Paul committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
struct verify : command<verify>
{
    loader l;
    double tolerance = 80;
    bool per_instruction = false;
    bool reduce = false;
    void parse(argument_parser& ap) 
    { 
        l.parse(ap);
        ap.add(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
        ap.add(per_instruction, {"-i", "--per-instruction"}, ap.help("Verify each instruction"), ap.set_value(true));
        ap.add(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
    }

    void run()
    {
        auto p = l.load();
        std::cout << p << std::endl;

        if(per_instruction)
        {
            verify_instructions(p, tolerance);
        }
        else if(reduce)
        {
            verify_reduced_program(p, tolerance);
        }
        else
        {
            verify_program(l.file, p, tolerance);
        }
    }
};

Paul's avatar
Paul committed
101
102
103
104
105
struct main_command
{
    static std::string get_command_help()
    {
        std::string result = "Commands:\n";
Paul's avatar
Paul committed
106
        for(const auto& p : get_commands())
Paul's avatar
Paul committed
107
108
109
            result += "    " + p.first + "\n";
        return result;
    }
Paul's avatar
Paul committed
110
    void parse(argument_parser& ap)
Paul's avatar
Paul committed
111
112
113
114
115
116
117
    {
        ap.add(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
    }

    void run() {}
};

Paul's avatar
Paul committed
118
119
120
121
122
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx

using namespace migraphx::driver;
Paul's avatar
Paul committed
123
124
int main(int argc, const char* argv[])
{
Paul's avatar
Paul committed
125
    std::vector<std::string> args(argv + 1, argv + argc);
Paul's avatar
Paul committed
126
    if(args.empty())
Paul's avatar
Paul committed
127
        return 0;
Paul's avatar
Paul committed
128
    auto&& m = get_commands();
Paul's avatar
Paul committed
129
    auto cmd = args.front();
Paul's avatar
Paul committed
130
    if(m.count(cmd) > 0)
Paul's avatar
Paul committed
131
    {
Paul's avatar
Paul committed
132
        m.at(cmd)({args.begin() + 1, args.end()});
Paul's avatar
Paul committed
133
    }
Paul's avatar
Paul committed
134
    else
Paul's avatar
Paul committed
135
    {
Paul's avatar
Paul committed
136
        run_command<main_command>(args);
Paul's avatar
Paul committed
137
138
139
    }
    return 0;
}