Commit e0303225 authored by jungpark-mlir's avatar jungpark-mlir
Browse files

Apply MLIR versioning

parent 6372469b
...@@ -30,10 +30,25 @@ ...@@ -30,10 +30,25 @@
#include <mlir-c/BuiltinTypes.h> #include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Diagnostics.h> #include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/MIGraphX.h> #include <mlir-c/Dialect/MIGraphX.h>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 3
#define MIGRAPHX_MLIR_REGISTER_ROCMLIR
#define MIGRAPHX_MLIR_USE_FULL_ARCH
#endif
#include <mlir-c/IntegerSet.h> #include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#ifdef MIGRAPHX_MLIR_REGISTER_ROCMLIR
#include <mlir-c/RegisterEverything.h> #include <mlir-c/RegisterEverything.h>
#include <mlir-c/RegisterRocMLIR.h> #include <mlir-c/RegisterRocMLIR.h>
#else
#include <mlir-c/Registration.h>
#endif
#endif #endif
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
...@@ -51,10 +66,6 @@ ...@@ -51,10 +66,6 @@
#include <deque> #include <deque>
#include <variant> #include <variant>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
...@@ -169,6 +180,7 @@ struct mlir_program ...@@ -169,6 +180,7 @@ struct mlir_program
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location))
{ {
#ifdef MIGRAPHX_MLIR_REGISTER_ROCMLIR
MlirDialectRegistry registry = mlirDialectRegistryCreate(); MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(registry); mlirRegisterRocMLIRDialects(registry);
mlirRegisterAllDialects(registry); mlirRegisterAllDialects(registry);
...@@ -176,6 +188,11 @@ struct mlir_program ...@@ -176,6 +188,11 @@ struct mlir_program
mlirContextLoadAllAvailableDialects(ctx.get()); mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry); mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/); mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
#else
MlirDialectHandle mixr_handle = mlirGetDialectHandle__migraphx__();
mlirDialectHandleRegisterDialect(mixr_handle, ctx.get());
mlirRegisterAllDialects(ctx.get());
#endif
} }
MlirType make_type(shape::type_t t) const MlirType make_type(shape::type_t t) const
...@@ -456,8 +473,12 @@ struct mlir_program ...@@ -456,8 +473,12 @@ struct mlir_program
auto ops = create_operation_state("func.func"); auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)}, ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")}, {"sym_name", std::string("main")},
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
{"kernel", std::string("mixr")}, {"kernel", std::string("mixr")},
{"arch", target_arch}}); {"arch", target_arch}});
#else
{"kernel", std::string("mixr")}});
#endif
ops.add_region(std::move(region)); ops.add_region(std::move(region));
insert(body, std::move(ops)); insert(body, std::move(ops));
...@@ -517,8 +538,12 @@ struct mlir_program ...@@ -517,8 +538,12 @@ struct mlir_program
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops // check if HW supports xdlops
auto target_chip = trim(split_string(target_arch, ':').front()); #ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
bool xdlops = contains(get_xdlops_archs(), target_chip); auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
#else
bool xdlops = contains(get_xdlops_archs(), target_name);
#endif
std::string tuned = get_tune_params(xdlops); std::string tuned = get_tune_params(xdlops);
if(not tuned.empty()) if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}}); ops.add_attributes({{"perf_config", tuned}});
...@@ -546,7 +571,11 @@ struct mlir_program ...@@ -546,7 +571,11 @@ struct mlir_program
// 1st pipeline to call // 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get()); mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call // 2nd pipeline to call
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
#else
mlirMIGraphXAddBackendPipeline(pm.get(), target_name.c_str(), "amdgcn-amd-amdhsa", "");
#endif
mlirPassManagerRun(pm.get(), mmodule.get()); mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op{}; code_object_op op{};
...@@ -556,7 +585,20 @@ struct mlir_program ...@@ -556,7 +585,20 @@ struct mlir_program
return op; return op;
} }
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
void find_target() { target_arch = get_device_name(); } void find_target() { target_arch = get_device_name(); }
#else
void find_target()
{
std::string tname = get_device_name();
// HACK: Since MLIR can't handle the full target name
target_name = trim(split_string(tname, ':').front());
if(tname.size() != target_name.size())
std::cout
<< "*************** WARNING: MLIR may not compile the correct target features for: "
<< tname << std::endl;
}
#endif
std::pair<std::size_t, std::size_t> get_launch_params() const std::pair<std::size_t, std::size_t> get_launch_params() const
{ {
...@@ -585,7 +627,11 @@ struct mlir_program ...@@ -585,7 +627,11 @@ struct mlir_program
mlir_module mmodule; mlir_module mmodule;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
std::string target_arch; std::string target_arch;
#else
std::string target_name;
#endif
}; };
std::string dump_mlir(const module& m) std::string dump_mlir(const module& m)
...@@ -685,8 +731,8 @@ instruction_ref insert_mlir(module& m, ...@@ -685,8 +731,8 @@ instruction_ref insert_mlir(module& m,
for(auto input : inputs) for(auto input : inputs)
{ {
const size_t offset = 0; const size_t offset = 0;
auto s = input->get_shape(); auto s = input->get_shape();
last = refs.size(); last = refs.size();
refs.push_back(input); refs.push_back(input);
refs.push_back(input); refs.push_back(input);
refs.push_back(get_literal(offset)); // offset refs.push_back(get_literal(offset)); // offset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment