Commit 8e50db0f authored by Paul's avatar Paul
Browse files

Formatting

parent 45b4f134
......@@ -15,7 +15,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template<class T, class F, F f>
template <class T, class F, F f>
struct mlir_handle
{
struct ptr
......@@ -30,20 +30,11 @@ struct mlir_handle
return reinterpret_cast<const std::intptr_t&>(obj);
}
T get() const
{
return obj;
}
T get() const { return obj; }
friend bool operator==(ptr x, ptr y)
{
return x.get_value() == y.get_value();
}
friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
friend bool operator!=(ptr x, ptr y)
{
return !(x == y);
}
friend bool operator!=(ptr x, ptr y) { return !(x == y); }
T obj{};
};
......@@ -60,57 +51,44 @@ struct mlir_handle
}
};
mlir_handle()
: handle(nullptr)
{}
mlir_handle() : handle(nullptr) {}
mlir_handle(T p)
: handle(ptr{p})
{}
mlir_handle(T p) : handle(ptr{p}) {}
T get() const
{
return handle.get().get();
}
T get() const { return handle.get().get(); }
T release() const
{
return handle.release().get();
}
T release() const { return handle.release().get(); }
private:
private:
std::unique_ptr<ptr, deleter> handle;
};
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) \
migraphx::mlir_handle<T, decltype(&F), &F> // NOLINT
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::mlir_handle<T, decltype(&F), &F> // NOLINT
using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy);
using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags, mlirOpPrintingFlagsDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
mlirOpPrintingFlagsDestroy);
using mlir_region = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRegion, mlirRegionDestroy);
using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockDestroy);
std::string_view to_string_view(MlirStringRef s)
{
return {s.data, s.length};
}
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
MlirStringRef make_mlir_string_ref(const std::string_view& s)
{
return mlirStringRefCreate(s.data(), s.size());
}
template<class F, class T, class Printer>
template <class F, class T, class Printer>
void mlir_print(F f, T x, Printer printer)
{
f(x, +[](MlirStringRef s, void* data) {
(*reinterpret_cast<Printer*>(data))(to_string_view(s));
}, &printer);
f(x,
+[](MlirStringRef s, void* data) { (*reinterpret_cast<Printer*>(data))(to_string_view(s)); },
&printer);
}
template<class F, class T>
template <class F, class T>
void mlir_print(F f, T x, std::ostream& os)
{
mlir_print(f, x, [&](auto s) { os << s; });
......@@ -118,26 +96,25 @@ void mlir_print(F f, T x, std::ostream& os)
struct mlir_program
{
mlir_program()
: ctx(mlirContextCreate())
mlir_program() : ctx(mlirContextCreate())
{
mlirRegisterAllDialects(ctx.get());
mlirContextSetAllowUnregisteredDialects(ctx.get(), true/*allow*/);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
}
MlirType make_type(shape::type_t t) const
{
MlirType result;
shape::visit(t, [&](auto as) {
if (as.type_enum() == shape::float_type)
if(as.type_enum() == shape::float_type)
result = mlirF32TypeGet(ctx.get());
else if (as.type_enum() == shape::half_type)
else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get());
else if (as.type_enum() == shape::double_type)
else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get());
else if (as.is_integral())
else if(as.is_integral())
{
if (as.is_signed())
if(as.is_signed())
result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8);
else
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
......@@ -164,18 +141,9 @@ struct mlir_program
{
return mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx.get(), 64), i);
}
MlirAttribute attribute(std::uint64_t i) const
{
return attribute(std::int64_t(i));
}
MlirAttribute attribute(unsigned char i) const
{
return attribute(std::int64_t(i));
}
MlirAttribute attribute(bool b) const
{
return mlirBoolAttrGet(ctx.get(), b);
}
MlirAttribute attribute(std::uint64_t i) const { return attribute(std::int64_t(i)); }
MlirAttribute attribute(unsigned char i) const { return attribute(std::int64_t(i)); }
MlirAttribute attribute(bool b) const { return mlirBoolAttrGet(ctx.get(), b); }
MlirAttribute attribute(double d) const
{
return mlirFloatAttrDoubleGet(ctx.get(), mlirF64TypeGet(ctx.get()), d);
......@@ -184,11 +152,8 @@ struct mlir_program
{
return mlirStringAttrGet(ctx.get(), make_mlir_string_ref(s));
}
MlirAttribute attribute(std::nullptr_t) const
{
return {};
}
template<class T>
MlirAttribute attribute(std::nullptr_t) const { return {}; }
template <class T>
MlirAttribute attribute(const std::vector<T>& v) const
{
std::vector<MlirAttribute> attributes;
......@@ -201,9 +166,7 @@ struct mlir_program
MlirAttribute attribute(const value& v) const
{
MlirAttribute attr;
v.visit_value([&](auto&& x) {
attr = attribute(x);
});
v.visit_value([&](auto&& x) { attr = attribute(x); });
return attr;
}
MlirAttribute attribute(const std::vector<value>& v) const
......@@ -238,7 +201,5 @@ struct mlir_program
mlir_context ctx;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment