Unverified Commit be70702d authored by jungpark-mlir's avatar jungpark-mlir Committed by GitHub
Browse files

Update MLIR integration (#1451)

Update dialect registration interface
Update 2nd build pipeline call and use full arch name
parent fdc3f00a
...@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@c0723a7e50043d973cb73ae51dc30d36679ee7e5 -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@0f38fb33f518b53b94b541feb9b079668c5518e8 -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -32,7 +32,13 @@ ...@@ -32,7 +32,13 @@
#include <mlir-c/Dialect/MIGraphX.h> #include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/IntegerSet.h> #include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Registration.h> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling"
#undef MIGRAPHX_MLIR
#else
#include <mlir-c/RegisterRocMLIR.h>
#endif
#endif #endif
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
...@@ -50,10 +56,6 @@ ...@@ -50,10 +56,6 @@
#include <deque> #include <deque>
#include <variant> #include <variant>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
...@@ -168,9 +170,11 @@ struct mlir_program ...@@ -168,9 +170,11 @@ struct mlir_program
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location))
{ {
MlirDialectHandle mixr_handle = mlirGetDialectHandle__migraphx__(); MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirDialectHandleRegisterDialect(mixr_handle, ctx.get()); mlirRegisterRocMLIRDialects(registry);
mlirRegisterAllDialects(ctx.get()); mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/); mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
} }
...@@ -452,7 +456,8 @@ struct mlir_program ...@@ -452,7 +456,8 @@ struct mlir_program
auto ops = create_operation_state("func.func"); auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)}, ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")}, {"sym_name", std::string("main")},
{"kernel", std::string("mixr")}}); {"kernel", std::string("mixr")},
{"arch", target_arch}});
ops.add_region(std::move(region)); ops.add_region(std::move(region));
insert(body, std::move(ops)); insert(body, std::move(ops));
...@@ -512,7 +517,8 @@ struct mlir_program ...@@ -512,7 +517,8 @@ struct mlir_program
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops // check if HW supports xdlops
bool xdlops = contains(get_xdlops_archs(), target_name); auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
std::string tuned = get_tune_params(xdlops); std::string tuned = get_tune_params(xdlops);
if(not tuned.empty()) if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}}); ops.add_attributes({{"perf_config", tuned}});
...@@ -540,7 +546,7 @@ struct mlir_program ...@@ -540,7 +546,7 @@ struct mlir_program
// 1st pipeline to call // 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get()); mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call // 2nd pipeline to call
mlirMIGraphXAddBackendPipeline(pm.get(), target_name.c_str(), "amdgcn-amd-amdhsa", ""); mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
mlirPassManagerRun(pm.get(), mmodule.get()); mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op{}; code_object_op op{};
...@@ -550,16 +556,7 @@ struct mlir_program ...@@ -550,16 +556,7 @@ struct mlir_program
return op; return op;
} }
void find_target() void find_target() { target_arch = get_device_name(); }
{
std::string tname = get_device_name();
// HACK: Since MLIR can't handle the full target name
target_name = trim(split_string(tname, ':').front());
if(tname.size() != target_name.size())
std::cout
<< "*************** WARNING: MLIR may not compile the correct target features for: "
<< tname << std::endl;
}
std::pair<std::size_t, std::size_t> get_launch_params() const std::pair<std::size_t, std::size_t> get_launch_params() const
{ {
...@@ -588,7 +585,7 @@ struct mlir_program ...@@ -588,7 +585,7 @@ struct mlir_program
mlir_module mmodule; mlir_module mmodule;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_name; std::string target_arch;
}; };
std::string dump_mlir(const module& m) std::string dump_mlir(const module& m)
...@@ -650,6 +647,10 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct ...@@ -650,6 +647,10 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace) if(trace)
std::cout << m << std::endl; std::cout << m << std::endl;
// set mutex while llvm thread support is disabled.
static std::mutex g_mlirc_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_mlirc_mutex);
mlir_program mp; mlir_program mp;
mp.find_target(); mp.find_target();
mp.parse(m); mp.parse(m);
...@@ -669,46 +670,9 @@ instruction_ref insert_mlir(module& m, ...@@ -669,46 +670,9 @@ instruction_ref insert_mlir(module& m,
std::vector<instruction_ref> refs; std::vector<instruction_ref> refs;
std::size_t last = 0; std::size_t last = 0;
#ifdef MIGRAPHX_MLIR_BARE_POINTER
refs.reserve(inputs.size()); refs.reserve(inputs.size());
std::copy(inputs.begin(), inputs.end(), std::back_inserter(refs)); std::copy(inputs.begin(), inputs.end(), std::back_inserter(refs));
last = refs.size() - 1; last = refs.size() - 1;
#else
refs.reserve(inputs.size() * 15);
std::unordered_map<uint64_t, instruction_ref> literal_map{};
auto get_literal = [&](uint64_t value) {
auto fi = literal_map.find(value);
if(fi != literal_map.end())
return fi->second;
auto lit = m.add_literal(value);
literal_map.emplace(value, lit);
return lit;
};
for(auto input : inputs)
{
const size_t offset = 0;
auto s = input->get_shape();
last = refs.size();
refs.push_back(input);
refs.push_back(input);
refs.push_back(get_literal(offset)); // offset
// dim sizes
std::transform(s.lens().begin(),
s.lens().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
// dim strides
std::transform(s.strides().begin(),
s.strides().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
}
#endif
co.expected_inputs = to_shapes(refs); co.expected_inputs = to_shapes(refs);
co.output_arg = last; co.output_arg = last;
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <fstream> #include <fstream>
#include <mutex>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -88,6 +89,9 @@ std::string generate_miopen_config(const problem_params& pp) ...@@ -88,6 +89,9 @@ std::string generate_miopen_config(const problem_params& pp)
auto query_miopen_db(const std::string& query) auto query_miopen_db(const std::string& query)
{ {
static std::mutex g_db_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_db_mutex);
// TODO: Store db as a static variable // TODO: Store db as a static variable
const auto dbpath = fs::path{"/opt"} / "rocm" / "share" / "miopen" / "db" / "miopen.db"; const auto dbpath = fs::path{"/opt"} / "rocm" / "share" / "miopen" / "db" / "miopen.db";
// Check if db file exists. // Check if db file exists.
......
...@@ -140,7 +140,7 @@ TEST_CASE(conv) ...@@ -140,7 +140,7 @@ TEST_CASE(conv)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} { func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32> return %0 : tensor<1x2x2x2xf32>
} }
...@@ -163,7 +163,7 @@ TEST_CASE(conv_add_relu) ...@@ -163,7 +163,7 @@ TEST_CASE(conv_add_relu)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} { func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment