Unverified Commit be70702d authored by jungpark-mlir's avatar jungpark-mlir Committed by GitHub
Browse files

Update MLIR integration (#1451)

Update dialect registration interface
Update 2nd build pipeline call and use full arch name
parent fdc3f00a
......@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@c0723a7e50043d973cb73ae51dc30d36679ee7e5 -DBUILD_MIXR_TARGET=On
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@0f38fb33f518b53b94b541feb9b079668c5518e8 -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
......@@ -24,7 +24,6 @@
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/mlir.hpp>
namespace migraphx {
......
......@@ -32,7 +32,13 @@
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Registration.h>
#include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling"
#undef MIGRAPHX_MLIR
#else
#include <mlir-c/RegisterRocMLIR.h>
#endif
#endif
#include <migraphx/env.hpp>
......@@ -50,10 +56,6 @@
#include <deque>
#include <variant>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
......@@ -168,9 +170,11 @@ struct mlir_program
location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location))
{
MlirDialectHandle mixr_handle = mlirGetDialectHandle__migraphx__();
mlirDialectHandleRegisterDialect(mixr_handle, ctx.get());
mlirRegisterAllDialects(ctx.get());
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(registry);
mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
}
......@@ -452,7 +456,8 @@ struct mlir_program
auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")},
{"kernel", std::string("mixr")}});
{"kernel", std::string("mixr")},
{"arch", target_arch}});
ops.add_region(std::move(region));
insert(body, std::move(ops));
......@@ -512,7 +517,8 @@ struct mlir_program
pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops
bool xdlops = contains(get_xdlops_archs(), target_name);
auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
std::string tuned = get_tune_params(xdlops);
if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}});
......@@ -540,7 +546,7 @@ struct mlir_program
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call
mlirMIGraphXAddBackendPipeline(pm.get(), target_name.c_str(), "amdgcn-amd-amdhsa", "");
mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op{};
......@@ -550,16 +556,7 @@ struct mlir_program
return op;
}
void find_target()
{
std::string tname = get_device_name();
// HACK: Since MLIR can't handle the full target name
target_name = trim(split_string(tname, ':').front());
if(tname.size() != target_name.size())
std::cout
<< "*************** WARNING: MLIR may not compile the correct target features for: "
<< tname << std::endl;
}
void find_target() { target_arch = get_device_name(); }
std::pair<std::size_t, std::size_t> get_launch_params() const
{
......@@ -588,7 +585,7 @@ struct mlir_program
mlir_module mmodule;
problem_params pp;
std::deque<std::string> strings{};
std::string target_name;
std::string target_arch;
};
std::string dump_mlir(const module& m)
......@@ -650,6 +647,10 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace)
std::cout << m << std::endl;
// set mutex while llvm thread support is disabled.
static std::mutex g_mlirc_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_mlirc_mutex);
mlir_program mp;
mp.find_target();
mp.parse(m);
......@@ -669,46 +670,9 @@ instruction_ref insert_mlir(module& m,
std::vector<instruction_ref> refs;
std::size_t last = 0;
#ifdef MIGRAPHX_MLIR_BARE_POINTER
refs.reserve(inputs.size());
std::copy(inputs.begin(), inputs.end(), std::back_inserter(refs));
last = refs.size() - 1;
#else
refs.reserve(inputs.size() * 15);
std::unordered_map<uint64_t, instruction_ref> literal_map{};
auto get_literal = [&](uint64_t value) {
auto fi = literal_map.find(value);
if(fi != literal_map.end())
return fi->second;
auto lit = m.add_literal(value);
literal_map.emplace(value, lit);
return lit;
};
for(auto input : inputs)
{
const size_t offset = 0;
auto s = input->get_shape();
last = refs.size();
refs.push_back(input);
refs.push_back(input);
refs.push_back(get_literal(offset)); // offset
// dim sizes
std::transform(s.lens().begin(),
s.lens().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
// dim strides
std::transform(s.strides().begin(),
s.strides().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
}
#endif
last = refs.size() - 1;
co.expected_inputs = to_shapes(refs);
co.output_arg = last;
return m.insert_instruction(ins, co, refs);
......
......@@ -27,6 +27,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/permutation.hpp>
#include <fstream>
#include <mutex>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -88,6 +89,9 @@ std::string generate_miopen_config(const problem_params& pp)
auto query_miopen_db(const std::string& query)
{
static std::mutex g_db_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_db_mutex);
// TODO: Store db as a static variable
const auto dbpath = fs::path{"/opt"} / "rocm" / "share" / "miopen" / "db" / "miopen.db";
// Check if db file exists.
......
......@@ -140,7 +140,7 @@ TEST_CASE(conv)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32>
}
......@@ -163,7 +163,7 @@ TEST_CASE(conv_add_relu)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment