Commit 83b9164b authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

* fix pm + attention test

parent a22ec139
...@@ -411,7 +411,7 @@ struct find_mlir_standalone_attention_op ...@@ -411,7 +411,7 @@ struct find_mlir_standalone_attention_op
mm->set_bypass(); mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto top_ins = r.instructions["top_dot"]; auto top_ins = r.instructions["gemm0"];
auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins); auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins);
inputs.insert(inputs.begin(), top_inputs.begin(), top_inputs.end()); inputs.insert(inputs.begin(), top_inputs.begin(), top_inputs.end());
ins_map[top_ins] = new_top_ins; ins_map[top_ins] = new_top_ins;
...@@ -424,8 +424,8 @@ struct find_mlir_standalone_attention_op ...@@ -424,8 +424,8 @@ struct find_mlir_standalone_attention_op
[&](auto input) { return input != top_ins; }); [&](auto input) { return input != top_ins; });
} }
auto softmax = mm->add_instruction(r.instructions["softmax"]->get_operator(), new_top_ins); auto softmax = mm->add_instruction(r.instructions["softmax"]->get_operator(), new_top_ins);
std::transform(r.instructions["bottom_dot"]->inputs().begin(), std::transform(r.instructions["gemm1"]->inputs().begin(),
r.instructions["bottom_dot"]->inputs().end(), r.instructions["gemm1"]->inputs().end(),
std::inserter(ins_map, ins_map.end()), std::inserter(ins_map, ins_map.end()),
[&](auto old_ins) { [&](auto old_ins) {
if(old_ins == r.instructions["softmax"]){ if(old_ins == r.instructions["softmax"]){
...@@ -433,22 +433,30 @@ struct find_mlir_standalone_attention_op ...@@ -433,22 +433,30 @@ struct find_mlir_standalone_attention_op
} }
inputs.push_back(old_ins); inputs.push_back(old_ins);
return std::make_pair(old_ins, return std::make_pair(old_ins,
mm->add_parameter("bdot_non_smax_in", old_ins->get_shape())); mm->add_parameter("v", old_ins->get_shape()));
}); });
auto bottom_dot_a = ins_map[r.instructions["bottom_dot"]->inputs().front()]; auto gemm1_a = ins_map[r.instructions["gemm1"]->inputs().front()];
auto bottom_dot_b = ins_map[r.instructions["bottom_dot"]->inputs().back()]; auto gemm1_b = ins_map[r.instructions["gemm1"]->inputs().back()];
auto new_bottom_dot = mm->add_instruction(make_op("dot"), {bottom_dot_a, bottom_dot_b}); auto new_gemm1 = mm->add_instruction(make_op("dot"), {gemm1_a, gemm1_b});
ins_map[r.instructions["gemm1"]] = new_gemm1;
auto ins_to_replace = new_gemm1;
auto ins_to_be_replaced = r.instructions["gemm1"];
if(r.instructions.find("trailing_pm") != r.instructions.end()){ if(r.instructions.find("trailing_pm") != r.instructions.end()){
new_bottom_dot = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0]; ins_to_replace = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
std::copy_if(r.instructions["trailing_pm"]->inputs().begin(),
r.instructions["trailing_pm"]->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != r.instructions["gemm1"]; });
ins_to_be_replaced = r.instructions["trailing_pm"];
} }
mm->add_return({new_bottom_dot}); mm->add_return({ins_to_replace});
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
r.instructions["bottom_dot"], mlir_op{new_bottom_dot->get_operator()}, inputs, {mm}); ins_to_be_replaced, mlir_op{new_gemm1->get_operator()}, inputs, {mm});
} }
auto matcher() const { auto matcher() const {
auto match_softmax_input = match::any_of[match::inputs()](match::name("dot").bind("top_dot"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot").bind("top_dot"))).bind("scale")); auto match_softmax_input = match::any_of[match::inputs()](match::name("dot").bind("gemm0"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot").bind("gemm0"))).bind("scale"));
auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax")(match_softmax_input).bind("softmax"))).bind("bottom_dot"); auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax")(match_softmax_input).bind("softmax"))).bind("gemm1");
return is_mlir_attention; return is_mlir_attention;
} }
...@@ -458,7 +466,7 @@ struct find_mlir_standalone_attention_op ...@@ -458,7 +466,7 @@ struct find_mlir_standalone_attention_op
if(mode != mlir_mode::all){ if(mode != mlir_mode::all){
return false; return false;
} }
auto top_dot = r.instructions["top_dot"]; auto gemm0 = r.instructions["gemm0"];
// Check the pointwise mod only contains a single mul // Check the pointwise mod only contains a single mul
if(r.instructions.find("scale") != r.instructions.end()){ if(r.instructions.find("scale") != r.instructions.end()){
auto scale_pm = r.instructions["scale"]; auto scale_pm = r.instructions["scale"];
...@@ -475,7 +483,7 @@ struct find_mlir_standalone_attention_op ...@@ -475,7 +483,7 @@ struct find_mlir_standalone_attention_op
} }
} }
// enable only for fp32/fp16/i8 types // enable only for fp32/fp16/i8 types
if(std::any_of(top_dot->inputs().begin(), top_dot->inputs().end(), [&](auto i) { if(std::any_of(gemm0->inputs().begin(), gemm0->inputs().end(), [&](auto i) {
return not contains( return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type()); i->get_shape().type());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment