main.cpp 26.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24

kahmed10's avatar
kahmed10 committed
25
#include "verify.hpp"
Paul's avatar
Paul committed
26
27
#include "argument_parser.hpp"
#include "command.hpp"
kahmed10's avatar
kahmed10 committed
28
#include "precision.hpp"
29
#include "passes.hpp"
Paul's avatar
Paul committed
30
#include "perf.hpp"
31
#include "models.hpp"
32
#include "marker_roctx.hpp"
Paul's avatar
Paul committed
33

Paul's avatar
Paul committed
34
35
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
36
#ifdef MIGRAPHX_ENABLE_PYTHON
37
#include <migraphx/py.hpp>
38
#endif
Paul's avatar
Paul committed
39
#include <migraphx/stringutils.hpp>
40
#include <migraphx/convert_to_json.hpp>
41
42
#include <migraphx/load_save.hpp>
#include <migraphx/json.hpp>
43
#include <migraphx/version.h>
Paul's avatar
Paul committed
44

45
46
47
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
48
49
#include <migraphx/generate.hpp>
#include <migraphx/pass_manager.hpp>
50
#include <migraphx/propagate_constant.hpp>
51
#include <migraphx/quantization.hpp>
52
#include <migraphx/register_op.hpp>
53
54
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
55
#include <migraphx/register_target.hpp>
56

57
58
#include <fstream>

Paul's avatar
Paul committed
59
60
61
62
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

kahmed10's avatar
kahmed10 committed
63
64
65
66
inline std::string get_version()
{
    return "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) + "." +
           std::to_string(MIGRAPHX_VERSION_MINOR) + "." + std::to_string(MIGRAPHX_VERSION_PATCH) +
Artur Wojcik's avatar
Artur Wojcik committed
67
           "." MIGRAPHX_VERSION_TWEAK;
kahmed10's avatar
kahmed10 committed
68
69
}

Paul's avatar
Paul committed
70
71
struct loader
{
72
    std::string model;
Paul's avatar
Paul committed
73
    std::string file;
Paul's avatar
Paul committed
74
    std::string file_type;
75
76
77
78
79
    unsigned batch              = 1;
    bool is_nhwc                = true;
    unsigned trim               = 0;
    bool optimize               = false;
    bool skip_unknown_operators = false;
80
81
82
    bool brief                  = false;
    std::string output_type;
    std::string output;
83
    std::string default_dyn_dim;
Shucai Xiao's avatar
Shucai Xiao committed
84
    std::vector<std::string> param_dims;
85
    std::vector<std::string> dyn_param_dims;
kahmed10's avatar
kahmed10 committed
86
    std::vector<std::string> output_names;
87
    std::vector<std::string> passes;
Paul's avatar
Paul committed
88
89
90

    void parse(argument_parser& ap)
    {
91
92
93
94
95
        ap(file, {}, ap.metavar("<input file>"), ap.file_exist(), ap.required(), ap.group("input"));
        ap(model,
           {"--model"},
           ap.help("Load model"),
           ap.type("resnet50|inceptionv3|alexnet"),
96
           ap.matches({"resnet50", "inceptionv3", "alexnet"}),
97
           ap.group("input"));
Paul's avatar
Paul committed
98
99
        ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
        ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
100
101
        ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
        ap(file_type, {"--migraphx-json"}, ap.help("Load as MIGraphX JSON"), ap.set_value("json"));
102
103
104
105
106
        ap(batch,
           {"--batch"},
           ap.help("For a static model, sets default_dim_value size (commonly batch size). For a "
                   "dynamic batch model, sets the batch "
                   "size at runtime."));
Paul's avatar
Paul committed
107
        ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
108
109
110
111
        ap(skip_unknown_operators,
           {"--skip-unknown-operators"},
           ap.help("Skip unknown operators when parsing and continue to parse."),
           ap.set_value(true));
Paul's avatar
Paul committed
112
        ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
Paul's avatar
Paul committed
113
        ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
Shucai Xiao's avatar
Shucai Xiao committed
114
115
116
117
118
        ap(param_dims,
           {"--input-dim"},
           ap.help("Dim of a parameter (format: \"@name d1 d2 dn\")"),
           ap.append(),
           ap.nargs(2));
119
120
121
122
123
124
125
126
127
128
        ap(dyn_param_dims,
           {"--dyn-input-dim"},
           ap.help("Dynamic dimensions of a parameter (format: \"@name_1\" \"[{min:x, max:y, "
                   "optimals:[o1,o2,...]}, dim2,dim3, ...]\", \"@name_2\", ... You can supply a "
                   "single integer value for a dimension to specify it as fixed."),
           ap.append(),
           ap.nargs(2));
        ap(default_dyn_dim,
           {"--default-dyn-dim"},
           ap.help("Default dynamic dimension (format: \"{min:x, max:y, optimals:[o1,o2]}\")."));
kahmed10's avatar
kahmed10 committed
129
130
131
132
133
        ap(output_names,
           {"--output-names"},
           ap.help("Names of node output (format: \"name_1 name_2 name_n\")"),
           ap.append(),
           ap.nargs(2));
134
        ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
135
        ap(passes, {"--apply-pass", "-p"}, ap.help("Passes to apply to model"), ap.append());
136
137
138
139
140
141
142
        ap(output_type,
           {"--graphviz", "-g"},
           ap.help("Print out a graphviz representation."),
           ap.set_value("graphviz"));
        ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
        ap(output_type,
           {"--cpp"},
143
           ap.help("Print out the program as C++ program."),
144
           ap.set_value("cpp"));
145
146
147
148
        ap(output_type,
           {"--python", "--py"},
           ap.help("Print out the program as python program."),
           ap.set_value("py"));
149
150
151
152
153
154
155
156
157
158
        ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
        ap(output_type,
           {"--text"},
           ap.help("Print out program in text format."),
           ap.set_value("text"));
        ap(output_type,
           {"--binary"},
           ap.help("Print out program in binary format."),
           ap.set_value("binary"));
        ap(output, {"--output", "-o"}, ap.help("Output to file."));
Paul's avatar
Paul committed
159
160
    }

Shucai Xiao's avatar
Shucai Xiao committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    static auto parse_param_dims(const std::vector<std::string>& param_dims_info)
    {
        std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
        std::string name = "";
        for(auto&& x : param_dims_info)
        {
            if(x[0] == '@')
            {
                name = x.substr(1);
            }
            else
            {
                map_input_dims[name].push_back(value_parser<std::size_t>::apply(x));
            }
        }

        return map_input_dims;
    }

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    static auto parse_dyn_dims_json(const std::string& dd_json)
    {
        // expecting a json string like "[{min:1,max:64,optimals:[1,2,4,8]},3,224,224]"
        auto v = from_json_string(convert_to_json(dd_json));
        std::vector<migraphx::shape::dynamic_dimension> dyn_dims;
        std::transform(v.begin(), v.end(), std::back_inserter(dyn_dims), [&](auto x) {
            if(x.is_object())
                return from_value<migraphx::shape::dynamic_dimension>(x);
            auto d = x.template to<std::size_t>();
            return migraphx::shape::dynamic_dimension{d, d};
        });
        return dyn_dims;
    }

    static auto parse_dyn_dims_map(const std::vector<std::string>& param_dyn_dims)
    {
        // expecting vector of strings formatted like
        // {"@param_name_0", "dd_json_0", "@param_name_1", "dd_json_1", ...}
        std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
        std::string name = "";
        for(auto&& x : param_dyn_dims)
        {
            if(x[0] == '@')
            {
                name = x.substr(1);
            }
            else
            {
                map_dyn_input_dims[name] = parse_dyn_dims_json(x);
            }
        }
        return map_dyn_input_dims;
    }

kahmed10's avatar
kahmed10 committed
214
215
216
217
218
219
220
221
222
223
224
    static auto parse_output_names(const std::vector<std::string>& output_names_info)
    {
        std::vector<std::string> output_node_names;
        std::transform(output_names_info.begin(),
                       output_names_info.end(),
                       std::back_inserter(output_node_names),
                       [&](auto x) { return value_parser<std::string>::apply(x); });

        return output_node_names;
    }

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    tf_options get_tf_options() const
    {
        auto map_input_dims    = parse_param_dims(param_dims);
        auto output_node_names = parse_output_names(output_names);
        tf_options options;
        options.is_nhwc           = is_nhwc;
        options.batch_size        = batch;
        options.map_input_dims    = map_input_dims;
        options.output_node_names = output_node_names;
        return options;
    }

    onnx_options get_onnx_options() const
    {
        auto map_input_dims     = parse_param_dims(param_dims);
        auto map_dyn_input_dims = parse_dyn_dims_map(dyn_param_dims);
        onnx_options options;
        if(default_dyn_dim.empty())
        {
            options.default_dim_value = batch;
        }
        else
        {
            auto v                        = from_json_string(convert_to_json(default_dyn_dim));
            options.default_dyn_dim_value = from_value<migraphx::shape::dynamic_dimension>(v);
        }
        options.skip_unknown_operators = skip_unknown_operators;
        options.print_program_on_error = true;
        options.map_input_dims         = map_input_dims;
        options.map_dyn_input_dims     = map_dyn_input_dims;
        return options;
    }

258
259
260
261
262
263
264
265
266
267
268
269
270
271
    static std::string get_file_type(const std::string& file)
    {
        if(ends_with(file, ".onnx"))
            return "onnx";
        else if(ends_with(file, ".pb"))
            return "tf";
        else if(ends_with(file, ".json"))
            return "json";
        else if(ends_with(file, ".py"))
            return "py";
        else
            return "migraphx";
    }

Paul's avatar
Paul committed
272
    program load()
Paul's avatar
Paul committed
273
274
    {
        program p;
275
        if(model.empty())
Paul's avatar
Paul committed
276
        {
277
278
            if(file_type.empty())
            {
279
                file_type = get_file_type(file);
280
281
282
            }
            std::cout << "Reading: " << file << std::endl;
            if(file_type == "onnx")
283
            {
284
                p = parse_onnx(file, get_onnx_options());
285
            }
286
            else if(file_type == "tf")
287
            {
288
                p = parse_tf(file, get_tf_options());
289
            }
290
291
292
293
294
295
            else if(file_type == "json")
            {
                file_options options;
                options.format = "json";
                p              = migraphx::load(file, options);
            }
296
#ifdef MIGRAPHX_ENABLE_PYTHON
297
298
299
300
            else if(file_type == "py")
            {
                p = migraphx::load_py(file);
            }
301
#endif
302
303
304
305
            else if(file_type == "migraphx")
            {
                p = migraphx::load(file);
            }
306
307
308
309
310
311
312
313
314
315
316
        }
        else
        {
            if(model == "resnet50")
                p = resnet50(batch);
            else if(model == "inceptionv3")
                p = inceptionv3(batch);
            else if(model == "alexnet")
                p = alexnet(batch);
            else
                MIGRAPHX_THROW("Unknown model: " + model);
Paul's avatar
Paul committed
317
        }
Paul's avatar
Paul committed
318
        if(trim > 0)
Paul's avatar
Paul committed
319
        {
320
            auto* mm  = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
321
322
            auto last = std::prev(mm->end(), trim);
            mm->remove_instructions(last, mm->end());
Paul's avatar
Paul committed
323
        }
324
325
326
        // Remove unused variable when exporting to cpp
        if(output_type == "cpp")
            migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
Paul's avatar
Paul committed
327
        if(optimize)
328
329
        {
            migraphx::run_passes(*p.get_main_module(),
Paul's avatar
Paul committed
330
331
332
333
334
335
336
337
338
339
340
341
                                 {
                                     migraphx::eliminate_identity{},
                                     migraphx::dead_code_elimination{},
                                     migraphx::simplify_algebra{},
                                     migraphx::dead_code_elimination{},
                                     migraphx::simplify_reshapes{},
                                     migraphx::dead_code_elimination{},
                                     migraphx::propagate_constant{},
                                     migraphx::dead_code_elimination{},
                                     migraphx::eliminate_pad{},
                                     migraphx::dead_code_elimination{},
                                 });
342
        }
343
344
        if(not passes.empty())
            migraphx::run_passes(*p.get_main_module(), get_passes(passes));
Paul's avatar
Paul committed
345
346
        return p;
    }
347
348
349
350
351
352

    static void write(std::ostream& os, const std::vector<char>& buffer)
    {
        os.write(buffer.data(), buffer.size());
    }

353
    void save(const program& p) const
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    {
        auto* os = &std::cout;
        std::ofstream fs;
        if(not output.empty())
        {
            fs.open(output);
            os = &fs;
        }

        std::string type = output_type;
        if(type.empty())
        {
            if(output.empty())
                type = "text";
            else
                type = "binary";
        }

372
373
374
        if(type == "py")
            p.print_py(*os);
        else if(type == "cpp")
375
376
377
378
379
380
381
382
383
384
            p.print_cpp(*os);
        else if(type == "graphviz")
            p.print_graph(*os, brief);
        else if(type == "text")
            *os << p << std::endl;
        else if(type == "json")
            *os << to_json_string(p.to_value()) << std::endl;
        else if(type == "binary")
            write(*os, save_buffer(p));
    }
Paul's avatar
Paul committed
385
386
};

387
388
389
390
391
392
struct program_params
{
    std::vector<std::string> fill0{};
    std::vector<std::string> fill1{};
    void parse(argument_parser& ap)
    {
Shucai Xiao's avatar
Shucai Xiao committed
393
394
        ap(fill0, {"--fill0"}, ap.help("Fill parameter with 0s"), ap.append(), ap.nargs(2));
        ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append(), ap.nargs(2));
395
396
    }

397
    auto generate(const program& p, const target& t, bool offload, unsigned batch)
398
    {
399
        parameter_map m;
400
401
402
403
404
405
406
        auto param_shapes = p.get_parameter_shapes();
        std::unordered_map<std::string, shape> static_param_shapes;
        std::transform(
            param_shapes.cbegin(),
            param_shapes.cend(),
            std::inserter(static_param_shapes, static_param_shapes.end()),
            [&](const auto& x) { return std::make_pair(x.first, x.second.to_static(batch)); });
407
        for(auto&& s : fill0)
408
            m[s] = fill_argument(static_param_shapes.at(s), 0);
409
        for(auto&& s : fill1)
410
411
            m[s] = fill_argument(static_param_shapes.at(s), 1);
        fill_param_map(m, static_param_shapes, t, offload);
412
413
414
415
        return m;
    }
};

416
417
418
419
struct compiler_target
{
#ifdef HAVE_GPU
    std::string target_name = "gpu";
420
#elif defined(HAVE_CPU)
421
    std::string target_name = "cpu";
422
423
#elif defined(HAVE_FPGA)
    std::string target_name = "fpga";
424
#else
425
    std::string target_name = "ref";
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
#endif

    void parse(argument_parser& ap)
    {
        ap(target_name, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value("gpu"));
        ap(target_name, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value("cpu"));
        ap(target_name,
           {"--ref"},
           ap.help("Compile on the reference implementation"),
           ap.set_value("ref"));
    }

    target get_target() const { return make_target(target_name); }
};

Paul's avatar
Paul committed
441
442
443
struct compiler
{
    loader l;
444
    program_params parameters;
445
    compiler_target ct;
446
    compile_options co;
447
448
    bool to_fp16 = false;
    bool to_int8 = false;
449

kahmed10's avatar
kahmed10 committed
450
    std::vector<std::string> fill0;
Paul's avatar
Paul committed
451
    std::vector<std::string> fill1;
Paul's avatar
Paul committed
452
453
    void parse(argument_parser& ap)
    {
Paul's avatar
Paul committed
454
        l.parse(ap);
455
        parameters.parse(ap);
456
        ct.parse(ap);
457
        ap(co.offload_copy,
458
459
           {"--enable-offload-copy"},
           ap.help("Enable implicit offload copying"),
460
           ap.set_value(true));
461
        ap(co.fast_math,
kahmed10's avatar
kahmed10 committed
462
463
464
           {"--disable-fast-math"},
           ap.help("Disable fast math optimization"),
           ap.set_value(false));
465
466
467
468
        ap(co.exhaustive_tune,
           {"--exhaustive-tune"},
           ap.help("Exhastively search for best tuning parameters for kernels"),
           ap.set_value(true));
469
470
        ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
        ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
Paul's avatar
Paul committed
471
472
    }

473
474
    auto params(const program& p)
    {
475
        return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch);
476
    }
477

478
479
480
481
482
    auto host_params(const program& p)
    {
        return parameters.generate(p, ct.get_target(), true, l.batch);
    }

483
484
485
    program compile()
    {
        auto p = l.load();
486
        // Dont compile if its already been compiled
487

488
        if(p.is_compiled())
489
490
491
492
493
        {
            if(ct.target_name == "gpu")
            {
                if(is_offload_copy_set(p) and not co.offload_copy)
                {
494
495
496
497
498
                    std::cout
                        << "[WARNING]: MIGraphX program was likely compiled with offload_copy "
                           "set, Try "
                           "passing "
                           "`--enable-offload-copy` if program run fails.\n";
499
500
501
                }
                else if(co.offload_copy)
                {
502
                    std::cout << "[WARNING]: MIGraphX program was likely compiled without "
503
504
505
506
507
508
509
                                 "offload_copy set, Try "
                                 "removing "
                                 "`--enable-offload-copy` flag if passed to driver, if program run "
                                 "fails.\n";
                }
            }

510
            return p;
511
        }
512
        auto t = ct.get_target();
513
        if(to_fp16)
514
515
516
        {
            quantize_fp16(p);
        }
517
        if(to_int8)
518
        {
519
            quantize_int8(p, t, {host_params(p)});
520
        }
521
        p.compile(t, co);
522
        l.save(p);
523
524
        return p;
    }
Paul's avatar
Paul committed
525
526
};

Paul's avatar
Paul committed
527
528
529
struct read : command<read>
{
    loader l;
530
    void parse(argument_parser& ap) { l.parse(ap); }
Paul's avatar
Paul committed
531
532
533
534

    void run()
    {
        auto p = l.load();
535
        l.save(p);
Paul's avatar
Paul committed
536
537
538
    }
};

Paul's avatar
Paul committed
539
540
541
542
543
544
545
546
struct params : command<params>
{
    loader l;
    void parse(argument_parser& ap) { l.parse(ap); }

    void run()
    {
        auto p = l.load();
Paul's avatar
Paul committed
547
        for(auto&& param : p.get_parameter_shapes())
Paul's avatar
Paul committed
548
549
550
551
            std::cout << param.first << ": " << param.second << std::endl;
    }
};

Paul's avatar
Paul committed
552
553
struct verify : command<verify>
{
554
    compiler c;
555
556
557
    std::optional<double> rms_tol;
    std::optional<double> atol;
    std::optional<double> rtol;
Paul's avatar
Paul committed
558
    bool per_instruction = false;
Paul's avatar
Paul committed
559
560
561
    bool reduce          = false;
    void parse(argument_parser& ap)
    {
562
        c.parse(ap);
563
564
565
        ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error"));
        ap(atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute difference"));
        ap(rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative difference"));
Paul's avatar
Paul committed
566
        ap(per_instruction,
Paul's avatar
Paul committed
567
568
569
570
           {"-i", "--per-instruction"},
           ap.help("Verify each instruction"),
           ap.set_value(true));
        ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
Paul's avatar
Paul committed
571
572
573
574
    }

    void run()
    {
575
576
        auto p = c.l.load();
        c.l.save(p);
Paul's avatar
Paul committed
577
578
        std::cout << p << std::endl;

579
        auto t = c.ct.get_target();
580
        auto m = c.parameters.generate(p, t, true, c.l.batch);
581

582
583
        auto quantize = precision::fp32;
        if(c.to_fp16)
584
        {
585
            quantize = precision::fp16;
586
        }
587
        if(c.to_int8)
588
        {
589
            quantize = precision::int8;
590
        }
591

592
593
594
595
596
        auto tols = get_tolerances(p, quantize, rms_tol, atol, rtol);
        std::cout << "rms_tol: " << tols.rms_tol << std::endl;
        std::cout << "atol: " << tols.atol << std::endl;
        std::cout << "rtol: " << tols.rtol << std::endl;

Paul's avatar
Paul committed
597
598
        if(per_instruction)
        {
599
            verify_instructions(p, t, c.co, quantize, tols);
Paul's avatar
Paul committed
600
601
602
        }
        else if(reduce)
        {
603
            verify_reduced_program(p, t, c.co, quantize, m, tols);
Paul's avatar
Paul committed
604
605
606
        }
        else
        {
607
            verify_program(c.l.file, p, t, c.co, quantize, m, tols);
Paul's avatar
Paul committed
608
609
610
611
        }
    }
};

Paul's avatar
Paul committed
612
613
614
struct compile : command<compile>
{
    compiler c;
Paul's avatar
Paul committed
615
    void parse(argument_parser& ap) { c.parse(ap); }
Paul's avatar
Paul committed
616
617
618
619

    void run()
    {
        std::cout << "Compiling ... " << std::endl;
620
        c.compile();
Paul's avatar
Paul committed
621
622
623
624
625
626
    }
};

struct run_cmd : command<run_cmd>
{
    compiler c;
Paul's avatar
Paul committed
627
    void parse(argument_parser& ap) { c.parse(ap); }
Paul's avatar
Paul committed
628
629
630
631
632
633
634

    void run()
    {
        std::cout << "Compiling ... " << std::endl;
        auto p = c.compile();
        std::cout << "Allocating params ... " << std::endl;
        auto m = c.params(p);
Paul's avatar
Paul committed
635
        p.eval(m);
Paul's avatar
Paul committed
636
637
638
639
        std::cout << p << std::endl;
    }
};

Paul's avatar
Paul committed
640
641
642
643
struct perf : command<perf>
{
    compiler c;
    unsigned n = 100;
Paul's avatar
Paul committed
644
645
    void parse(argument_parser& ap)
    {
Paul's avatar
Paul committed
646
647
648
649
650
651
652
653
654
655
656
        c.parse(ap);
        ap(n, {"--iterations", "-n"}, ap.help("Number of iterations to run for perf report"));
    }

    void run()
    {
        std::cout << "Compiling ... " << std::endl;
        auto p = c.compile();
        std::cout << "Allocating params ... " << std::endl;
        auto m = c.params(p);
        std::cout << "Running performance report ... " << std::endl;
657
        p.perf_report(std::cout, n, m, c.l.batch);
Paul's avatar
Paul committed
658
    }
659
660
};

661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
struct roctx : command<roctx>
{
    compiler c;
    void parse(argument_parser& ap) { c.parse(ap); }

    void run()
    {
        std::cout << "Compiling ... " << std::endl;
        auto p = c.compile();
        std::cout << "Allocating params ... " << std::endl;
        auto m = c.params(p);
        std::cout << "rocTX:\tLoading rocTX library..." << std::endl;
        auto rtx = create_marker_roctx();
        p.mark(m, std::move(rtx));
    }
};

678
679
680
struct op : command<op>
{
    bool show_ops = false;
681
    std::string op_name{};
682
683
    void parse(argument_parser& ap)
    {
684
        ap(op_name, {}, ap.metavar("<MIGraphX operator name>"));
685
686
687
688
689
690
691
692
693
694
695
696
        ap(show_ops,
           {"--list", "-l"},
           ap.help("List all the operators of MIGraphX"),
           ap.set_value(true));
    }
    void run() const
    {
        if(show_ops)
        {
            for(const auto& name : get_operators())
                std::cout << name << std::endl;
        }
697
698
699
700
701
702
        else
        {
            auto op = load_op(op_name);
            std::cout << op_name << ": " << std::endl;
            std::cout << to_pretty_json_string(op.to_value()) << std::endl;
        }
703
    }
Paul's avatar
Paul committed
704
705
};

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
struct onnx : command<onnx>
{
    bool show_ops = false;
    void parse(argument_parser& ap)
    {
        ap(show_ops,
           {"--list", "-l"},
           ap.help("List all onnx operators supported by MIGraphX"),
           ap.set_value(true));
    }
    void run() const
    {
        if(show_ops)
        {
            for(const auto& name : get_onnx_operators())
                std::cout << name << std::endl;
        }
    }
};

726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
struct tf : command<tf>
{
    bool show_ops = false;
    void parse(argument_parser& ap)
    {
        ap(show_ops,
           {"--list", "-l"},
           ap.help("List all tf operators supported by MIGraphX"),
           ap.set_value(true));
    }
    void run() const
    {
        if(show_ops)
        {
            for(const auto& name : get_tf_operators())
                std::cout << name << std::endl;
        }
    }
};

Paul's avatar
Paul committed
746
747
struct main_command
{
748
749
    static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
                                                                            "COMMANDS:"))
Paul's avatar
Paul committed
750
    {
751
752
753
754
755
756
757
758
759
760
        std::string result = title + "\n";
        std::vector<std::string> commands(get_commands().size());
        std::transform(get_commands().begin(),
                       get_commands().end(),
                       commands.begin(),
                       [](const auto& p) { return colorize(color::fg_green, p.first); });
        std::sort(commands.begin(), commands.end());
        return std::accumulate(commands.begin(), commands.end(), result, [](auto r, auto&& s) {
            return r + "    " + s + "\n";
        });
Paul's avatar
Paul committed
761
    }
Paul's avatar
Paul committed
762
    void parse(argument_parser& ap)
Paul's avatar
Paul committed
763
    {
kahmed10's avatar
kahmed10 committed
764
        std::string version_str = get_version();
765
        ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
Paul's avatar
Paul committed
766
        ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
767
768
769
770
        ap(nullptr,
           {"-v", "--version"},
           ap.help("Show MIGraphX version"),
           ap.show_help(version_str));
kahmed10's avatar
kahmed10 committed
771
        ap(nullptr, {"--ort-sha"}, ap.help("Show MIGraphX onnx runtime SHA"));
772
773
774
775

        // Trim command off of exe name
        ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
        ap.set_exe_name_to(exe_name);
Paul's avatar
Paul committed
776
777
    }

778
779
780
781
782
783
784
785
786
787
788
789
790
    std::vector<std::string> wrong_commands{};
    std::string exe_name = "<exe>";

    void run()
    {
        std::cout << color::fg_red << color::bold << "error: " << color::reset;
        auto it = std::find_if(wrong_commands.begin(), wrong_commands.end(), [](const auto& c) {
            return get_commands().count(c) > 0;
        });
        if(it == wrong_commands.end())
        {
            std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
                      << "' is not a valid command." << std::endl;
791
            std::cout << get_command_help("Available commands:");
792
793
794
795
796
797
798
799
800
801
802
803
        }
        else
        {
            std::cout << "command '" << color::fg_yellow << *it << color::reset
                      << "' must be first argument" << std::endl;
            std::cout << std::endl;

            std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
            std::cout << "    " << exe_name << " " << *it << " <options>" << std::endl;
        }
        std::cout << std::endl;
    }
Paul's avatar
Paul committed
804
805
};

Paul's avatar
Paul committed
806
807
808
809
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx

Paul's avatar
Paul committed
810
using namespace migraphx::driver; // NOLINT
Paul's avatar
Paul committed
811
812
int main(int argc, const char* argv[])
{
Paul's avatar
Paul committed
813
    std::vector<std::string> args(argv + 1, argv + argc);
Shucai Xiao's avatar
Shucai Xiao committed
814
    // no argument, print the help infomration by default
Paul's avatar
Paul committed
815
    if(args.empty())
Shucai Xiao's avatar
Shucai Xiao committed
816
817
818
819
    {
        args.push_back("-h");
    }

Paul's avatar
Paul committed
820
    auto&& m = get_commands();
Paul's avatar
Paul committed
821
    auto cmd = args.front();
822

kahmed10's avatar
kahmed10 committed
823
    if(cmd == "--ort-sha")
824
825
826
827
    {
        std::cout << MIGRAPHX_ORT_SHA1 << std::endl;
        return 0;
    }
kahmed10's avatar
kahmed10 committed
828
829
830
831
832
    if(cmd == "-v" or cmd == "--version")
    {
        std::cout << get_version() << std::endl;
        return 0;
    }
833

Paul's avatar
Paul committed
834
    if(m.count(cmd) > 0)
Paul's avatar
Paul committed
835
    {
kahmed10's avatar
kahmed10 committed
836
837
838
839
840
841
842
843
        std::string driver_invocation =
            std::string(argv[0]) + " " + migraphx::to_string_range(args, " ");
        std::cout << "Running [ " << get_version() << " ]: " << driver_invocation << std::endl;

        m.at(cmd)(argv[0],
                  {args.begin() + 1, args.end()}); // run driver command found in commands map

        std::cout << "[ " << get_version() << " ] Complete: " << driver_invocation << std::endl;
Paul's avatar
Paul committed
844
    }
Paul's avatar
Paul committed
845
    else
Paul's avatar
Paul committed
846
    {
847
        run_command<main_command>(argv[0], args);
Paul's avatar
Paul committed
848
    }
Shucai Xiao's avatar
Shucai Xiao committed
849

Paul's avatar
Paul committed
850
851
    return 0;
}