Commit 6372469b authored by jungpark-mlir's avatar jungpark-mlir
Browse files

Update mlir interface

parent eb094e57
......@@ -32,7 +32,8 @@
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Registration.h>
#include <mlir-c/RegisterEverything.h>
#include <mlir-c/RegisterRocMLIR.h>
#endif
#include <migraphx/env.hpp>
......@@ -168,9 +169,12 @@ struct mlir_program
location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location))
{
MlirDialectHandle mixr_handle = mlirGetDialectHandle__migraphx__();
mlirDialectHandleRegisterDialect(mixr_handle, ctx.get());
mlirRegisterAllDialects(ctx.get());
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(registry);
mlirRegisterAllDialects(registry);
mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
}
......@@ -452,7 +456,8 @@ struct mlir_program
auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")},
{"kernel", std::string("mixr")}});
{"kernel", std::string("mixr")},
{"arch", target_arch}});
ops.add_region(std::move(region));
insert(body, std::move(ops));
......@@ -512,7 +517,8 @@ struct mlir_program
pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops
bool xdlops = contains(get_xdlops_archs(), target_name);
auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
std::string tuned = get_tune_params(xdlops);
if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}});
......@@ -540,7 +546,7 @@ struct mlir_program
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call
mlirMIGraphXAddBackendPipeline(pm.get(), target_name.c_str(), "amdgcn-amd-amdhsa", "");
mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op{};
......@@ -550,16 +556,7 @@ struct mlir_program
return op;
}
void find_target()
{
std::string tname = get_device_name();
// HACK: Since MLIR can't handle the full target name
target_name = trim(split_string(tname, ':').front());
if(tname.size() != target_name.size())
std::cout
<< "*************** WARNING: MLIR may not compile the correct target features for: "
<< tname << std::endl;
}
void find_target() { target_arch = get_device_name(); }
std::pair<std::size_t, std::size_t> get_launch_params() const
{
......@@ -588,7 +585,7 @@ struct mlir_program
mlir_module mmodule;
problem_params pp;
std::deque<std::string> strings{};
std::string target_name;
std::string target_arch;
};
std::string dump_mlir(const module& m)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment