Commit 9898823d authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

* added fused attention kerenels

parent b83d7788
......@@ -202,23 +202,11 @@ std::vector<instruction_ref> fold_pointwise_mod(instruction_ref pm_ins, module_r
return parent_mod->insert_instructions(parent_mod->end(), pm, param_map);
}
struct find_mlir_fused_ops
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i)
{
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i) const
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
......@@ -279,6 +267,18 @@ struct find_mlir_fused_ops
});
}
return false;
}
struct find_mlir_fused_ops
{
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
void rewrite(module_pass_manager& mpm, const match::matcher_result& r) const
......@@ -297,20 +297,7 @@ struct find_mlir_fused_ops
std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass();
// std::unordered_map<instruction_ref, instruction_ref> param_map =
// create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
// std::transform(names.begin(),
// names.end(),
// ins->inputs().begin(),
// std::inserter(param_map, param_map.end()),
// [&, &anchor = anchor_op](auto name, auto input) {
// if(input == x_ins)
// return std::make_pair(pm->get_parameter(name), anchor);
// return std::make_pair(pm->get_parameter(name),
// mm->add_parameter(name, input->get_shape()));
// });
// mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
mm->add_return(fold_pointwise_mod(ins, mm, {{x_ins, anchor_op}}));
std::vector<instruction_ref> inputs;
......@@ -336,42 +323,6 @@ struct find_mlir_fused_ops
}
};
struct find_mlir_attention_fused_ops : public find_mlir_fused_ops
{
auto matcher() const
{
auto match_softmax_input = match::any_of[match::inputs()](match::name("dot"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot"))).bind("scale"));
auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax")));
return match::name("pointwise")(match::any_of[match::inputs()](is_mlir_attention.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto* pm = ins->module_inputs().front();
// Check the pointwise mod only contains a single mul
if(r.instructions.find("scale") != r.instructions.end()){
auto scale_pm = r.instructions["scale"];
bool found_mul = false;
for(const auto& scale_ins : *scale_pm->module_inputs().front()){
if(contains({"@param", "@literal", "@return"}, scale_ins.name())){
continue;
}
if(scale_ins.name() == "mul" && !found_mul){
found_mul = true;
continue;
}
return;
}
}
// Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i);
}))
return;
rewrite(mpm, r);
}
};
struct find_mlir_standalone_op
{
......@@ -410,24 +361,8 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
};
struct find_mlir_standalone_attention_op : find_mlir_standalone_op
struct find_mlir_standalone_attention_op
{
void insert_to_map(std::unordered_map<instruction_ref, instruction_ref>& ins_map, instruction_ref old_ins, instruction_ref new_ins) const {
if(ins_map.count(new_ins)){
new_ins = ins_map[new_ins];
}
if(!ins_map.count(old_ins)){
ins_map[old_ins] = new_ins;
}
}
instruction_ref get_from_map(std::unordered_map<instruction_ref, instruction_ref>& ins_map, instruction_ref ins) const {
if(ins_map.count(ins)){
return ins_map[ins];
}
return ins;
}
void rewrite(module_pass_manager& mpm, const match::matcher_result& r) const
{
static size_t counter = 0;
......@@ -438,7 +373,7 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op
std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto top_ins = r.instructions["top_dot"];
auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins);
insert_to_map(ins_map, top_ins, new_top_ins);
ins_map[top_ins] = new_top_ins;
if(r.instructions.find("scale") != r.instructions.end()){
auto scale_ins = r.instructions["scale"];
new_top_ins = fold_pointwise_mod(scale_ins, mm, ins_map)[0];
......@@ -459,12 +394,13 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op
return std::make_pair(old_ins,
mm->add_parameter("bdot_non_smax_in", old_ins->get_shape()));
});
auto bottom_dot_a = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().front());
auto bottom_dot_b = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().back());
auto bottom_dot_a = ins_map[r.instructions["bottom_dot"]->inputs().front()];
auto bottom_dot_b = ins_map[r.instructions["bottom_dot"]->inputs().back()];
auto new_bottom_dot = mm->add_instruction(make_op("dot"), {bottom_dot_a, bottom_dot_b});
if(r.instructions.find("trailing_pm") != r.instructions.end()){
new_bottom_dot = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
}
mm->add_return({new_bottom_dot});
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction(
r.instructions["bottom_dot"], mlir_op{new_bottom_dot->get_operator()}, inputs, {mm});
......@@ -476,8 +412,7 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op
return is_mlir_attention;
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
bool check(const match::matcher_result& r) const {
auto top_dot = r.instructions["top_dot"];
// Check the pointwise mod only contains a single mul
if(r.instructions.find("scale") != r.instructions.end()){
......@@ -491,7 +426,7 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op
found_mul = true;
continue;
}
return;
return false;
}
}
// enable only for fp32/fp16/i8 types
......@@ -500,6 +435,44 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
})){
return false;
}
return true;
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
if(!check(r)){
return;
}
rewrite(mpm, r);
}
};
struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
{
auto matcher() const {
auto standalone_matcher = find_mlir_standalone_attention_op::matcher();
return match::name("pointwise")(match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));;
}
bool check(const match::matcher_result& r) const {
if(!find_mlir_standalone_attention_op::check(r)){
return false;
}
auto trailing_pm_ins = r.instructions["trailing_pm"]; // input after contiguous
auto* trailing_pm = trailing_pm_ins->module_inputs().front();
// Whitelist pointwise operators.
if(std::any_of(trailing_pm->begin(), trailing_pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i);
}))
return false;
return true;
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
if(!check(r)){
return;
}
rewrite(mpm, r);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment