Commit 82df48a6 authored by jungpark-mlir's avatar jungpark-mlir
Browse files

Fixes per review

parent fe2370e6
......@@ -30,25 +30,14 @@
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/MIGraphX.h>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 3
#define MIGRAPHX_MLIR_REGISTER_ROCMLIR
#define MIGRAPHX_MLIR_USE_FULL_ARCH
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling"
#undef MIGRAPHX_MLIR
#endif
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h>
#ifdef MIGRAPHX_MLIR_REGISTER_ROCMLIR
#include <mlir-c/RegisterEverything.h>
#include <mlir-c/RegisterRocMLIR.h>
#else
#include <mlir-c/Registration.h>
#endif
#endif
#include <migraphx/env.hpp>
......@@ -180,7 +169,6 @@ struct mlir_program
location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location))
{
#ifdef MIGRAPHX_MLIR_REGISTER_ROCMLIR
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(registry);
mlirRegisterAllDialects(registry);
......@@ -188,11 +176,6 @@ struct mlir_program
mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
#else
MlirDialectHandle mixr_handle = mlirGetDialectHandle__migraphx__();
mlirDialectHandleRegisterDialect(mixr_handle, ctx.get());
mlirRegisterAllDialects(ctx.get());
#endif
}
MlirType make_type(shape::type_t t) const
......@@ -473,12 +456,8 @@ struct mlir_program
auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")},
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
{"kernel", std::string("mixr")},
{"arch", target_arch}});
#else
{"kernel", std::string("mixr")}});
#endif
ops.add_region(std::move(region));
insert(body, std::move(ops));
......@@ -538,12 +517,8 @@ struct mlir_program
pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
#else
bool xdlops = contains(get_xdlops_archs(), target_name);
#endif
std::string tuned = get_tune_params(xdlops);
if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}});
......@@ -571,11 +546,7 @@ struct mlir_program
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
#else
mlirMIGraphXAddBackendPipeline(pm.get(), target_name.c_str(), "amdgcn-amd-amdhsa", "");
#endif
mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op{};
......@@ -585,20 +556,7 @@ struct mlir_program
return op;
}
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
void find_target() { target_arch = get_device_name(); }
#else
void find_target()
{
target_arch = get_device_name();
// HACK: Since MLIR can't handle the full target name
target_chip = trim(split_string(target_arch, ':').front());
if(target_arch.size() != target_chip.size())
std::cout
<< "*************** WARNING: MLIR may not compile the correct target features for: "
<< tname << std::endl;
}
#endif
std::pair<std::size_t, std::size_t> get_launch_params() const
{
......@@ -627,11 +585,7 @@ struct mlir_program
mlir_module mmodule;
problem_params pp;
std::deque<std::string> strings{};
#ifdef MIGRAPHX_MLIR_USE_FULL_ARCH
std::string target_arch;
#else
std::string target_name;
#endif
};
std::string dump_mlir(const module& m)
......@@ -712,46 +666,9 @@ instruction_ref insert_mlir(module& m,
std::vector<instruction_ref> refs;
std::size_t last = 0;
#ifdef MIGRAPHX_MLIR_BARE_POINTER
refs.reserve(inputs.size());
std::copy(inputs.begin(), inputs.end(), std::back_inserter(refs));
last = refs.size() - 1;
#else
refs.reserve(inputs.size() * 15);
std::unordered_map<uint64_t, instruction_ref> literal_map{};
auto get_literal = [&](uint64_t value) {
auto fi = literal_map.find(value);
if(fi != literal_map.end())
return fi->second;
auto lit = m.add_literal(value);
literal_map.emplace(value, lit);
return lit;
};
for(auto input : inputs)
{
const size_t offset = 0;
auto s = input->get_shape();
last = refs.size();
refs.push_back(input);
refs.push_back(input);
refs.push_back(get_literal(offset)); // offset
// dim sizes
std::transform(s.lens().begin(),
s.lens().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
// dim strides
std::transform(s.strides().begin(),
s.strides().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
}
#endif
co.expected_inputs = to_shapes(refs);
co.output_arg = last;
return m.insert_instruction(ins, co, refs);
......
......@@ -29,8 +29,6 @@
#include <fstream>
#include <mutex>
static std::mutex g_db_mutex; // NOLINT
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
......@@ -91,6 +89,7 @@ std::string generate_miopen_config(const problem_params& pp)
auto query_miopen_db(const std::string& query)
{
static std::mutex g_db_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_db_mutex);
// TODO: Store db as a static variable
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment