Commit 8e50db0f authored by Paul's avatar Paul
Browse files

Formatting

parent 45b4f134
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template<class T, class F, F f> template <class T, class F, F f>
struct mlir_handle struct mlir_handle
{ {
struct ptr struct ptr
...@@ -30,23 +30,14 @@ struct mlir_handle ...@@ -30,23 +30,14 @@ struct mlir_handle
return reinterpret_cast<const std::intptr_t&>(obj); return reinterpret_cast<const std::intptr_t&>(obj);
} }
T get() const T get() const { return obj; }
{
return obj;
}
friend bool operator==(ptr x, ptr y) friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
{
return x.get_value() == y.get_value();
}
friend bool operator!=(ptr x, ptr y) friend bool operator!=(ptr x, ptr y) { return !(x == y); }
{
return !(x == y);
}
T obj{}; T obj{};
}; };
struct deleter struct deleter
{ {
using pointer = ptr; using pointer = ptr;
...@@ -60,57 +51,44 @@ struct mlir_handle ...@@ -60,57 +51,44 @@ struct mlir_handle
} }
}; };
mlir_handle() mlir_handle() : handle(nullptr) {}
: handle(nullptr)
{}
mlir_handle(T p) mlir_handle(T p) : handle(ptr{p}) {}
: handle(ptr{p})
{}
T get() const T get() const { return handle.get().get(); }
{
return handle.get().get();
}
T release() const T release() const { return handle.release().get(); }
{
return handle.release().get();
}
private: private:
std::unique_ptr<ptr, deleter> handle; std::unique_ptr<ptr, deleter> handle;
}; };
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) \ #define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::mlir_handle<T, decltype(&F), &F> // NOLINT
migraphx::mlir_handle<T, decltype(&F), &F> // NOLINT
using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy); using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy); using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy);
using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy); using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags, mlirOpPrintingFlagsDestroy); using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
using mlir_region = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRegion, mlirRegionDestroy); mlirOpPrintingFlagsDestroy);
using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockDestroy); using mlir_region = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRegion, mlirRegionDestroy);
using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockDestroy);
std::string_view to_string_view(MlirStringRef s) std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
{
return {s.data, s.length};
}
MlirStringRef make_mlir_string_ref(const std::string_view& s) MlirStringRef make_mlir_string_ref(const std::string_view& s)
{ {
return mlirStringRefCreate(s.data(), s.size()); return mlirStringRefCreate(s.data(), s.size());
} }
template<class F, class T, class Printer> template <class F, class T, class Printer>
void mlir_print(F f, T x, Printer printer) void mlir_print(F f, T x, Printer printer)
{ {
f(x, +[](MlirStringRef s, void* data) { f(x,
(*reinterpret_cast<Printer*>(data))(to_string_view(s)); +[](MlirStringRef s, void* data) { (*reinterpret_cast<Printer*>(data))(to_string_view(s)); },
}, &printer); &printer);
} }
template<class F, class T> template <class F, class T>
void mlir_print(F f, T x, std::ostream& os) void mlir_print(F f, T x, std::ostream& os)
{ {
mlir_print(f, x, [&](auto s) { os << s; }); mlir_print(f, x, [&](auto s) { os << s; });
...@@ -118,26 +96,25 @@ void mlir_print(F f, T x, std::ostream& os) ...@@ -118,26 +96,25 @@ void mlir_print(F f, T x, std::ostream& os)
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program() : ctx(mlirContextCreate())
: ctx(mlirContextCreate())
{ {
mlirRegisterAllDialects(ctx.get()); mlirRegisterAllDialects(ctx.get());
mlirContextSetAllowUnregisteredDialects(ctx.get(), true/*allow*/); mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
} }
MlirType make_type(shape::type_t t) const MlirType make_type(shape::type_t t) const
{ {
MlirType result; MlirType result;
shape::visit(t, [&](auto as) { shape::visit(t, [&](auto as) {
if (as.type_enum() == shape::float_type) if(as.type_enum() == shape::float_type)
result = mlirF32TypeGet(ctx.get()); result = mlirF32TypeGet(ctx.get());
else if (as.type_enum() == shape::half_type) else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get()); result = mlirF16TypeGet(ctx.get());
else if (as.type_enum() == shape::double_type) else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get()); result = mlirF64TypeGet(ctx.get());
else if (as.is_integral()) else if(as.is_integral())
{ {
if (as.is_signed()) if(as.is_signed())
result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8); result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8);
else else
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8); result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
...@@ -164,18 +141,9 @@ struct mlir_program ...@@ -164,18 +141,9 @@ struct mlir_program
{ {
return mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx.get(), 64), i); return mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx.get(), 64), i);
} }
MlirAttribute attribute(std::uint64_t i) const MlirAttribute attribute(std::uint64_t i) const { return attribute(std::int64_t(i)); }
{ MlirAttribute attribute(unsigned char i) const { return attribute(std::int64_t(i)); }
return attribute(std::int64_t(i)); MlirAttribute attribute(bool b) const { return mlirBoolAttrGet(ctx.get(), b); }
}
MlirAttribute attribute(unsigned char i) const
{
return attribute(std::int64_t(i));
}
MlirAttribute attribute(bool b) const
{
return mlirBoolAttrGet(ctx.get(), b);
}
MlirAttribute attribute(double d) const MlirAttribute attribute(double d) const
{ {
return mlirFloatAttrDoubleGet(ctx.get(), mlirF64TypeGet(ctx.get()), d); return mlirFloatAttrDoubleGet(ctx.get(), mlirF64TypeGet(ctx.get()), d);
...@@ -184,11 +152,8 @@ struct mlir_program ...@@ -184,11 +152,8 @@ struct mlir_program
{ {
return mlirStringAttrGet(ctx.get(), make_mlir_string_ref(s)); return mlirStringAttrGet(ctx.get(), make_mlir_string_ref(s));
} }
MlirAttribute attribute(std::nullptr_t) const MlirAttribute attribute(std::nullptr_t) const { return {}; }
{ template <class T>
return {};
}
template<class T>
MlirAttribute attribute(const std::vector<T>& v) const MlirAttribute attribute(const std::vector<T>& v) const
{ {
std::vector<MlirAttribute> attributes; std::vector<MlirAttribute> attributes;
...@@ -201,9 +166,7 @@ struct mlir_program ...@@ -201,9 +166,7 @@ struct mlir_program
MlirAttribute attribute(const value& v) const MlirAttribute attribute(const value& v) const
{ {
MlirAttribute attr; MlirAttribute attr;
v.visit_value([&](auto&& x) { v.visit_value([&](auto&& x) { attr = attribute(x); });
attr = attribute(x);
});
return attr; return attr;
} }
MlirAttribute attribute(const std::vector<value>& v) const MlirAttribute attribute(const std::vector<value>& v) const
...@@ -218,7 +181,7 @@ struct mlir_program ...@@ -218,7 +181,7 @@ struct mlir_program
attributes.reserve(v.size()); attributes.reserve(v.size());
std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](auto&& x) { std::transform(v.begin(), v.end(), std::back_inserter(attributes), [&](auto&& x) {
MlirNamedAttribute attr; MlirNamedAttribute attr;
attr.name = id(x.get_key()); attr.name = id(x.get_key());
attr.attribute = attribute(x.without_key()); attr.attribute = attribute(x.without_key());
return attr; return attr;
}); });
...@@ -238,7 +201,5 @@ struct mlir_program ...@@ -238,7 +201,5 @@ struct mlir_program
mlir_context ctx; mlir_context ctx;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment