mlir.cpp 20.6 KB
Newer Older
1
#include <migraphx/gpu/mlir.hpp>
Paul's avatar
Paul committed
2

Paul's avatar
Paul committed
3
#ifdef MIGRAPHX_MLIR
Paul's avatar
Paul committed
4
5
6
7
8
9
#include <mlir-c/IR.h>
#include <mlir-c/BuiltinAttributes.h>
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/IntegerSet.h>
Paul's avatar
Paul committed
10
#include <mlir-c/Pass.h>
Paul's avatar
Paul committed
11
#include <mlir-c/Registration.h>
Paul's avatar
Paul committed
12
#endif
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
15
16
#include <migraphx/manage_ptr.hpp>
#include <migraphx/module.hpp>
Paul's avatar
Paul committed
17
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
18
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
19
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
20
#include <migraphx/gpu/code_object_op.hpp>
Paul's avatar
Paul committed
21
#include <migraphx/gpu/context.hpp>
Paul's avatar
Paul committed
22
#include <migraphx/gpu/device_name.hpp>
Paul's avatar
Paul committed
23
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
24
#include <migraphx/gpu/perfdb.hpp>
Paul's avatar
Paul committed
25
26
#include <deque>
#include <variant>
Paul's avatar
Paul committed
27
28
29

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
30
namespace gpu {
Paul's avatar
Paul committed
31

Paul's avatar
Paul committed
32
33
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);

Paul's avatar
Paul committed
34
#ifdef MIGRAPHX_MLIR
Paul's avatar
Paul committed
35
template <class T, class F, F f> // NOLINT
Paul's avatar
Paul committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
struct mlir_handle
{
    struct ptr
    {
        ptr() = default;
        ptr(std::nullptr_t) {}
        ptr(T x) : obj(x) {}

        std::intptr_t get_value() const
        {
            static_assert(sizeof(T) == sizeof(std::intptr_t), "MLIR Handle different size");
            return reinterpret_cast<const std::intptr_t&>(obj);
        }

Paul's avatar
Paul committed
50
        T get() const { return obj; }
Paul's avatar
Paul committed
51

Paul's avatar
Paul committed
52
        friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
Paul's avatar
Paul committed
53

Paul's avatar
Paul committed
54
        friend bool operator!=(ptr x, ptr y) { return !(x == y); }
Paul's avatar
Paul committed
55
56
        T obj{};
    };
Paul's avatar
Paul committed
57

Paul's avatar
Paul committed
58
59
60
61
62
63
64
65
66
67
68
69
70
    struct deleter
    {
        using pointer = ptr;

        void operator()(pointer x) const
        {
            if(x != nullptr)
            {
                (void)f(x.obj);
            }
        }
    };

Paul's avatar
Paul committed
71
    mlir_handle() : handle(nullptr) {}
Paul's avatar
Paul committed
72

Paul's avatar
Paul committed
73
    mlir_handle(T p) : handle(ptr{p}) {}
Paul's avatar
Paul committed
74

Paul's avatar
Paul committed
75
    T get() const { return handle.get().get(); }
Paul's avatar
Paul committed
76

Paul's avatar
Paul committed
77
    T release() { return handle.release().get(); }
Paul's avatar
Paul committed
78

Paul's avatar
Paul committed
79
    private:
Paul's avatar
Paul committed
80
81
82
    std::unique_ptr<ptr, deleter> handle;
};

83
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
Paul's avatar
Paul committed
84

Paul's avatar
Paul committed
85
86
87
88
89
90
91
using mlir_context           = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_module            = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy);
using mlir_operation         = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
                                                           mlirOpPrintingFlagsDestroy);
using mlir_region            = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRegion, mlirRegionDestroy);
using mlir_block             = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockDestroy);
Paul's avatar
Paul committed
92
using mlir_pass_manager      = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy);
Paul's avatar
Paul committed
93

Paul's avatar
Paul committed
94
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
Paul's avatar
Paul committed
95
96
97
98
99
100

MlirStringRef make_mlir_string_ref(const std::string_view& s)
{
    return mlirStringRefCreate(s.data(), s.size());
}

Paul's avatar
Paul committed
101
template <class F, class T, class Printer>
Paul's avatar
Paul committed
102
103
void mlir_print(F f, T x, Printer printer)
{
Paul's avatar
Format  
Paul committed
104
105
106
107
108
109
    f(
        x,
        +[](MlirStringRef s, void* data) {
            (*reinterpret_cast<Printer*>(data))(to_string_view(s));
        },
        &printer);
Paul's avatar
Paul committed
110
111
}

Paul's avatar
Paul committed
112
template <class F, class T>
Paul's avatar
Paul committed
113
114
115
116
117
void mlir_print(F f, T x, std::ostream& os)
{
    mlir_print(f, x, [&](auto s) { os << s; });
}

Paul's avatar
Paul committed
118
119
120
121
122
123
124
125
template <class F, class T>
std::string mlir_print(F f, T x)
{
    std::stringstream ss;
    mlir_print(f, x, [&](auto s) { ss << s; });
    return ss.str();
}

Paul's avatar
Paul committed
126
127
struct mlir_program
{
Paul's avatar
Paul committed
128
129
130
131
    mlir_program()
        : ctx(mlirContextCreate()),
          location(mlirLocationUnknownGet(ctx.get())),
          mmodule(mlirModuleCreateEmpty(location))
Paul's avatar
Paul committed
132
    {
Paul's avatar
Paul committed
133
134
        MlirDialectHandle mixr_handle = mlirGetDialectHandle__migraphx__();
        mlirDialectHandleRegisterDialect(mixr_handle, ctx.get());
Paul's avatar
Paul committed
135
        mlirRegisterAllDialects(ctx.get());
Paul's avatar
Paul committed
136
        mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
Paul's avatar
Paul committed
137
138
139
140
141
142
    }

    MlirType make_type(shape::type_t t) const
    {
        MlirType result;
        shape::visit(t, [&](auto as) {
Paul's avatar
Paul committed
143
            if(as.type_enum() == shape::float_type)
Paul's avatar
Paul committed
144
                result = mlirF32TypeGet(ctx.get());
Paul's avatar
Paul committed
145
            else if(as.type_enum() == shape::half_type)
Paul's avatar
Paul committed
146
                result = mlirF16TypeGet(ctx.get());
Paul's avatar
Paul committed
147
            else if(as.type_enum() == shape::double_type)
Paul's avatar
Paul committed
148
                result = mlirF64TypeGet(ctx.get());
Paul's avatar
Paul committed
149
            else if(as.is_integral())
Paul's avatar
Paul committed
150
            {
Paul's avatar
Paul committed
151
                if(as.is_signed())
Paul's avatar
Paul committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
                    result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8);
                else
                    result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
            }
            else
                MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
        });
        return result;
    }

    MlirType make_tensor(const shape& s) const
    {
        assert(s.standard());
        std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
Paul's avatar
Format  
Paul committed
166
167
        return mlirRankedTensorTypeGet(
            lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull());
Paul's avatar
Paul committed
168
169
    }

Paul's avatar
Paul committed
170
    template <class Range>
Paul's avatar
Paul committed
171
172
173
174
175
176
177
178
179
180
181
    std::vector<MlirType> make_tensors(const Range& r)
    {
        std::vector<MlirType> result;
        std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) {
            return make_tensor(s);
        });
        return result;
    }

    MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs)
    {
Paul's avatar
Paul committed
182
        auto in  = make_tensors(inputs);
Paul's avatar
Paul committed
183
184
185
186
        auto out = make_tensors(outputs);
        return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data());
    }

Paul's avatar
Paul committed
187
188
189
190
191
192
193
    MlirIdentifier id(const std::string_view& s) const
    {
        return mlirIdentifierGet(ctx.get(), make_mlir_string_ref(s));
    }

    MlirAttribute attribute(std::int64_t i) const
    {
Paul's avatar
Format  
Paul committed
194
        if(i < 0)
195
196
            MIGRAPHX_THROW("MLIR cant handle negative values since they are ambiguous");
        return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i);
Paul's avatar
Paul committed
197
    }
Paul's avatar
Paul committed
198
199
    MlirAttribute attribute(std::uint64_t i) const
    {
Paul's avatar
Format  
Paul committed
200
        if(i > (std::numeric_limits<std::uint64_t>::max() / 2))
201
202
            MIGRAPHX_THROW("MLIR cant handle large integer values since they are ambiguous");
        return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i);
Paul's avatar
Paul committed
203
204
    }
    MlirAttribute attribute(unsigned char i) const { return attribute(std::uint64_t(i)); }
Paul's avatar
Paul committed
205
    MlirAttribute attribute(bool b) const { return mlirBoolAttrGet(ctx.get(), b ? 1 : 0); }
Paul's avatar
Paul committed
206
207
208
209
210
211
212
213
    MlirAttribute attribute(double d) const
    {
        return mlirFloatAttrDoubleGet(ctx.get(), mlirF64TypeGet(ctx.get()), d);
    }
    MlirAttribute attribute(const std::string& s) const
    {
        return mlirStringAttrGet(ctx.get(), make_mlir_string_ref(s));
    }
Paul's avatar
Paul committed
214
215
    MlirAttribute attribute(std::nullptr_t) const { return {}; }
    template <class T>
Paul's avatar
Paul committed
216
217
218
219
220
221
222
223
224
225
226
227
    MlirAttribute attribute(const std::vector<T>& v) const
    {
        std::vector<MlirAttribute> attributes;
        attributes.reserve(v.size());
        std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](auto&& x) {
            return attribute(x);
        });
        return mlirArrayAttrGet(ctx.get(), attributes.size(), attributes.data());
    }
    MlirAttribute attribute(const value& v) const
    {
        MlirAttribute attr;
Paul's avatar
Paul committed
228
        v.visit_value([&](auto&& x) { attr = attribute(x); });
Paul's avatar
Paul committed
229
230
231
232
233
234
235
236
237
238
        return attr;
    }
    MlirAttribute attribute(const std::vector<value>& v) const
    {
        if(v.empty())
        {
            return mlirArrayAttrGet(ctx.get(), 0, nullptr);
        }
        if(not v.front().get_key().empty())
        {
Paul's avatar
Paul committed
239
            std::vector<MlirNamedAttribute> attributes = name_attributes(v);
Paul's avatar
Paul committed
240
241
242
243
244
245
246
247
248
249
250
251
252
            return mlirDictionaryAttrGet(ctx.get(), attributes.size(), attributes.data());
        }
        else
        {
            std::vector<MlirAttribute> attributes;
            attributes.reserve(v.size());
            std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](auto&& x) {
                return attribute(x);
            });
            return mlirArrayAttrGet(ctx.get(), attributes.size(), attributes.data());
        }
    }

Paul's avatar
Paul committed
253
    MlirAttribute attribute(MlirType t) const { return mlirTypeAttrGet(t); }
Paul's avatar
Paul committed
254

Paul's avatar
Paul committed
255
256
    MlirAttribute attribute(MlirAttribute a) const { return a; }

Paul's avatar
Paul committed
257
    template <class T>
Paul's avatar
Paul committed
258
259
260
261
262
263
264
265
    MlirNamedAttribute name_attribute(const std::string_view& key, const T& x) const
    {
        MlirNamedAttribute attr;
        attr.name      = id(key);
        attr.attribute = attribute(x);
        return attr;
    }

Paul's avatar
Paul committed
266
267
268
269
270
271
272
273
274
    using attribute_t       = std::variant<std::nullptr_t,
                                     std::uint64_t,
                                     unsigned char,
                                     bool,
                                     double,
                                     std::string,
                                     value,
                                     std::vector<value>,
                                     MlirType>;
Paul's avatar
Paul committed
275
276
277
278
    using named_attribute_t = std::pair<std::string_view, attribute_t>;

    MlirNamedAttribute name_attribute(const named_attribute_t& na) const
    {
Paul's avatar
Paul committed
279
280
        return name_attribute(na.first,
                              std::visit([&](const auto& x) { return attribute(x); }, na.second));
Paul's avatar
Paul committed
281
282
    }

Paul's avatar
Paul committed
283
284
    std::vector<MlirNamedAttribute>
    name_attributes(const std::vector<named_attribute_t>& named_attrs) const
Paul's avatar
Paul committed
285
286
287
    {
        std::vector<MlirNamedAttribute> attributes;
        attributes.reserve(named_attrs.size());
Paul's avatar
Paul committed
288
289
290
291
        std::transform(named_attrs.begin(),
                       named_attrs.end(),
                       std::back_inserter(attributes),
                       [&](const named_attribute_t& a) { return name_attribute(a); });
Paul's avatar
Paul committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        return attributes;
    }

    std::vector<MlirNamedAttribute> name_attributes(const value& v) const
    {
        std::vector<MlirNamedAttribute> attributes;
        attributes.reserve(v.size());
        std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](const value& x) {
            return name_attribute(x.get_key(), x.without_key());
        });
        return attributes;
    }

    struct mlir_operation_state
    {
Paul's avatar
Paul committed
307
308
309
310
        mlir_operation_state(mlir_program& p, const std::string_view& name)
            : prog(&p), op_state(mlirOperationStateGet(make_mlir_string_ref(name), p.location))
        {
        }
Paul's avatar
Paul committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

        mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs)
        {
            auto attributes = prog->name_attributes(named_attrs);
            mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data());
            return *this;
        }

        mlir_operation_state& add_attribute_value(const value& v)
        {
            auto attributes = prog->name_attributes(v);
            mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data());
            return *this;
        }

        mlir_operation_state& add_regions(std::vector<mlir_region> rs)
        {
            regions = std::move(rs);
            return *this;
        }

        mlir_operation_state& add_region(mlir_region r)
        {
            regions.emplace_back(std::move(r));
            return *this;
        }

        mlir_operation_state& add_results(const std::vector<shape>& outputs)
        {
            auto x = prog->make_tensors(outputs);
            mlirOperationStateAddResults(&op_state, x.size(), x.data());
            return *this;
        }

        mlir_operation_state& add_operands(const std::vector<MlirValue>& inputs)
        {
            mlirOperationStateAddOperands(&op_state, inputs.size(), inputs.data());
            return *this;
        }

        mlir_operation create_operation()
        {
            std::vector<MlirRegion> mregions(regions.size());
            std::transform(regions.begin(), regions.end(), mregions.begin(), [](const auto& r) {
                return r.get();
            });
            mlirOperationStateAddOwnedRegions(&op_state, mregions.size(), mregions.data());
            mlir_operation op(mlirOperationCreate(&op_state));
            // Release memory since mlir_operation owns it
Paul's avatar
Paul committed
360
            for(auto& r : regions)
Paul's avatar
Paul committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
                r.release();
            regions.clear();
            return op;
        }

        mlir_program* prog;
        MlirOperationState op_state;
        std::vector<mlir_region> regions = {};
    };

    mlir_operation_state create_operation_state(const std::string_view& name)
    {
        return {*this, name};
    }

    std::vector<MlirValue> insert(MlirBlock body, mlir_operation_state ops)
    {
        std::vector<MlirValue> result;
        mlir_operation op = ops.create_operation();
Paul's avatar
Paul committed
380
        auto weak_op      = op.get();
Paul's avatar
Paul committed
381
        mlirBlockAppendOwnedOperation(body, op.release());
Paul's avatar
Paul committed
382
383
384
385
386
387
388
389
390

        auto n = mlirOperationGetNumResults(weak_op);
        result.reserve(n);
        transform(range(n), std::back_inserter(result), [&](auto i) {
            return mlirOperationGetResult(weak_op, i);
        });
        return result;
    }

Paul's avatar
Paul committed
391
392
    MlirBlock
    insert(MlirBlock body, const module& m, std::unordered_map<instruction_ref, MlirValue>& ins_map)
Paul's avatar
Paul committed
393
394
    {
        auto names = m.get_parameter_names();
Paul's avatar
Paul committed
395
        std::sort(names.begin(), names.end());
Paul's avatar
Paul committed
396
        std::vector<shape> inputs;
Paul's avatar
Paul committed
397
398
399
400
        std::transform(names.begin(),
                       names.end(),
                       std::back_inserter(inputs),
                       [&](const std::string& name) { return m.get_parameter_shape(name); });
Paul's avatar
Paul committed
401
402
        std::vector<shape> outputs = m.get_output_shapes();

Paul's avatar
Paul committed
403
        std::vector<MlirLocation> arg_locs(inputs.size(), location);
Paul's avatar
Paul committed
404
        auto body_inputs   = make_tensors(inputs);
Paul's avatar
Paul committed
405
        mlir_region region = mlirRegionCreate();
Paul's avatar
Format  
Paul committed
406
407
        mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data());
        MlirBlock result = fbody.get();
Paul's avatar
Paul committed
408
409
        mlirRegionAppendOwnedBlock(region.get(), fbody.release());

Paul's avatar
Paul committed
410
        auto ops = create_operation_state("func.func");
Paul's avatar
Format  
Paul committed
411
412
413
        ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
                            {"sym_name", std::string("main")},
                            {"kernel", std::string("mixr")}});
Paul's avatar
Paul committed
414
415
416
        ops.add_region(std::move(region));
        insert(body, std::move(ops));

Paul's avatar
Paul committed
417
        for(auto i : range(names.size()))
Paul's avatar
Paul committed
418
419
420
421
            ins_map[m.get_parameter(names[i])] = mlirBlockGetArgument(result, i);
        return result;
    }

Paul's avatar
Paul committed
422
423
    static std::string get_name(instruction_ref ins)
    {
Paul's avatar
Format  
Paul committed
424
        if(ins->name() == "@return")
Paul's avatar
Paul committed
425
            return "func.return";
Paul's avatar
Paul committed
426
427
428
        return "migraphx." + ins->name();
    }

Paul's avatar
Paul committed
429
430
431
    static value get_operator_value(const operation& op)
    {
        auto v = op.to_value();
Paul's avatar
Format  
Paul committed
432
        if(op.name() == "convolution")
Paul's avatar
Paul committed
433
434
        {
            // Adjust symetrical padding
Paul's avatar
Format  
Paul committed
435
            if(v.at("padding").size() == v.at("stride").size())
Paul's avatar
Paul committed
436
437
438
439
440
441
442
443
            {
                auto padding = v.at("padding");
                std::copy(padding.begin(), padding.end(), std::back_inserter(v.at("padding")));
            }
        }
        return v;
    }

Paul's avatar
Paul committed
444
445
    static shape get_shape(instruction_ref ins)
    {
Paul's avatar
Format  
Paul committed
446
        if(ins->name() == "@return")
Paul's avatar
Paul committed
447
448
449
450
451
452
453
        {
            assert(ins->inputs().size() == 1);
            return ins->inputs().front()->get_shape();
        }
        return ins->get_shape();
    }

Paul's avatar
Paul committed
454
455
456
457
458
    void parse(const module& m)
    {
        auto mbody = mlirModuleGetBody(mmodule.get());
        std::unordered_map<instruction_ref, MlirValue> ins_map;
        auto fbody = insert(mbody, m, ins_map);
Paul's avatar
Paul committed
459
        for(auto ins : iterator_for(m))
Paul's avatar
Paul committed
460
        {
Paul's avatar
Paul committed
461
            if(ins->name() == "@param")
Paul's avatar
Paul committed
462
                continue;
Paul's avatar
Paul committed
463
            auto name = get_name(ins);
Paul's avatar
Paul committed
464
            auto ops  = create_operation_state(name);
Paul's avatar
Paul committed
465
            ops.add_attribute_value(get_operator_value(ins->get_operator()));
466
467
            if(ins->name() != "@return")
                ops.add_results({get_shape(ins)});
Paul's avatar
Format  
Paul committed
468
            if(ins->name())
Paul's avatar
Paul committed
469
                pp = {ins->get_operator(), ins->inputs(), ins->get_shape()};
Paul's avatar
Paul committed
470
471

            std::vector<MlirValue> inputs;
Paul's avatar
Paul committed
472
473
            transform(
                ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); });
Paul's avatar
Paul committed
474
475
476
            ops.add_operands(inputs);

            auto outputs = insert(fbody, std::move(ops));
477
478
479
480
481
            if(ins->name() != "@return")
            {
                assert(outputs.size() == 1);
                ins_map[ins] = outputs.front();
            }
Paul's avatar
Paul committed
482
483
484
        }
    }

Paul's avatar
Paul committed
485
    code_object_op compile() MIGRAPHX_TIDY_CONST
Paul's avatar
Paul committed
486
487
488
489
490
    {
        mlir_pass_manager pm{mlirPassManagerCreate(ctx.get())};
        // 1st pipeline to call
        mlirMIGraphXAddHighLevelPipeline(pm.get());
        // 2nd pipeline to call
Paul's avatar
Paul committed
491
492
        std::string tname = get_device_name();
        // HACK: Since MLIR can't handle the full target name
Paul's avatar
Paul committed
493
        auto hacked_tname = tname.substr(0, tname.find(':'));
Paul's avatar
Format  
Paul committed
494
495
496
497
498
        if(tname.size() != hacked_tname.size())
            std::cout
                << "*************** WARNING: MLIR may not compile the correct target features for: "
                << tname << std::endl;
        mlirMIGraphXAddBackendPipeline(pm.get(), hacked_tname.c_str(), "amdgcn-amd-amdhsa", "");
Paul's avatar
Paul committed
499
500
        mlirPassManagerRun(pm.get(), mmodule.get());

Paul's avatar
Paul committed
501
        code_object_op op{};
Paul's avatar
Paul committed
502
        op.symbol_name                = "main";
Paul's avatar
Format  
Paul committed
503
        op.code_object                = get_binary();
Paul's avatar
Paul committed
504
505
506
507
        std::tie(op.global, op.local) = get_launch_params();
        return op;
    }

Paul's avatar
Paul committed
508
    std::pair<std::size_t, std::size_t> get_launch_params() const
Paul's avatar
Paul committed
509
    {
Paul's avatar
Paul committed
510
        uint32_t attrs[2];
Paul's avatar
Paul committed
511
512
513
514
515
516
517
        // returns block and grid sizes
        mlirGetKernelAttrs(mmodule.get(), attrs);
        std::size_t local  = attrs[0];
        std::size_t global = local * attrs[1];
        return {global, local};
    }

Paul's avatar
Paul committed
518
    value::binary get_binary() const
Paul's avatar
Paul committed
519
    {
Paul's avatar
Paul committed
520
521
522
523
        int size = 0;
        mlirGetBinary(mmodule.get(), &size, nullptr);
        value::binary result(size);
        if(mlirGetBinary(mmodule.get(), &size, reinterpret_cast<char*>(result.data())))
Paul's avatar
Paul committed
524
525
526
527
            return result;
        MIGRAPHX_THROW("Failed to compile mlir program");
    }

Paul's avatar
Format  
Paul committed
528
    std::string get_tune_params() { return get_mlir_perf_for_conv(pp); }
Paul's avatar
Paul committed
529

Paul's avatar
Paul committed
530
    mlir_context ctx;
Paul's avatar
Paul committed
531
532
    MlirLocation location;
    mlir_module mmodule;
Paul's avatar
Paul committed
533
    problem_params pp;
Paul's avatar
Paul committed
534
    std::deque<std::string> strings{};
Paul's avatar
Paul committed
535
536
};

Paul's avatar
Paul committed
537
538
539
540
541
542
543
544
std::string dump_mlir(const module& m)
{
    mlir_program mp;
    mp.parse(m);
    auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
    return mlir_print(&mlirOperationPrint, mod_op);
}

Paul's avatar
Paul committed
545
code_object_op compile_mlir(const context&, const module& m)
Paul's avatar
Paul committed
546
{
Paul's avatar
Paul committed
547
    const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
Paul's avatar
Format  
Paul committed
548
    if(trace)
Paul's avatar
Paul committed
549
        std::cout << m << std::endl;
Paul's avatar
Paul committed
550
551
    mlir_program mp;
    mp.parse(m);
Paul's avatar
Paul committed
552
    auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
Paul's avatar
Format  
Paul committed
553
    if(trace)
Paul's avatar
Paul committed
554
        std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
Paul's avatar
Format  
Paul committed
555
    auto co   = mp.compile();
Paul's avatar
Paul committed
556
557
    co.output = m.get_output_shapes().front();
    return co;
Paul's avatar
Paul committed
558
559
}

Paul's avatar
Format  
Paul committed
560
561
instruction_ref insert_mlir(module& m,
                            instruction_ref ins,
Paul's avatar
Paul committed
562
                            code_object_op co,
Paul's avatar
Format  
Paul committed
563
                            const std::vector<instruction_ref>& inputs)
Paul's avatar
Paul committed
564
565
566
567
568
569
570
571
572
573
574
575
576
577
{
    std::vector<instruction_ref> refs;
    refs.reserve(inputs.size() * 15);

    std::unordered_map<uint64_t, instruction_ref> literal_map{};
    auto get_literal = [&](uint64_t value) {
        auto fi = literal_map.find(value);
        if(fi != literal_map.end())
            return fi->second;
        auto lit = m.add_literal(value);
        literal_map.emplace(value, lit);
        return lit;
    };

Paul's avatar
Paul committed
578
    std::size_t last = 0;
Paul's avatar
Format  
Paul committed
579
    for(auto input : inputs)
Paul's avatar
Paul committed
580
581
    {
        const size_t offset = 0;
Paul's avatar
Format  
Paul committed
582
        auto s              = input->get_shape();
Paul's avatar
Format  
Paul committed
583
        last                = refs.size();
Paul's avatar
Paul committed
584
585
586
587
588
589
590
591
592
        refs.push_back(input);
        refs.push_back(input);
        refs.push_back(get_literal(offset)); // offset

        // dim sizes
        std::transform(s.lens().begin(),
                       s.lens().end(),
                       std::back_inserter(refs),
                       [&](const auto& lval) { return get_literal(lval); });
Paul's avatar
Updates  
Paul committed
593
        // refs.push_back(get_literal(1)); // G
Paul's avatar
Paul committed
594
595
596
597
598
599

        // dim strides
        std::transform(s.strides().begin(),
                       s.strides().end(),
                       std::back_inserter(refs),
                       [&](const auto& lval) { return get_literal(lval); });
Paul's avatar
Updates  
Paul committed
600
        // refs.push_back(get_literal(1)); // G
Paul's avatar
Paul committed
601
602
    }
    co.expected_inputs = to_shapes(refs);
Paul's avatar
Format  
Paul committed
603
    co.output_arg      = last;
Paul's avatar
Paul committed
604
605
606
    return m.insert_instruction(ins, co, refs);
}

Paul's avatar
Paul committed
607
608
#else

Paul's avatar
Format  
Paul committed
609
std::string dump_mlir(const module&) { return {}; }
Paul's avatar
Paul committed
610

Paul's avatar
Paul committed
611
612
code_object_op compile_mlir(const context&, const module&) { return {}; }

Paul's avatar
Format  
Paul committed
613
614
615
616
template <class T>
void use(T&)
{
}
Paul's avatar
Paul committed
617

Paul's avatar
Format  
Paul committed
618
instruction_ref
Paul's avatar
Paul committed
619
// cppcheck-suppress funcArgNamesDifferent
Paul's avatar
Format  
Paul committed
620
insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
621
{
Paul's avatar
Paul committed
622
    use(co);
Paul's avatar
Paul committed
623
624
625
    return m.end();
}

Paul's avatar
Paul committed
626
627
#endif

628
} // namespace gpu
Paul's avatar
Paul committed
629
630
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx