Commit 812cd5c8 authored by Paul's avatar Paul
Browse files

Add mlir_compile

parent 6c97c8ea
......@@ -4,4 +4,4 @@ blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -D
half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
# jungpark-mlir/llvm-project-mlir@7b82d84c795cc826b9607274695f3a9bee468102 -DBUILD_MIXR_TARGET=On
# jungpark-mlir/llvm-project-mlir@fc5bdee801385557c68f6bf5c9e0d59adbfec405 -DBUILD_MIXR_TARGET=On
......@@ -3,6 +3,7 @@
#include <string>
#include <migraphx/config.hpp>
#include <migraphx/gpu/code_object_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -10,6 +11,7 @@ struct module;
namespace gpu {
std::string dump_mlir(const module& m);
code_object_op compile_mlir(const module& m);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -8,6 +8,7 @@
#include <mlir-c/Dialect/Standard.h>
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Registration.h>
#endif
......@@ -16,6 +17,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <deque>
#include <variant>
......@@ -82,6 +84,7 @@ using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
mlirOpPrintingFlagsDestroy);
using mlir_region = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRegion, mlirRegionDestroy);
using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockDestroy);
using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy);
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
......@@ -438,6 +441,40 @@ struct mlir_program
}
}
code_object_op compile()
{
mlir_pass_manager pm{mlirPassManagerCreate(ctx.get())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call
const char *deviceName = "gfx908";
mlirMIGraphXAddBackendPipeline(pm.get(), deviceName);
mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op;
op.code_object = get_binary();
std::tie(op.global, op.local) = get_launch_params();
return op;
}
std::pair<std::size_t, std::size_t> get_launch_params()
{
int attrs[2];
// returns block and grid sizes
mlirGetKernelAttrs(mmodule.get(), attrs);
std::size_t local = attrs[0];
std::size_t global = local * attrs[1];
return {global, local};
}
value::binary get_binary()
{
value::binary result(mlirGetBinarySize(mmodule.get()));
if(mlirGetBinary(mmodule.get(), reinterpret_cast<char*>(result.data())))
return result;
MIGRAPHX_THROW("Failed to compile mlir program");
}
mlir_context ctx;
MlirLocation location;
mlir_module mmodule;
......@@ -452,6 +489,13 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op);
}
code_object_op compile_mlir(const module& m)
{
mlir_program mp;
mp.parse(m);
return mp.compile();
}
#else
std::string dump_mlir(const module&) { return {}; }
......
......@@ -49,6 +49,7 @@ si64]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
if(s.empty())
return;
EXPECT(encode(s) == encode(mlir_output));
auto op = migraphx::gpu::compile_mlir(m);
}
TEST_CASE(conv_add_relu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment