mlir.cpp 7.32 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

#include <mlir-c/IR.h>
#include <mlir-c/BuiltinAttributes.h>
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/Standard.h>
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Registration.h>

#include <migraphx/manage_ptr.hpp>
#include <migraphx/module.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
18
template <class T, class F, F f>
Paul's avatar
Paul committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
struct mlir_handle
{
    struct ptr
    {
        ptr() = default;
        ptr(std::nullptr_t) {}
        ptr(T x) : obj(x) {}

        std::intptr_t get_value() const
        {
            static_assert(sizeof(T) == sizeof(std::intptr_t), "MLIR Handle different size");
            return reinterpret_cast<const std::intptr_t&>(obj);
        }

Paul's avatar
Paul committed
33
        T get() const { return obj; }
Paul's avatar
Paul committed
34

Paul's avatar
Paul committed
35
        friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
Paul's avatar
Paul committed
36

Paul's avatar
Paul committed
37
        friend bool operator!=(ptr x, ptr y) { return !(x == y); }
Paul's avatar
Paul committed
38
39
        T obj{};
    };
Paul's avatar
Paul committed
40

Paul's avatar
Paul committed
41
42
43
44
45
46
47
48
49
50
51
52
53
    struct deleter
    {
        using pointer = ptr;

        void operator()(pointer x) const
        {
            if(x != nullptr)
            {
                (void)f(x.obj);
            }
        }
    };

Paul's avatar
Paul committed
54
    mlir_handle() : handle(nullptr) {}
Paul's avatar
Paul committed
55

Paul's avatar
Paul committed
56
    mlir_handle(T p) : handle(ptr{p}) {}
Paul's avatar
Paul committed
57

Paul's avatar
Paul committed
58
    T get() const { return handle.get().get(); }
Paul's avatar
Paul committed
59

Paul's avatar
Paul committed
60
    T release() const { return handle.release().get(); }
Paul's avatar
Paul committed
61

Paul's avatar
Paul committed
62
    private:
Paul's avatar
Paul committed
63
64
65
    std::unique_ptr<ptr, deleter> handle;
};

Paul's avatar
Paul committed
66
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::mlir_handle<T, decltype(&F), &F> // NOLINT
Paul's avatar
Paul committed
67

Paul's avatar
Paul committed
68
69
70
71
72
73
74
using mlir_context           = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_module            = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy);
using mlir_operation         = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
                                                           mlirOpPrintingFlagsDestroy);
using mlir_region            = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRegion, mlirRegionDestroy);
using mlir_block             = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockDestroy);
Paul's avatar
Paul committed
75

Paul's avatar
Paul committed
76
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
Paul's avatar
Paul committed
77
78
79
80
81
82

MlirStringRef make_mlir_string_ref(const std::string_view& s)
{
    return mlirStringRefCreate(s.data(), s.size());
}

Paul's avatar
Paul committed
83
template <class F, class T, class Printer>
Paul's avatar
Paul committed
84
85
void mlir_print(F f, T x, Printer printer)
{
Paul's avatar
Paul committed
86
87
88
    f(x,
      +[](MlirStringRef s, void* data) { (*reinterpret_cast<Printer*>(data))(to_string_view(s)); },
      &printer);
Paul's avatar
Paul committed
89
90
}

Paul's avatar
Paul committed
91
template <class F, class T>
Paul's avatar
Paul committed
92
93
94
95
96
97
98
void mlir_print(F f, T x, std::ostream& os)
{
    mlir_print(f, x, [&](auto s) { os << s; });
}

struct mlir_program
{
Paul's avatar
Paul committed
99
    mlir_program() : ctx(mlirContextCreate())
Paul's avatar
Paul committed
100
101
    {
        mlirRegisterAllDialects(ctx.get());
Paul's avatar
Paul committed
102
        mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
Paul's avatar
Paul committed
103
104
105
106
107
108
    }

    MlirType make_type(shape::type_t t) const
    {
        MlirType result;
        shape::visit(t, [&](auto as) {
Paul's avatar
Paul committed
109
            if(as.type_enum() == shape::float_type)
Paul's avatar
Paul committed
110
                result = mlirF32TypeGet(ctx.get());
Paul's avatar
Paul committed
111
            else if(as.type_enum() == shape::half_type)
Paul's avatar
Paul committed
112
                result = mlirF16TypeGet(ctx.get());
Paul's avatar
Paul committed
113
            else if(as.type_enum() == shape::double_type)
Paul's avatar
Paul committed
114
                result = mlirF64TypeGet(ctx.get());
Paul's avatar
Paul committed
115
            else if(as.is_integral())
Paul's avatar
Paul committed
116
            {
Paul's avatar
Paul committed
117
                if(as.is_signed())
Paul's avatar
Paul committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
                    result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8);
                else
                    result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
            }
            else
                MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
        });
        return result;
    }

    MlirType make_tensor(const shape& s) const
    {
        assert(s.standard());
        std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
        return mlirRankedTensorTypeGet(lens.size(), lens.data(), make_type(s.type()));
    }

Paul's avatar
Paul committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    template<class Range>
    std::vector<MlirType> make_tensors(const Range& r)
    {
        std::vector<MlirType> result;
        std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) {
            return make_tensor(s);
        });
        return result;
    }

    MlirType make_function_type(const std::vector<shape>& inputs, const std::vector<shape>& outputs)
    {
        auto in = make_tensors(inputs);
        auto out = make_tensors(outputs);
        return mlirFunctionTypeGet(ctx.get(), in.size(), in.data(), out.size(), out.data());
    }

Paul's avatar
Paul committed
152
153
154
155
156
157
158
159
160
    MlirIdentifier id(const std::string_view& s) const
    {
        return mlirIdentifierGet(ctx.get(), make_mlir_string_ref(s));
    }

    MlirAttribute attribute(std::int64_t i) const
    {
        return mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx.get(), 64), i);
    }
Paul's avatar
Paul committed
161
162
163
    MlirAttribute attribute(std::uint64_t i) const { return attribute(std::int64_t(i)); }
    MlirAttribute attribute(unsigned char i) const { return attribute(std::int64_t(i)); }
    MlirAttribute attribute(bool b) const { return mlirBoolAttrGet(ctx.get(), b); }
Paul's avatar
Paul committed
164
165
166
167
168
169
170
171
    MlirAttribute attribute(double d) const
    {
        return mlirFloatAttrDoubleGet(ctx.get(), mlirF64TypeGet(ctx.get()), d);
    }
    MlirAttribute attribute(const std::string& s) const
    {
        return mlirStringAttrGet(ctx.get(), make_mlir_string_ref(s));
    }
Paul's avatar
Paul committed
172
173
    MlirAttribute attribute(std::nullptr_t) const { return {}; }
    template <class T>
Paul's avatar
Paul committed
174
175
176
177
178
179
180
181
182
183
184
185
    MlirAttribute attribute(const std::vector<T>& v) const
    {
        std::vector<MlirAttribute> attributes;
        attributes.reserve(v.size());
        std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](auto&& x) {
            return attribute(x);
        });
        return mlirArrayAttrGet(ctx.get(), attributes.size(), attributes.data());
    }
    MlirAttribute attribute(const value& v) const
    {
        MlirAttribute attr;
Paul's avatar
Paul committed
186
        v.visit_value([&](auto&& x) { attr = attribute(x); });
Paul's avatar
Paul committed
187
188
189
190
191
192
193
194
195
196
197
198
199
        return attr;
    }
    MlirAttribute attribute(const std::vector<value>& v) const
    {
        if(v.empty())
        {
            return mlirArrayAttrGet(ctx.get(), 0, nullptr);
        }
        if(not v.front().get_key().empty())
        {
            std::vector<MlirNamedAttribute> attributes;
            attributes.reserve(v.size());
            std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](auto&& x) {
Paul's avatar
Paul committed
200
                return name_attribute(x.get_key(), x.without_key());
Paul's avatar
Paul committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            });
            return mlirDictionaryAttrGet(ctx.get(), attributes.size(), attributes.data());
        }
        else
        {
            std::vector<MlirAttribute> attributes;
            attributes.reserve(v.size());
            std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](auto&& x) {
                return attribute(x);
            });
            return mlirArrayAttrGet(ctx.get(), attributes.size(), attributes.data());
        }
    }

Paul's avatar
Paul committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    MlirAttribute attribute(MlirType t) const
    {
        return mlirTypeAttrGet(t);
    }

    template<class T>
    MlirNamedAttribute name_attribute(const std::string_view& key, const T& x) const
    {
        MlirNamedAttribute attr;
        attr.name      = id(key);
        attr.attribute = attribute(x);
        return attr;
    }


Paul's avatar
Paul committed
230
231
232
233
234
    mlir_context ctx;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx