Commit 25d6b2e2 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

[WIP] factoring out pointwise folding

* some skeleton code to handle attention patterns
parent a3cf9951
...@@ -164,18 +164,8 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) ...@@ -164,18 +164,8 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
return true; return true;
} }
struct find_mlir_fused_ops std::unordered_map<instruction_ref, instruction_ref>
{ create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape)
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{ {
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm)) for(auto ins : iterator_for(*pm))
...@@ -193,6 +183,37 @@ struct find_mlir_fused_ops ...@@ -193,6 +183,37 @@ struct find_mlir_fused_ops
return ins_map; return ins_map;
} }
std::vector<instruction_ref> fold_pointwise_mod(instruction_ref pm_ins, module_ref parent_mod, const std::unordered_map<instruction_ref, instruction_ref>& ins_map){
auto* pm = pm_ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(parent_mod, pm, pm_ins->get_shape());
std::transform( names.begin(),
names.end(),
pm_ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&](auto name, auto input) {
if(ins_map.count(input))
return std::make_pair(pm->get_parameter(name), ins_map.at(input));
return std::make_pair(pm->get_parameter(name),
parent_mod->add_parameter(name, input->get_shape()));
});
return parent_mod->insert_instructions(parent_mod->end(), pm, param_map);
}
struct find_mlir_fused_ops
{
auto matcher() const
{
auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
// Whitelist supported fusion options, including imposing type constraints // Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function) // for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types. // on particular types.
...@@ -260,7 +281,7 @@ struct find_mlir_fused_ops ...@@ -260,7 +281,7 @@ struct find_mlir_fused_ops
return false; return false;
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void rewrite(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto gemm_based_op = r.instructions["gemm_based_op"]; auto gemm_based_op = r.instructions["gemm_based_op"];
...@@ -276,20 +297,21 @@ struct find_mlir_fused_ops ...@@ -276,20 +297,21 @@ struct find_mlir_fused_ops
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map = // std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(mm, pm, gemm_based_op->get_shape()); // create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
std::transform(names.begin(), // std::transform(names.begin(),
names.end(), // names.end(),
ins->inputs().begin(), // ins->inputs().begin(),
std::inserter(param_map, param_map.end()), // std::inserter(param_map, param_map.end()),
[&, &anchor = anchor_op](auto name, auto input) { // [&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins) // if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor); // return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name), // return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); // mm->add_parameter(name, input->get_shape()));
}); // });
mm->add_return(mm->insert_instructions(mm->end(), pm, param_map)); // mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
mm->add_return(fold_pointwise_mod(ins, mm, {{x_ins, anchor_op}}));
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::copy_if(ins->inputs().begin(), std::copy_if(ins->inputs().begin(),
...@@ -300,10 +322,70 @@ struct find_mlir_fused_ops ...@@ -300,10 +322,70 @@ struct find_mlir_fused_ops
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm}); ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto* pm = ins->module_inputs().front();
// Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i);
}))
return;
rewrite(mpm, r);
}
};
struct find_mlir_attention_fused_ops : public find_mlir_fused_ops
{
auto matcher() const
{
auto match_softmax_input = match::any_of[match::inputs()](match::name("dot"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot"))).bind("scale"));
auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax")));
return match::name("pointwise")(match::any_of[match::inputs()](is_mlir_attention.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto* pm = ins->module_inputs().front();
// Check the pointwise mod only contains a single mul
if(r.instructions.find("scale") != r.instructions.end()){
auto scale_pm = r.instructions["scale"];
bool found_mul = false;
for(const auto& scale_ins : *scale_pm->module_inputs().front()){
if(contains({"@param", "@literal", "@return"}, scale_ins.name())){
continue;
}
if(scale_ins.name() == "mul" && !found_mul){
found_mul = true;
continue;
}
return;
}
}
// Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i);
}))
return;
rewrite(mpm, r);
}
}; };
struct find_mlir_standalone_op struct find_mlir_standalone_op
{ {
void rewrite(module_pass_manager& mpm, instruction_ref top_ins) const
{
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins);
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
top_ins, mlir_op{top_ins->get_operator()}, top_inputs, {mm});
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto conv_based_op = r.result; auto conv_based_op = r.result;
...@@ -314,14 +396,7 @@ struct find_mlir_standalone_op ...@@ -314,14 +396,7 @@ struct find_mlir_standalone_op
i->get_shape().type()); i->get_shape().type());
})) }))
return; return;
rewrite(mpm, conv_based_op);
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
} }
}; };
...@@ -335,6 +410,93 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op ...@@ -335,6 +410,93 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); } auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
}; };
struct find_mlir_standalone_attention_op : find_mlir_standalone_op
{
void insert_to_map(std::unordered_map<instruction_ref, instruction_ref>& ins_map, instruction_ref old_ins, instruction_ref new_ins) const {
if(ins_map.count(new_ins)){
new_ins = ins_map[new_ins];
}
if(!ins_map.count(old_ins)){
ins_map[old_ins] = new_ins;
}
}
instruction_ref get_from_map(std::unordered_map<instruction_ref, instruction_ref>& ins_map, instruction_ref ins) const {
if(ins_map.count(ins)){
return ins_map[ins];
}
return ins;
}
void rewrite(module_pass_manager& mpm, const match::matcher_result& r) const
{
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
std::vector<instruction_ref> inputs;
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto top_ins = r.instructions["top_dot"];
auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins);
insert_to_map(ins_map, top_ins, new_top_ins);
if(r.instructions.find("scale") != r.instructions.end()){
auto scale_ins = r.instructions["scale"];
new_top_ins = fold_pointwise_mod(scale_ins, mm, ins_map)[0];
std::copy_if(scale_ins->inputs().begin(),
scale_ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != top_ins; });
}
auto softmax = mm->add_instruction(r.instructions["softmax"]->get_operator(), new_top_ins);
insert_to_map(ins_map, r.instructions["softmax"], softmax);
auto bottom_dot_a = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().front());
auto bottom_dot_b = get_from_map(ins_map, r.instructions["bottom_dot"]->inputs().back());
auto new_bottom_dot = mm->add_instruction(make_op("dot"), {bottom_dot_a, bottom_dot_b});
mm->add_return({new_bottom_dot});
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction(
top_ins, mlir_op{new_bottom_dot->get_operator()}, inputs, {mm});
}
auto matcher() const {
auto match_softmax_input = match::any_of[match::inputs()](match::name("dot").bind("top_dot"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot").bind("top_dot"))).bind("scale"));
auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax").bind("softmax"))).bind("bottom_dot");
return is_mlir_attention;
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto top_dot = r.instructions["top_dot"];
// Check the pointwise mod only contains a single mul
std::cerr << "standalone attention found!\n";
if(r.instructions.find("scale") != r.instructions.end()){
auto scale_pm = r.instructions["scale"];
bool found_mul = false;
for(const auto& scale_ins : *scale_pm->module_inputs().front()){
if(contains({"@param", "@literal", "@return"}, scale_ins.name())){
continue;
}
if(scale_ins.name() == "mul" && !found_mul){
found_mul = true;
continue;
}
std::cerr << "standalone attention scale not compatible!\n";
return;
}
}
// enable only for fp32/fp16/i8 types
if(std::any_of(top_dot->inputs().begin(), top_dot->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
})){
std::cerr << "standalone attention dtype not compatible!\n";
return;
}
rewrite(mpm, r);
}
};
/** /**
* @brief Declares a new MIGraphX environment variable which forces to generate * @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations. * only specific MLIR operations.
...@@ -393,9 +555,11 @@ void fuse_mlir::apply(module_pass_manager& mpm) const ...@@ -393,9 +555,11 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
if(is_enabled("fused", this->ctx)) if(is_enabled("fused", this->ctx))
{ {
match::find_matches(mpm, find_mlir_attention_fused_ops{});
match::find_matches(mpm, find_mlir_fused_ops{}); match::find_matches(mpm, find_mlir_fused_ops{});
} }
match::find_matches(mpm, find_mlir_standalone_attention_op{});
if(is_enabled("convolution", this->ctx)) if(is_enabled("convolution", this->ctx))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment