Commit 3af83bae authored by jungpark-mlir's avatar jungpark-mlir
Browse files

add mlir gemm-pointwise fusion

parent 962329f3
......@@ -136,6 +136,75 @@ struct find_conv_pointwise
ins, mlir_conv{conv_ins->get_operator()}, inputs, {mm});
}
};
MIGRAPHX_PRED_MATCHER(is_mlir_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
return false;
return true;
}
struct find_gemm_pointwise
{
// Find a convolution followed by a pointwise operation.
auto matcher() const
{
auto gemm = match::skip(match::name("contiguous"))(is_mlir_gemm().bind("dot"));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
// turn match::any_of[match::inputs()](gemm.bind("x"));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["dot"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
// Whitelist pointwise operators
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) {
return not contains({"@literal", "@param", "@return", "dot", "add", "relu"},
i.name());
}))
return;
// Only fuse with fp32/fp16
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type, shape::type_t::half_type},
i->get_shape().type());
}))
return;
std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
auto x = mm->add_parameter("x" + std::to_string(names.size()),
gemm_ins->inputs().at(0)->get_shape());
auto w = mm->add_parameter("x" + std::to_string(names.size() + 1),
gemm_ins->inputs().at(1)->get_shape());
auto gemm = mm->add_instruction(gemm_ins->get_operator(), {x, w});
std::transform(names.begin(),
names.end(),
ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&](auto name, auto input) {
if(input == x_ins)
return std::make_pair(pm->get_parameter(name), gemm);
return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape()));
});
mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
std::vector<instruction_ref> inputs;
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != gemm_ins; });
inputs.insert(inputs.end(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
mpm.get_module().replace_instruction(
ins, mlir_conv{gemm_ins->get_operator()}, inputs, {mm});
}
};
} // namespace
#endif
......@@ -144,6 +213,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_conv_pointwise{});
match::find_matches(mpm, find_gemm_pointwise{});
#else
(void)mpm;
#endif
......
......@@ -455,7 +455,7 @@ struct mlir_program
auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")},
{"sym_name", std::string("mlir_main")},
{"kernel", std::string("mixr")},
{"arch", target_arch}});
ops.add_region(std::move(region));
......@@ -550,7 +550,7 @@ struct mlir_program
mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op{};
op.symbol_name = "main";
op.symbol_name = "mlir_main";
op.code_object = get_binary();
std::tie(op.global, op.local) = get_launch_params();
return op;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment