main.cpp 25.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24

kahmed10's avatar
kahmed10 committed
25
#include "verify.hpp"
Paul's avatar
Paul committed
26
27
#include "argument_parser.hpp"
#include "command.hpp"
kahmed10's avatar
kahmed10 committed
28
#include "precision.hpp"
Paul's avatar
Paul committed
29
#include "perf.hpp"
30
#include "models.hpp"
31
#include "marker_roctx.hpp"
Paul's avatar
Paul committed
32

Paul's avatar
Paul committed
33
34
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
35
#include <migraphx/py.hpp>
Paul's avatar
Paul committed
36
#include <migraphx/stringutils.hpp>
37
#include <migraphx/convert_to_json.hpp>
38
39
#include <migraphx/load_save.hpp>
#include <migraphx/json.hpp>
40
#include <migraphx/version.h>
Paul's avatar
Paul committed
41

42
43
44
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
45
46
#include <migraphx/generate.hpp>
#include <migraphx/pass_manager.hpp>
47
#include <migraphx/propagate_constant.hpp>
48
#include <migraphx/quantization.hpp>
49
#include <migraphx/register_op.hpp>
50
51
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
52
#include <migraphx/register_target.hpp>
53

54
55
#include <fstream>

Paul's avatar
Paul committed
56
57
58
59
60
61
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {

struct loader
{
62
    std::string model;
Paul's avatar
Paul committed
63
    std::string file;
Paul's avatar
Paul committed
64
    std::string file_type;
65
66
67
68
69
    unsigned batch              = 1;
    bool is_nhwc                = true;
    unsigned trim               = 0;
    bool optimize               = false;
    bool skip_unknown_operators = false;
70
71
72
    bool brief                  = false;
    std::string output_type;
    std::string output;
73
    std::string default_dyn_dim;
Shucai Xiao's avatar
Shucai Xiao committed
74
    std::vector<std::string> param_dims;
75
    std::vector<std::string> dyn_param_dims;
kahmed10's avatar
kahmed10 committed
76
    std::vector<std::string> output_names;
Paul's avatar
Paul committed
77
78
79

    void parse(argument_parser& ap)
    {
80
81
82
83
84
85
        ap(file, {}, ap.metavar("<input file>"), ap.file_exist(), ap.required(), ap.group("input"));
        ap(model,
           {"--model"},
           ap.help("Load model"),
           ap.type("resnet50|inceptionv3|alexnet"),
           ap.group("input"));
Paul's avatar
Paul committed
86
87
        ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
        ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
88
89
        ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
        ap(file_type, {"--migraphx-json"}, ap.help("Load as MIGraphX JSON"), ap.set_value("json"));
90
91
92
93
94
        ap(batch,
           {"--batch"},
           ap.help("For a static model, sets default_dim_value size (commonly batch size). For a "
                   "dynamic batch model, sets the batch "
                   "size at runtime."));
Paul's avatar
Paul committed
95
        ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
96
97
98
99
        ap(skip_unknown_operators,
           {"--skip-unknown-operators"},
           ap.help("Skip unknown operators when parsing and continue to parse."),
           ap.set_value(true));
Paul's avatar
Paul committed
100
        ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
Paul's avatar
Paul committed
101
        ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
Shucai Xiao's avatar
Shucai Xiao committed
102
103
104
105
106
        ap(param_dims,
           {"--input-dim"},
           ap.help("Dim of a parameter (format: \"@name d1 d2 dn\")"),
           ap.append(),
           ap.nargs(2));
107
108
109
110
111
112
113
114
115
116
        ap(dyn_param_dims,
           {"--dyn-input-dim"},
           ap.help("Dynamic dimensions of a parameter (format: \"@name_1\" \"[{min:x, max:y, "
                   "optimals:[o1,o2,...]}, dim2,dim3, ...]\", \"@name_2\", ... You can supply a "
                   "single integer value for a dimension to specify it as fixed."),
           ap.append(),
           ap.nargs(2));
        ap(default_dyn_dim,
           {"--default-dyn-dim"},
           ap.help("Default dynamic dimension (format: \"{min:x, max:y, optimals:[o1,o2]}\")."));
kahmed10's avatar
kahmed10 committed
117
118
119
120
121
        ap(output_names,
           {"--output-names"},
           ap.help("Names of node output (format: \"name_1 name_2 name_n\")"),
           ap.append(),
           ap.nargs(2));
122
        ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
123
124
125
126
127
128
129
        ap(output_type,
           {"--graphviz", "-g"},
           ap.help("Print out a graphviz representation."),
           ap.set_value("graphviz"));
        ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
        ap(output_type,
           {"--cpp"},
130
           ap.help("Print out the program as C++ program."),
131
           ap.set_value("cpp"));
132
133
134
135
        ap(output_type,
           {"--python", "--py"},
           ap.help("Print out the program as python program."),
           ap.set_value("py"));
136
137
138
139
140
141
142
143
144
145
        ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
        ap(output_type,
           {"--text"},
           ap.help("Print out program in text format."),
           ap.set_value("text"));
        ap(output_type,
           {"--binary"},
           ap.help("Print out program in binary format."),
           ap.set_value("binary"));
        ap(output, {"--output", "-o"}, ap.help("Output to file."));
Paul's avatar
Paul committed
146
147
    }

Shucai Xiao's avatar
Shucai Xiao committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    static auto parse_param_dims(const std::vector<std::string>& param_dims_info)
    {
        std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
        std::string name = "";
        for(auto&& x : param_dims_info)
        {
            if(x[0] == '@')
            {
                name = x.substr(1);
            }
            else
            {
                map_input_dims[name].push_back(value_parser<std::size_t>::apply(x));
            }
        }

        return map_input_dims;
    }

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    static auto parse_dyn_dims_json(const std::string& dd_json)
    {
        // expecting a json string like "[{min:1,max:64,optimals:[1,2,4,8]},3,224,224]"
        auto v = from_json_string(convert_to_json(dd_json));
        std::vector<migraphx::shape::dynamic_dimension> dyn_dims;
        std::transform(v.begin(), v.end(), std::back_inserter(dyn_dims), [&](auto x) {
            if(x.is_object())
                return from_value<migraphx::shape::dynamic_dimension>(x);
            auto d = x.template to<std::size_t>();
            return migraphx::shape::dynamic_dimension{d, d};
        });
        return dyn_dims;
    }

    static auto parse_dyn_dims_map(const std::vector<std::string>& param_dyn_dims)
    {
        // expecting vector of strings formatted like
        // {"@param_name_0", "dd_json_0", "@param_name_1", "dd_json_1", ...}
        std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
        std::string name = "";
        for(auto&& x : param_dyn_dims)
        {
            if(x[0] == '@')
            {
                name = x.substr(1);
            }
            else
            {
                map_dyn_input_dims[name] = parse_dyn_dims_json(x);
            }
        }
        return map_dyn_input_dims;
    }

kahmed10's avatar
kahmed10 committed
201
202
203
204
205
206
207
208
209
210
211
    static auto parse_output_names(const std::vector<std::string>& output_names_info)
    {
        std::vector<std::string> output_node_names;
        std::transform(output_names_info.begin(),
                       output_names_info.end(),
                       std::back_inserter(output_node_names),
                       [&](auto x) { return value_parser<std::string>::apply(x); });

        return output_node_names;
    }

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    tf_options get_tf_options() const
    {
        auto map_input_dims    = parse_param_dims(param_dims);
        auto output_node_names = parse_output_names(output_names);
        tf_options options;
        options.is_nhwc           = is_nhwc;
        options.batch_size        = batch;
        options.map_input_dims    = map_input_dims;
        options.output_node_names = output_node_names;
        return options;
    }

    onnx_options get_onnx_options() const
    {
        auto map_input_dims     = parse_param_dims(param_dims);
        auto map_dyn_input_dims = parse_dyn_dims_map(dyn_param_dims);
        onnx_options options;
        if(default_dyn_dim.empty())
        {
            options.default_dim_value = batch;
        }
        else
        {
            auto v                        = from_json_string(convert_to_json(default_dyn_dim));
            options.default_dyn_dim_value = from_value<migraphx::shape::dynamic_dimension>(v);
        }
        options.skip_unknown_operators = skip_unknown_operators;
        options.print_program_on_error = true;
        options.map_input_dims         = map_input_dims;
        options.map_dyn_input_dims     = map_dyn_input_dims;
        return options;
    }

245
246
247
248
249
250
251
252
253
254
255
256
257
258
    static std::string get_file_type(const std::string& file)
    {
        if(ends_with(file, ".onnx"))
            return "onnx";
        else if(ends_with(file, ".pb"))
            return "tf";
        else if(ends_with(file, ".json"))
            return "json";
        else if(ends_with(file, ".py"))
            return "py";
        else
            return "migraphx";
    }

Paul's avatar
Paul committed
259
    program load()
Paul's avatar
Paul committed
260
261
    {
        program p;
262
        if(model.empty())
Paul's avatar
Paul committed
263
        {
264
265
            if(file_type.empty())
            {
266
                file_type = get_file_type(file);
267
268
269
            }
            std::cout << "Reading: " << file << std::endl;
            if(file_type == "onnx")
270
            {
271
                p = parse_onnx(file, get_onnx_options());
272
            }
273
            else if(file_type == "tf")
274
            {
275
                p = parse_tf(file, get_tf_options());
276
            }
277
278
279
280
281
282
            else if(file_type == "json")
            {
                file_options options;
                options.format = "json";
                p              = migraphx::load(file, options);
            }
283
284
285
286
            else if(file_type == "py")
            {
                p = migraphx::load_py(file);
            }
287
288
289
290
            else if(file_type == "migraphx")
            {
                p = migraphx::load(file);
            }
291
292
293
294
295
296
297
298
299
300
301
        }
        else
        {
            if(model == "resnet50")
                p = resnet50(batch);
            else if(model == "inceptionv3")
                p = inceptionv3(batch);
            else if(model == "alexnet")
                p = alexnet(batch);
            else
                MIGRAPHX_THROW("Unknown model: " + model);
Paul's avatar
Paul committed
302
        }
Paul's avatar
Paul committed
303
        if(trim > 0)
Paul's avatar
Paul committed
304
        {
305
            auto* mm  = p.get_main_module();
Shucai Xiao's avatar
Shucai Xiao committed
306
307
            auto last = std::prev(mm->end(), trim);
            mm->remove_instructions(last, mm->end());
Paul's avatar
Paul committed
308
        }
309
310
311
        // Remove unused variable when exporting to cpp
        if(output_type == "cpp")
            migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
Paul's avatar
Paul committed
312
        if(optimize)
313
314
        {
            migraphx::run_passes(*p.get_main_module(),
Paul's avatar
Paul committed
315
316
317
318
319
320
321
322
323
324
325
326
                                 {
                                     migraphx::eliminate_identity{},
                                     migraphx::dead_code_elimination{},
                                     migraphx::simplify_algebra{},
                                     migraphx::dead_code_elimination{},
                                     migraphx::simplify_reshapes{},
                                     migraphx::dead_code_elimination{},
                                     migraphx::propagate_constant{},
                                     migraphx::dead_code_elimination{},
                                     migraphx::eliminate_pad{},
                                     migraphx::dead_code_elimination{},
                                 });
327
        }
Paul's avatar
Paul committed
328
329
        return p;
    }
330
331
332
333
334
335

    static void write(std::ostream& os, const std::vector<char>& buffer)
    {
        os.write(buffer.data(), buffer.size());
    }

336
    void save(const program& p) const
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    {
        auto* os = &std::cout;
        std::ofstream fs;
        if(not output.empty())
        {
            fs.open(output);
            os = &fs;
        }

        std::string type = output_type;
        if(type.empty())
        {
            if(output.empty())
                type = "text";
            else
                type = "binary";
        }

355
356
357
        if(type == "py")
            p.print_py(*os);
        else if(type == "cpp")
358
359
360
361
362
363
364
365
366
367
            p.print_cpp(*os);
        else if(type == "graphviz")
            p.print_graph(*os, brief);
        else if(type == "text")
            *os << p << std::endl;
        else if(type == "json")
            *os << to_json_string(p.to_value()) << std::endl;
        else if(type == "binary")
            write(*os, save_buffer(p));
    }
Paul's avatar
Paul committed
368
369
};

370
371
372
373
374
375
struct program_params
{
    std::vector<std::string> fill0{};
    std::vector<std::string> fill1{};
    void parse(argument_parser& ap)
    {
Shucai Xiao's avatar
Shucai Xiao committed
376
377
        ap(fill0, {"--fill0"}, ap.help("Fill parameter with 0s"), ap.append(), ap.nargs(2));
        ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append(), ap.nargs(2));
378
379
    }

380
    auto generate(const program& p, const target& t, bool offload, unsigned batch)
381
    {
382
        parameter_map m;
383
384
385
386
387
388
389
        auto param_shapes = p.get_parameter_shapes();
        std::unordered_map<std::string, shape> static_param_shapes;
        std::transform(
            param_shapes.cbegin(),
            param_shapes.cend(),
            std::inserter(static_param_shapes, static_param_shapes.end()),
            [&](const auto& x) { return std::make_pair(x.first, x.second.to_static(batch)); });
390
        for(auto&& s : fill0)
391
            m[s] = fill_argument(static_param_shapes.at(s), 0);
392
        for(auto&& s : fill1)
393
394
            m[s] = fill_argument(static_param_shapes.at(s), 1);
        fill_param_map(m, static_param_shapes, t, offload);
395
396
397
398
        return m;
    }
};

399
400
401
402
struct compiler_target
{
#ifdef HAVE_GPU
    std::string target_name = "gpu";
403
#elif defined(HAVE_CPU)
404
    std::string target_name = "cpu";
405
406
#elif defined(HAVE_FPGA)
    std::string target_name = "fpga";
407
#else
408
    std::string target_name = "ref";
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
#endif

    void parse(argument_parser& ap)
    {
        ap(target_name, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value("gpu"));
        ap(target_name, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value("cpu"));
        ap(target_name,
           {"--ref"},
           ap.help("Compile on the reference implementation"),
           ap.set_value("ref"));
    }

    target get_target() const { return make_target(target_name); }
};

Paul's avatar
Paul committed
424
425
426
struct compiler
{
    loader l;
427
    program_params parameters;
428
    compiler_target ct;
429
    compile_options co;
430
431
    bool to_fp16 = false;
    bool to_int8 = false;
432

kahmed10's avatar
kahmed10 committed
433
    std::vector<std::string> fill0;
Paul's avatar
Paul committed
434
    std::vector<std::string> fill1;
Paul's avatar
Paul committed
435
436
    void parse(argument_parser& ap)
    {
Paul's avatar
Paul committed
437
        l.parse(ap);
438
        parameters.parse(ap);
439
        ct.parse(ap);
440
        ap(co.offload_copy,
441
442
           {"--enable-offload-copy"},
           ap.help("Enable implicit offload copying"),
443
           ap.set_value(true));
444
        ap(co.fast_math,
kahmed10's avatar
kahmed10 committed
445
446
447
           {"--disable-fast-math"},
           ap.help("Disable fast math optimization"),
           ap.set_value(false));
448
449
450
451
        ap(co.exhaustive_tune,
           {"--exhaustive-tune"},
           ap.help("Exhastively search for best tuning parameters for kernels"),
           ap.set_value(true));
452
453
        ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
        ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
Paul's avatar
Paul committed
454
455
    }

456
457
    auto params(const program& p)
    {
458
        return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch);
459
    }
460

461
462
463
464
465
    auto host_params(const program& p)
    {
        return parameters.generate(p, ct.get_target(), true, l.batch);
    }

466
467
468
    program compile()
    {
        auto p = l.load();
469
        // Dont compile if its already been compiled
470

471
        if(p.is_compiled())
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
        {
            if(ct.target_name == "gpu")
            {
                if(is_offload_copy_set(p) and not co.offload_copy)
                {
                    std::cout << "MIGraphX program was likely compiled with offload_copy set, Try "
                                 "passing "
                                 "`--enable-offload-copy` if program run fails.\n";
                }
                else if(co.offload_copy)
                {
                    std::cout << "MIGraphX program was likely compiled without "
                                 "offload_copy set, Try "
                                 "removing "
                                 "`--enable-offload-copy` flag if passed to driver, if program run "
                                 "fails.\n";
                }
            }

491
            return p;
492
        }
493
        auto t = ct.get_target();
494
        if(to_fp16)
495
496
497
        {
            quantize_fp16(p);
        }
498
        if(to_int8)
499
        {
500
            quantize_int8(p, t, {host_params(p)});
501
        }
502
        p.compile(t, co);
503
        l.save(p);
504
505
        return p;
    }
Paul's avatar
Paul committed
506
507
};

Paul's avatar
Paul committed
508
509
510
struct read : command<read>
{
    loader l;
511
    void parse(argument_parser& ap) { l.parse(ap); }
Paul's avatar
Paul committed
512
513
514
515

    void run()
    {
        auto p = l.load();
516
        l.save(p);
Paul's avatar
Paul committed
517
518
519
    }
};

Paul's avatar
Paul committed
520
521
522
523
524
525
526
527
struct params : command<params>
{
    loader l;
    void parse(argument_parser& ap) { l.parse(ap); }

    void run()
    {
        auto p = l.load();
Paul's avatar
Paul committed
528
        for(auto&& param : p.get_parameter_shapes())
Paul's avatar
Paul committed
529
530
531
532
            std::cout << param.first << ": " << param.second << std::endl;
    }
};

Paul's avatar
Paul committed
533
534
struct verify : command<verify>
{
535
    compiler c;
Paul's avatar
Paul committed
536
    double tolerance     = 80;
Paul's avatar
Paul committed
537
    bool per_instruction = false;
Paul's avatar
Paul committed
538
539
540
    bool reduce          = false;
    void parse(argument_parser& ap)
    {
541
        c.parse(ap);
Paul's avatar
Paul committed
542
543
        ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
        ap(per_instruction,
Paul's avatar
Paul committed
544
545
546
547
           {"-i", "--per-instruction"},
           ap.help("Verify each instruction"),
           ap.set_value(true));
        ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
Paul's avatar
Paul committed
548
549
550
551
    }

    void run()
    {
552
553
        auto p = c.l.load();
        c.l.save(p);
Paul's avatar
Paul committed
554
555
        std::cout << p << std::endl;

556
        auto t = c.ct.get_target();
557
        auto m = c.parameters.generate(p, t, true, c.l.batch);
558

559
560
561
562
563
564
        auto quantize = precision::fp32;
        if(c.to_fp16)
            quantize = precision::fp16;
        if(c.to_int8)
            quantize = precision::int8;

Paul's avatar
Paul committed
565
566
        if(per_instruction)
        {
567
            verify_instructions(p, t, c.co, quantize, tolerance);
Paul's avatar
Paul committed
568
569
570
        }
        else if(reduce)
        {
571
            verify_reduced_program(p, t, c.co, quantize, m, tolerance);
Paul's avatar
Paul committed
572
573
574
        }
        else
        {
575
            verify_program(c.l.file, p, t, c.co, quantize, m, tolerance);
Paul's avatar
Paul committed
576
577
578
579
        }
    }
};

580
581
582
583
584
585
struct version : command<version>
{
    void parse(const argument_parser&) {}
    void run() const
    {
        std::cout << "MIGraphX Version: " << MIGRAPHX_VERSION_MAJOR << "." << MIGRAPHX_VERSION_MINOR
586
587
                  << "." << MIGRAPHX_VERSION_PATCH << "."
                  << MIGRAPHX_STRINGIZE(MIGRAPHX_VERSION_TWEAK) << std::endl;
588
589
590
    }
};

Paul's avatar
Paul committed
591
592
593
struct compile : command<compile>
{
    compiler c;
Paul's avatar
Paul committed
594
    void parse(argument_parser& ap) { c.parse(ap); }
Paul's avatar
Paul committed
595
596
597
598

    void run()
    {
        std::cout << "Compiling ... " << std::endl;
599
        c.compile();
Paul's avatar
Paul committed
600
601
602
603
604
605
    }
};

struct run_cmd : command<run_cmd>
{
    compiler c;
Paul's avatar
Paul committed
606
    void parse(argument_parser& ap) { c.parse(ap); }
Paul's avatar
Paul committed
607
608
609
610
611
612
613

    void run()
    {
        std::cout << "Compiling ... " << std::endl;
        auto p = c.compile();
        std::cout << "Allocating params ... " << std::endl;
        auto m = c.params(p);
Paul's avatar
Paul committed
614
        p.eval(m);
Paul's avatar
Paul committed
615
616
617
618
        std::cout << p << std::endl;
    }
};

Paul's avatar
Paul committed
619
620
621
622
struct perf : command<perf>
{
    compiler c;
    unsigned n = 100;
Paul's avatar
Paul committed
623
624
    void parse(argument_parser& ap)
    {
Paul's avatar
Paul committed
625
626
627
628
629
630
631
632
633
634
635
        c.parse(ap);
        ap(n, {"--iterations", "-n"}, ap.help("Number of iterations to run for perf report"));
    }

    void run()
    {
        std::cout << "Compiling ... " << std::endl;
        auto p = c.compile();
        std::cout << "Allocating params ... " << std::endl;
        auto m = c.params(p);
        std::cout << "Running performance report ... " << std::endl;
636
        p.perf_report(std::cout, n, m, c.l.batch);
Paul's avatar
Paul committed
637
    }
638
639
};

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
struct roctx : command<roctx>
{
    compiler c;
    void parse(argument_parser& ap) { c.parse(ap); }

    void run()
    {
        std::cout << "Compiling ... " << std::endl;
        auto p = c.compile();
        std::cout << "Allocating params ... " << std::endl;
        auto m = c.params(p);
        std::cout << "rocTX:\tLoading rocTX library..." << std::endl;
        auto rtx = create_marker_roctx();
        p.mark(m, std::move(rtx));
    }
};

657
658
659
struct op : command<op>
{
    bool show_ops = false;
660
    std::string op_name{};
661
662
    void parse(argument_parser& ap)
    {
663
        ap(op_name, {}, ap.metavar("<MIGraphX operator name>"));
664
665
666
667
668
669
670
671
672
673
674
675
        ap(show_ops,
           {"--list", "-l"},
           ap.help("List all the operators of MIGraphX"),
           ap.set_value(true));
    }
    void run() const
    {
        if(show_ops)
        {
            for(const auto& name : get_operators())
                std::cout << name << std::endl;
        }
676
677
678
679
680
681
        else
        {
            auto op = load_op(op_name);
            std::cout << op_name << ": " << std::endl;
            std::cout << to_pretty_json_string(op.to_value()) << std::endl;
        }
682
    }
Paul's avatar
Paul committed
683
684
};

685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
struct onnx : command<onnx>
{
    bool show_ops = false;
    void parse(argument_parser& ap)
    {
        ap(show_ops,
           {"--list", "-l"},
           ap.help("List all onnx operators supported by MIGraphX"),
           ap.set_value(true));
    }
    void run() const
    {
        if(show_ops)
        {
            for(const auto& name : get_onnx_operators())
                std::cout << name << std::endl;
        }
    }
};

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
struct tf : command<tf>
{
    bool show_ops = false;
    void parse(argument_parser& ap)
    {
        ap(show_ops,
           {"--list", "-l"},
           ap.help("List all tf operators supported by MIGraphX"),
           ap.set_value(true));
    }
    void run() const
    {
        if(show_ops)
        {
            for(const auto& name : get_tf_operators())
                std::cout << name << std::endl;
        }
    }
};

Paul's avatar
Paul committed
725
726
struct main_command
{
727
728
    static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
                                                                            "COMMANDS:"))
Paul's avatar
Paul committed
729
    {
730
731
732
733
734
735
736
737
738
739
        std::string result = title + "\n";
        std::vector<std::string> commands(get_commands().size());
        std::transform(get_commands().begin(),
                       get_commands().end(),
                       commands.begin(),
                       [](const auto& p) { return colorize(color::fg_green, p.first); });
        std::sort(commands.begin(), commands.end());
        return std::accumulate(commands.begin(), commands.end(), result, [](auto r, auto&& s) {
            return r + "    " + s + "\n";
        });
Paul's avatar
Paul committed
740
    }
Paul's avatar
Paul committed
741
    void parse(argument_parser& ap)
Paul's avatar
Paul committed
742
    {
743
        std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) +
744
745
746
                                  "." + std::to_string(MIGRAPHX_VERSION_MINOR) + "." +
                                  std::to_string(MIGRAPHX_VERSION_PATCH) + "." +
                                  MIGRAPHX_STRINGIZE(MIGRAPHX_VERSION_TWEAK);
747
        ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
Paul's avatar
Paul committed
748
        ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
749
750
751
752
        ap(nullptr,
           {"-v", "--version"},
           ap.help("Show MIGraphX version"),
           ap.show_help(version_str));
753
754
755
756

        // Trim command off of exe name
        ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
        ap.set_exe_name_to(exe_name);
Paul's avatar
Paul committed
757
758
    }

759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
    std::vector<std::string> wrong_commands{};
    std::string exe_name = "<exe>";

    void run()
    {
        std::cout << color::fg_red << color::bold << "error: " << color::reset;
        auto it = std::find_if(wrong_commands.begin(), wrong_commands.end(), [](const auto& c) {
            return get_commands().count(c) > 0;
        });
        if(it == wrong_commands.end())
        {
            std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
                      << "' is not a valid command." << std::endl;
            std::cout << get_command_help("Available commands:") << std::endl;
        }
        else
        {
            std::cout << "command '" << color::fg_yellow << *it << color::reset
                      << "' must be first argument" << std::endl;
            std::cout << std::endl;

            std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
            std::cout << "    " << exe_name << " " << *it << " <options>" << std::endl;
        }
        std::cout << std::endl;
    }
Paul's avatar
Paul committed
785
786
};

Paul's avatar
Paul committed
787
788
789
790
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx

Paul's avatar
Paul committed
791
using namespace migraphx::driver; // NOLINT
Paul's avatar
Paul committed
792
793
int main(int argc, const char* argv[])
{
Paul's avatar
Paul committed
794
    std::vector<std::string> args(argv + 1, argv + argc);
Shucai Xiao's avatar
Shucai Xiao committed
795
796

    // no argument, print the help infomration by default
Paul's avatar
Paul committed
797
    if(args.empty())
Shucai Xiao's avatar
Shucai Xiao committed
798
799
800
801
    {
        args.push_back("-h");
    }

Paul's avatar
Paul committed
802
    auto&& m = get_commands();
Paul's avatar
Paul committed
803
    auto cmd = args.front();
Paul's avatar
Paul committed
804
    if(m.count(cmd) > 0)
Paul's avatar
Paul committed
805
    {
806
        m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
Paul's avatar
Paul committed
807
    }
Paul's avatar
Paul committed
808
    else
Paul's avatar
Paul committed
809
    {
810
        run_command<main_command>(argv[0], args);
Paul's avatar
Paul committed
811
    }
Shucai Xiao's avatar
Shucai Xiao committed
812

Paul's avatar
Paul committed
813
814
    return 0;
}