"python/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "0d4f3a9fcdea60ac327a6a5897a281a1d763c3ac"
Commit b717b473 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

* standalone attention sane offloads to mlir but mlir is broken

parent 25d6b2e2
...@@ -81,6 +81,7 @@ struct mlir_op ...@@ -81,6 +81,7 @@ struct mlir_op
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
module_ref mod = mods[0]; module_ref mod = mods[0];
std::cerr << "mod:" << *mod << std::endl;
auto type = mod->get_output_shapes().front().type(); auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes; std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0; size_t param_cnt = 0;
...@@ -448,19 +449,31 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op ...@@ -448,19 +449,31 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op
[&](auto input) { return input != top_ins; }); [&](auto input) { return input != top_ins; });
} }
auto softmax = mm->add_instruction(r.instructions["softmax"]->get_operator(), new_top_ins); auto softmax = mm->add_instruction(r.instructions["softmax"]->get_operator(), new_top_ins);
insert_to_map(ins_map, r.instructions["softmax"], softmax); std::transform(r.instructions["bottom_dot"]->inputs().begin(),
r.instructions["bottom_dot"]->inputs().end(),
std::inserter(ins_map, ins_map.end()),
[&](auto old_ins) {
if(old_ins == r.instructions["softmax"]){
return std::make_pair(old_ins, softmax);
}
inputs.push_back(old_ins);
return std::make_pair(old_ins,
mm->add_parameter("bdot_non_smax_in", old_ins->get_shape()));
});
auto bottom_dot_a = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().front()); auto bottom_dot_a = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().front());
auto bottom_dot_b = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().back()); auto bottom_dot_b = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().back());
auto new_bottom_dot = mm->add_instruction(make_op("dot"), {bottom_dot_a, bottom_dot_b}); auto new_bottom_dot = mm->add_instruction(make_op("dot"), {bottom_dot_a, bottom_dot_b});
mm->add_return({new_bottom_dot}); mm->add_return({new_bottom_dot});
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end()); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
top_ins, mlir_op{new_bottom_dot->get_operator()}, inputs, {mm}); r.instructions["bottom_dot"], mlir_op{new_bottom_dot->get_operator()}, inputs, {mm});
} }
auto matcher() const { auto matcher() const {
auto match_softmax_input = match::any_of[match::inputs()](match::name("dot").bind("top_dot"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot").bind("top_dot"))).bind("scale")); auto match_softmax_input = match::any_of[match::inputs()](match::name("dot").bind("top_dot"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot").bind("top_dot"))).bind("scale"));
auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax").bind("softmax"))).bind("bottom_dot"); auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax")(match_softmax_input).bind("softmax"))).bind("bottom_dot");
return is_mlir_attention; return is_mlir_attention;
} }
...@@ -554,12 +567,12 @@ bool is_enabled(std::string_view op_name, context* ctx) ...@@ -554,12 +567,12 @@ bool is_enabled(std::string_view op_name, context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_standalone_attention_op{});
if(is_enabled("fused", this->ctx)) if(is_enabled("fused", this->ctx))
{ {
match::find_matches(mpm, find_mlir_attention_fused_ops{}); match::find_matches(mpm, find_mlir_attention_fused_ops{});
match::find_matches(mpm, find_mlir_fused_ops{}); match::find_matches(mpm, find_mlir_fused_ops{});
} }
match::find_matches(mpm, find_mlir_standalone_attention_op{});
if(is_enabled("convolution", this->ctx)) if(is_enabled("convolution", this->ctx))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment