Commit e0303225 authored by jungpark-mlir's avatar jungpark-mlir
Browse files

Apply MLIR versioning

parent 6372469b
......@@ -30,10 +30,25 @@
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/MIGraphX.h>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 3
#define MIGRAPHX_MLIR_REGISTER_ROCMLIR
#define MIGRAPHX_MLIR_USE_FULL_ARCH
#endif
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h>
#ifdef MIGRAPHX_MLIR_REGISTER_ROCMLIR
#include <mlir-c/RegisterEverything.h>
#include <mlir-c/RegisterRocMLIR.h>
#else
#include <mlir-c/Registration.h>
#endif
#endif
#include <migraphx/env.hpp>
......@@ -51,10 +66,6 @@
#include <deque>
#include <variant>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
......@@ -169,6 +180,7 @@ struct mlir_program
location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location))
{
#ifdef MIGRAPHX_MLIR_REGISTER_ROCMLIR
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(registry);
mlirRegisterAllDialects(registry);
......@@ -176,6 +188,11 @@ struct mlir_program
mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
#else
MlirDialectHandle mixr_handle = mlirGetDialectHandle__migraphx__();
mlirDialectHandleRegisterDialect(mixr_handle, ctx.get());
mlirRegisterAllDialects(ctx.get());
#endif
}
MlirType make_type(shape::type_t t) const
......@@ -456,8 +473,12 @@ struct mlir_program
auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")},
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
{"kernel", std::string("mixr")},
{"arch", target_arch}});
#else
{"kernel", std::string("mixr")}});
#endif
ops.add_region(std::move(region));
insert(body, std::move(ops));
......@@ -517,8 +538,12 @@ struct mlir_program
pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops
auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
#else
bool xdlops = contains(get_xdlops_archs(), target_name);
#endif
std::string tuned = get_tune_params(xdlops);
if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}});
......@@ -546,7 +571,11 @@ struct mlir_program
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
#else
mlirMIGraphXAddBackendPipeline(pm.get(), target_name.c_str(), "amdgcn-amd-amdhsa", "");
#endif
mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op{};
......@@ -556,7 +585,20 @@ struct mlir_program
return op;
}
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
void find_target() { target_arch = get_device_name(); }
#else
void find_target()
{
std::string tname = get_device_name();
// HACK: Since MLIR can't handle the full target name
target_name = trim(split_string(tname, ':').front());
if(tname.size() != target_name.size())
std::cout
<< "*************** WARNING: MLIR may not compile the correct target features for: "
<< tname << std::endl;
}
#endif
std::pair<std::size_t, std::size_t> get_launch_params() const
{
......@@ -585,7 +627,11 @@ struct mlir_program
mlir_module mmodule;
problem_params pp;
std::deque<std::string> strings{};
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
std::string target_arch;
#else
std::string target_name;
#endif
};
std::string dump_mlir(const module& m)
......@@ -685,8 +731,8 @@ instruction_ref insert_mlir(module& m,
for(auto input : inputs)
{
const size_t offset = 0;
auto s = input->get_shape();
last = refs.size();
auto s = input->get_shape();
last = refs.size();
refs.push_back(input);
refs.push_back(input);
refs.push_back(get_literal(offset)); // offset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment