parse_load_save.cpp 3.12 KB
Newer Older
turneram's avatar
turneram committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <algorithm>

// MIGraphX C++ API
#include <migraphx/migraphx.hpp>

char* getCmdOption(char**, char**, const std::string&);

bool cmdOptionExists(char**, char**, const std::string&);

int main(int argc, char** argv)
{
    if(argc < 2)
    {
        std::cout << "Usage: " << argv[0] << " <input_file> "
                  << "[options]" << std::endl;
        std::cout << "options:" << std::endl;
        std::cout << "\t--parse onnx" << std::endl;
        std::cout << "\t--load  json/msgpack" << std::endl;
        std::cout << "\t--save  <output_file>" << std::endl;
        return 0;
    }

    char* parse_arg        = getCmdOption(argv + 2, argv + argc, "--parse");
    char* load_arg         = getCmdOption(argv + 2, argv + argc, "--load");
    char* save_arg         = getCmdOption(argv + 2, argv + argc, "--save");
    const char* input_file = argv[1];

    migraphx::program p;

    if(cmdOptionExists(argv + 2, argv + argc, "--parse") ||
       !cmdOptionExists(argv + 2, argv + argc, "--load"))
    {
        std::cout << "Parsing ONNX File" << std::endl;
        migraphx::onnx_options options;
        p = parse_onnx(input_file, options);
    }
    else if(load_arg != nullptr)
    {
        std::cout << "Loading Graph File" << std::endl;
        std::string format = load_arg;
        if(format == "json")
        {
            migraphx_file_options options;
            options.format = "json";
            p              = migraphx::load(input_file, options);
        }
        else if(format == "msgpack")
        {
            migraphx_file_options options;
            options.format = "msgpack";
            p              = migraphx::load(input_file, options);
        }
        else
            p = migraphx::load(input_file);
    }
    else
    {
        std::cout << "Error: Incorrect Usage" << std::endl;
        std::cout << "Usage: " << argv[0] << " <input_file> "
                  << "[options]" << std::endl;
        std::cout << "options:" << std::endl;
        std::cout << "\t--parse onnx" << std::endl;
        std::cout << "\t--load  json/msgpack" << std::endl;
        std::cout << "\t--save  <output_file>" << std::endl;
        return 0;
    }

    std::cout << "Input Graph: " << std::endl;
    p.print();
    std::cout << std::endl;

    if(cmdOptionExists(argv + 2, argv + argc, "--save"))
    {
        std::cout << "Saving program..." << std::endl;
        std::string output_file;
        output_file = save_arg == nullptr ? "out" : save_arg;
        output_file.append(".msgpack");

        migraphx_file_options options;
        options.format = "msgpack";
        migraphx::save(p, output_file.c_str(), options);
        std::cout << "Program has been saved as ./" << output_file << std::endl;
    }

    return 0;
}

char* getCmdOption(char** begin, char** end, const std::string& option)
{
    char** itr = std::find(begin, end, option);
    if(itr != end && ++itr != end)
    {
        return *itr;
    }

    return nullptr;
}

bool cmdOptionExists(char** begin, char** end, const std::string& option)
{
    return std::find(begin, end, option) != end;
}