Unverified Commit 4996c6d7 authored by Manupa Karunaratne's avatar Manupa Karunaratne Committed by GitHub
Browse files

[MLIR][5.7] add input fusion support for view ops (#1705)

Adds support for slice,transpose,contigous and reshape fusions into input tensors for a fused mlir kernel.
parent 4fb3fd4a
......@@ -110,7 +110,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@55c6ee66cc7502db7950693b3e845676cbf400b1 -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@a997d5f51314b45d7a4c04f1599966dcf53f9b4d -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
......@@ -79,11 +79,41 @@ struct mlir_op
MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
auto* pm = mods.front();
auto type = pm->get_output_shapes().front().type();
auto shape = op.compute_shape({inputs[n - 2], inputs[n - 1]});
return shape.with_type(type);
module_ref mod = mods[0];
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end());
for(std::string param_name : names)
{
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
}
for(auto ins : iterator_for(*mod))
{
if(ins->name() == "@param")
{
continue;
}
if(ins->name() == "@literal")
{
ins_shapes[ins] = ins->get_shape();
continue;
}
if(ins->name() == "@return")
{
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
std::transform(ins->inputs().begin(),
ins->inputs().end(),
input_shapes.begin(),
[&](auto in) { return ins_shapes[in]; });
ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes);
}
MIGRAPHX_THROW("No return found in the submodule");
}
};
MIGRAPHX_REGISTER_OP(mlir_op);
......@@ -113,6 +143,53 @@ struct find_mlir_op
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
if(ins->name() != "@literal")
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -131,7 +208,8 @@ struct find_mlir_op
"add",
"relu",
"dequantizelinear",
"quantizelinear"},
"quantizelinear",
"mul"},
i.name());
}))
return;
......@@ -147,19 +225,16 @@ struct find_mlir_op
std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
auto x = mm->add_parameter("x" + std::to_string(names.size()),
gemm_based_op->inputs().at(0)->get_shape());
auto w = mm->add_parameter("x" + std::to_string(names.size() + 1),
gemm_based_op->inputs().at(1)->get_shape());
auto conv = mm->add_instruction(gemm_based_op->get_operator(), {x, w});
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
std::transform(names.begin(),
names.end(),
ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&](auto name, auto input) {
[&, &anchor_op = anchor_op](auto name, auto input) {
if(input == x_ins)
return std::make_pair(pm->get_parameter(name), conv);
return std::make_pair(pm->get_parameter(name), anchor_op);
return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape()));
});
......@@ -170,7 +245,7 @@ struct find_mlir_op
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != gemm_based_op; });
inputs.insert(inputs.end(), gemm_based_op->inputs().begin(), gemm_based_op->inputs().end());
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction(
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
}
......
......@@ -324,7 +324,8 @@ struct mlir_program
std::string,
value,
std::vector<value>,
MlirType>;
MlirType,
MlirAttribute>;
using named_attribute_t = std::pair<std::string_view, attribute_t>;
MlirNamedAttribute name_attribute(const named_attribute_t& na) const
......@@ -481,6 +482,10 @@ struct mlir_program
{
if(ins->name() == "@return")
return "func.return";
if(ins->name() == "@literal")
{
return "tosa.const";
}
return "migraphx." + ins->name();
}
......@@ -532,11 +537,24 @@ struct mlir_program
{
if(ins->name() == "@param")
continue;
if(ins->name() == "contiguous")
{
ins_map[ins] = ins_map[ins->inputs().at(0)];
continue;
}
auto name = get_name(ins);
auto ops = create_operation_state(name);
ops.add_attribute_value(get_operator_value(ins->get_operator()));
if(ins->name() != "@return")
ops.add_results({get_shape(ins)});
if(ins->name() == "@literal")
{
literal r = ins->get_literal();
MlirType tensor_type = make_tensor(ins->get_shape());
MlirAttribute mlir_value_attr =
mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data());
ops.add_attributes({{"value", mlir_value_attr}});
}
if(ins->name() == "convolution" or ins->name() == "dot")
{
pp =
......@@ -739,12 +757,13 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
{
adjust_param_shapes(m, inputs);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace)
std::cout << m << std::endl;
// set mutex while llvm thread support is disabled.
static std::mutex g_mlirc_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_mlirc_mutex);
if(trace)
std::cout << m << std::endl;
mlir_program mp;
mp.find_target();
mp.parse(m);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment