Unverified Commit 7c4dc99a authored by Manupa Karunaratne's avatar Manupa Karunaratne Committed by GitHub
Browse files

[MLIR] add dot offloads with manual tuning support (#1631)

* [MLIR] add dot offloads with manual tuning support
* This commit adds dot + pointwise fusion support
along with manual tuning using rocMLIR.
parent c614588b
......@@ -110,7 +110,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@78b706fe9879587ab98b6614ae539265374a3fae -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@acb727b348086b58a7f261b32c0e4f0686a4c0ee -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
......@@ -39,8 +39,10 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);
#ifdef MIGRAPHX_MLIR
struct mlir_conv
struct mlir_op
{
std::string name() const { return "gpu::mlir_op"; }
operation op = make_op("convolution");
template <class Self, class F>
......@@ -49,7 +51,6 @@ struct mlir_conv
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::mlir_conv"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.packed_or_broadcasted();
......@@ -61,7 +62,7 @@ struct mlir_conv
return op.compute_shape({inputs[n - 2], inputs[n - 1]});
}
};
MIGRAPHX_REGISTER_OP(mlir_conv);
MIGRAPHX_REGISTER_OP(mlir_op);
namespace {
......@@ -79,27 +80,27 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
return true;
}
struct find_conv_pointwise
struct find_mlir_op
{
// Find a convolution followed by a pointwise operation.
auto matcher() const
{
auto convolution =
match::skip(match::name("contiguous"))(is_mlir_conv().bind("convolution"));
return match::name("pointwise")(match::any_of[match::inputs()](convolution.bind("x")));
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), is_mlir_conv()).bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["convolution"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
auto ins = r.result;
auto gemm_based_op = r.instructions["gemm_based_op"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
// Whitelist pointwise operators
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) {
return not contains({"@literal", "@param", "@return", "convolution", "add", "relu"},
i.name());
return not contains(
{"@literal", "@param", "@return", "convolution", "dot", "add", "relu"},
i.name());
}))
return;
// Only fuse with fp32/fp16
......@@ -113,10 +114,10 @@ struct find_conv_pointwise
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
auto x = mm->add_parameter("x" + std::to_string(names.size()),
conv_ins->inputs().at(0)->get_shape());
gemm_based_op->inputs().at(0)->get_shape());
auto w = mm->add_parameter("x" + std::to_string(names.size() + 1),
conv_ins->inputs().at(1)->get_shape());
auto conv = mm->add_instruction(conv_ins->get_operator(), {x, w});
gemm_based_op->inputs().at(1)->get_shape());
auto conv = mm->add_instruction(gemm_based_op->get_operator(), {x, w});
std::transform(names.begin(),
names.end(),
ins->inputs().begin(),
......@@ -133,12 +134,13 @@ struct find_conv_pointwise
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != conv_ins; });
inputs.insert(inputs.end(), conv_ins->inputs().begin(), conv_ins->inputs().end());
[&](auto input) { return input != gemm_based_op; });
inputs.insert(inputs.end(), gemm_based_op->inputs().begin(), gemm_based_op->inputs().end());
mpm.get_module().replace_instruction(
ins, mlir_conv{conv_ins->get_operator()}, inputs, {mm});
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
}
};
} // namespace
#endif
......@@ -149,7 +151,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
match::find_matches(mpm, find_conv_pointwise{});
match::find_matches(mpm, find_mlir_op{});
}
else
{
......
......@@ -32,7 +32,7 @@ namespace gpu {
struct mlir_compiler : compiler<mlir_compiler>
{
std::vector<std::string> names() const { return {"gpu::mlir_conv"}; }
std::vector<std::string> names() const { return {"gpu::mlir_op"}; }
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
......
......@@ -30,6 +30,7 @@
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/Dialect/Rock.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h>
#include <mutex>
......@@ -55,12 +56,16 @@
#include <migraphx/permutation.hpp>
#include <deque>
#include <variant>
#include <fstream>
#include <sstream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
#ifdef MIGRAPHX_MLIR
template <class T, class F, F f> // NOLINT
......@@ -124,6 +129,8 @@ using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
using mlir_region = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRegion, mlirRegionDestroy);
using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockDestroy);
using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy);
using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable,
mlirRockTuningTableDestroy);
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
......@@ -455,7 +462,7 @@ struct mlir_program
auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")},
{"sym_name", sym_name},
{"kernel", std::string("mixr")},
{"arch", target_arch}});
ops.add_region(std::move(region));
......@@ -498,11 +505,25 @@ struct mlir_program
return ins->get_shape();
}
static std::string get_symbol_name(const module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "convolution" or ins->name() == "dot")
{
return "mlir_" + ins->name();
}
}
return "main";
}
void parse(const module& m)
{
sym_name = get_symbol_name(m);
auto mbody = mlirModuleGetBody(mmodule.get());
std::unordered_map<instruction_ref, MlirValue> ins_map;
auto fbody = insert(mbody, m, ins_map);
for(auto ins : iterator_for(m))
{
if(ins->name() == "@param")
......@@ -512,16 +533,13 @@ struct mlir_program
ops.add_attribute_value(get_operator_value(ins->get_operator()));
if(ins->name() != "@return")
ops.add_results({get_shape(ins)});
if(ins->name() == "convolution")
if(ins->name() == "convolution" or ins->name() == "dot")
{
pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops
auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
std::string tuned = get_tune_params(xdlops);
if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}});
auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
if(xdlops)
ops.add_attributes({{"xdlopsV2", true}});
}
......@@ -542,15 +560,19 @@ struct mlir_program
code_object_op compile() MIGRAPHX_TIDY_CONST
{
mlir_pass_manager pm{mlirPassManagerCreate(ctx.get())};
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get());
mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRun(pm_front.get(), mmodule.get());
// 2nd pipeline to call
mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
mlirPassManagerRun(pm.get(), mmodule.get());
get_module_tuned();
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRun(pm_back.get(), mmodule.get());
code_object_op op{};
op.symbol_name = "main";
op.symbol_name = sym_name;
op.code_object = get_binary();
std::tie(op.global, op.local) = get_launch_params();
return op;
......@@ -578,7 +600,74 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program");
}
std::string get_tune_params(bool xdlops) { return get_mlir_perf_for_conv(pp, xdlops); }
std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); }
// This function appends to tuning cfg file that could be
// used with rocMLIR tuning scripts.
void dump_tuning_cfg(const char* prob_config) const
{
std::string tuning_cfg_path = string_value_of(MIGRAPHX_MLIR_TUNING_CFG{});
if(!tuning_cfg_path.empty())
{
std::vector<std::string> tokens = split_string(prob_config, '\t');
std::string prob = tokens[1];
if(starts_with(prob, "conv"))
{
tuning_cfg_path += ".conv";
}
else
{
tuning_cfg_path += ".gemm";
}
std::ofstream tuning_cfg(tuning_cfg_path, std::ios::app);
tuning_cfg << prob << std::endl;
}
}
static mlir_tuning_table create_tuning_table()
{
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()};
std::string tuning_db_path = string_value_of(MIGRAPHX_MLIR_TUNING_DB{});
if(!tuning_db_path.empty())
{
std::ifstream tuning_db_tsv(tuning_db_path);
if(tuning_db_tsv)
{
std::string line;
while(std::getline(tuning_db_tsv, line))
{
std::vector<std::string> tokens = split_string(line, '\t');
std::string arch = tokens[0];
std::string prob = tokens[1];
std::string perf = tokens[2];
std::string key = arch.append("\t").append(prob);
mlirRockTuningUpdateTable(tuning_table.get(), key.c_str(), perf.c_str(), 1.0);
}
}
}
else
{
std::cerr
<< "WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for "
"optimal performance."
<< std::endl;
}
return tuning_table;
}
bool get_module_tuned() const
{
static mlir_tuning_table tuning_table = create_tuning_table();
if(!mlirRockTuningSetFromTable(tuning_table.get(), mmodule.get()))
{
const char* prob_config = mlirRockTuningGetKey(tuning_table.get(), mmodule.get());
std::stringstream key(prob_config);
std::cerr << "fails to set param on" << prob_config << std::endl;
dump_tuning_cfg(prob_config);
return false;
}
return true;
}
mlir_context ctx;
MlirLocation location;
......@@ -586,6 +675,7 @@ struct mlir_program
problem_params pp;
std::deque<std::string> strings{};
std::string target_arch;
std::string sym_name;
};
std::string dump_mlir(const module& m)
......
......@@ -140,7 +140,7 @@ TEST_CASE(conv)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
func.func @mlir_convolution(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32>
}
......@@ -163,7 +163,7 @@ TEST_CASE(conv_add_relu)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
func.func @mlir_convolution(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
......@@ -187,4 +187,30 @@ module {
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_add)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : tensor<1x5x4xf32>, tensor<1x4x3xf32> -> tensor<1x5x3xf32>
%1 = migraphx.add(%0, %arg2) : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::float_type, {1, 5, 3}});
auto conv = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto add = m.add_instruction(migraphx::make_op("add"), conv, arg2);
m.add_return({add});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment