Commit b717b473 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

* standalone attention sane offloads to mlir but mlir is broken

parent 25d6b2e2
......@@ -81,6 +81,7 @@ struct mlir_op
MIGRAPHX_THROW("should have at least two inputs.");
module_ref mod = mods[0];
std::cerr << "mod:" << *mod << std::endl;
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
......@@ -448,19 +449,31 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op
[&](auto input) { return input != top_ins; });
}
auto softmax = mm->add_instruction(r.instructions["softmax"]->get_operator(), new_top_ins);
insert_to_map(ins_map, r.instructions["softmax"], softmax);
std::transform(r.instructions["bottom_dot"]->inputs().begin(),
r.instructions["bottom_dot"]->inputs().end(),
std::inserter(ins_map, ins_map.end()),
[&](auto old_ins) {
if(old_ins == r.instructions["softmax"]){
return std::make_pair(old_ins, softmax);
}
inputs.push_back(old_ins);
return std::make_pair(old_ins,
mm->add_parameter("bdot_non_smax_in", old_ins->get_shape()));
});
auto bottom_dot_a = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().front());
auto bottom_dot_b = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().back());
auto new_bottom_dot = mm->add_instruction(make_op("dot"), {bottom_dot_a, bottom_dot_b});
mm->add_return({new_bottom_dot});
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction(
top_ins, mlir_op{new_bottom_dot->get_operator()}, inputs, {mm});
r.instructions["bottom_dot"], mlir_op{new_bottom_dot->get_operator()}, inputs, {mm});
}
auto matcher() const {
auto match_softmax_input = match::any_of[match::inputs()](match::name("dot").bind("top_dot"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot").bind("top_dot"))).bind("scale"));
auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax").bind("softmax"))).bind("bottom_dot");
auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax")(match_softmax_input).bind("softmax"))).bind("bottom_dot");
return is_mlir_attention;
}
......@@ -554,12 +567,12 @@ bool is_enabled(std::string_view op_name, context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_standalone_attention_op{});
if(is_enabled("fused", this->ctx))
{
match::find_matches(mpm, find_mlir_attention_fused_ops{});
match::find_matches(mpm, find_mlir_fused_ops{});
}
match::find_matches(mpm, find_mlir_standalone_attention_op{});
if(is_enabled("convolution", this->ctx))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment