Commit 59a53257 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

* clang-format

parent 83b9164b
...@@ -206,40 +206,44 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -206,40 +206,44 @@ auto is_mlir_conv(mlir_mode mode)
} }
std::unordered_map<instruction_ref, instruction_ref> std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape)
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{ {
std::unordered_map<instruction_ref, instruction_ref> ins_map; if(ins->name() != "@literal")
for(auto ins : iterator_for(*pm))
{ {
if(ins->name() != "@literal") continue;
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
} }
return ins_map; literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast =
mm->add_instruction(make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
} }
return ins_map;
}
std::vector<instruction_ref> fold_pointwise_mod(instruction_ref pm_ins, module_ref parent_mod, const std::unordered_map<instruction_ref, instruction_ref>& ins_map){ std::vector<instruction_ref>
auto* pm = pm_ins->module_inputs().front(); fold_pointwise_mod(instruction_ref pm_ins,
auto names = pm->get_parameter_names(); module_ref parent_mod,
const std::unordered_map<instruction_ref, instruction_ref>& ins_map)
{
auto* pm = pm_ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
std::unordered_map<instruction_ref, instruction_ref> param_map = std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(parent_mod, pm, pm_ins->get_shape()); create_param_map_with_literals(parent_mod, pm, pm_ins->get_shape());
std::transform( names.begin(), std::transform(names.begin(),
names.end(), names.end(),
pm_ins->inputs().begin(), pm_ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&](auto name, auto input) { [&](auto name, auto input) {
if(ins_map.count(input)) if(ins_map.count(input))
return std::make_pair(pm->get_parameter(name), ins_map.at(input)); return std::make_pair(pm->get_parameter(name), ins_map.at(input));
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
parent_mod->add_parameter(name, input->get_shape())); parent_mod->add_parameter(name, input->get_shape()));
}); });
return parent_mod->insert_instructions(parent_mod->end(), pm, param_map); return parent_mod->insert_instructions(parent_mod->end(), pm, param_map);
} }
...@@ -252,10 +256,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -252,10 +256,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto& name = i.name(); const auto& name = i.name();
const auto result_type = i.get_shape().type(); const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type, const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type, type_t::half_type,
type_t::int8_type, type_t::int8_type,
type_t::int32_type, type_t::int32_type,
type_t::bool_type}; type_t::bool_type};
// Preliminary type check. // Preliminary type check.
if(not contains(allowed_types, result_type)) if(not contains(allowed_types, result_type))
{ {
...@@ -310,8 +314,6 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -310,8 +314,6 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
return false; return false;
} }
struct find_mlir_fused_ops struct find_mlir_fused_ops
{ {
mlir_mode conv_mode = mlir_mode::none; mlir_mode conv_mode = mlir_mode::none;
...@@ -354,8 +356,8 @@ struct find_mlir_fused_ops ...@@ -354,8 +356,8 @@ struct find_mlir_fused_ops
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
// Whitelist pointwise operators. // Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) { if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i); return not is_pointwise_op_supported_by_mlir(i);
...@@ -373,13 +375,13 @@ struct find_mlir_standalone_op ...@@ -373,13 +375,13 @@ struct find_mlir_standalone_op
void rewrite(module_pass_manager& mpm, instruction_ref top_ins) const void rewrite(module_pass_manager& mpm, instruction_ref top_ins) const
{ {
static size_t counter = 0; static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++)); module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
mm->set_bypass(); mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins);
mm->add_return({anchor_op}); mm->add_return({anchor_op});
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
top_ins, mlir_op{top_ins->get_operator()}, top_inputs, {mm}); top_ins, mlir_op{top_ins->get_operator()}, top_inputs, {mm});
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...@@ -405,77 +407,94 @@ struct find_mlir_standalone_attention_op ...@@ -405,77 +407,94 @@ struct find_mlir_standalone_attention_op
mlir_mode mode = mlir_mode::none; mlir_mode mode = mlir_mode::none;
void rewrite(module_pass_manager& mpm, const match::matcher_result& r) const void rewrite(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
static size_t counter = 0; static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++)); module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
mm->set_bypass(); mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
auto top_ins = r.instructions["gemm0"]; auto top_ins = r.instructions["gemm0"];
auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins); auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins);
inputs.insert(inputs.begin(), top_inputs.begin(), top_inputs.end()); inputs.insert(inputs.begin(), top_inputs.begin(), top_inputs.end());
ins_map[top_ins] = new_top_ins; ins_map[top_ins] = new_top_ins;
if(r.instructions.find("scale") != r.instructions.end()){ if(r.instructions.find("scale") != r.instructions.end())
auto scale_ins = r.instructions["scale"]; {
new_top_ins = fold_pointwise_mod(scale_ins, mm, ins_map)[0]; auto scale_ins = r.instructions["scale"];
std::copy_if(scale_ins->inputs().begin(), new_top_ins = fold_pointwise_mod(scale_ins, mm, ins_map)[0];
scale_ins->inputs().end(), std::copy_if(scale_ins->inputs().begin(),
std::back_inserter(inputs), scale_ins->inputs().end(),
[&](auto input) { return input != top_ins; }); std::back_inserter(inputs),
} [&](auto input) { return input != top_ins; });
auto softmax = mm->add_instruction(r.instructions["softmax"]->get_operator(), new_top_ins); }
std::transform(r.instructions["gemm1"]->inputs().begin(), auto softmax = mm->add_instruction(r.instructions["softmax"]->get_operator(), new_top_ins);
r.instructions["gemm1"]->inputs().end(), std::transform(r.instructions["gemm1"]->inputs().begin(),
std::inserter(ins_map, ins_map.end()), r.instructions["gemm1"]->inputs().end(),
[&](auto old_ins) { std::inserter(ins_map, ins_map.end()),
if(old_ins == r.instructions["softmax"]){ [&](auto old_ins) {
return std::make_pair(old_ins, softmax); if(old_ins == r.instructions["softmax"])
} {
inputs.push_back(old_ins); return std::make_pair(old_ins, softmax);
return std::make_pair(old_ins, }
mm->add_parameter("v", old_ins->get_shape())); inputs.push_back(old_ins);
}); return std::make_pair(old_ins,
auto gemm1_a = ins_map[r.instructions["gemm1"]->inputs().front()]; mm->add_parameter("v", old_ins->get_shape()));
auto gemm1_b = ins_map[r.instructions["gemm1"]->inputs().back()]; });
auto new_gemm1 = mm->add_instruction(make_op("dot"), {gemm1_a, gemm1_b}); auto gemm1_a = ins_map[r.instructions["gemm1"]->inputs().front()];
ins_map[r.instructions["gemm1"]] = new_gemm1; auto gemm1_b = ins_map[r.instructions["gemm1"]->inputs().back()];
auto ins_to_replace = new_gemm1; auto new_gemm1 = mm->add_instruction(make_op("dot"), {gemm1_a, gemm1_b});
auto ins_to_be_replaced = r.instructions["gemm1"]; ins_map[r.instructions["gemm1"]] = new_gemm1;
if(r.instructions.find("trailing_pm") != r.instructions.end()){ auto ins_to_replace = new_gemm1;
ins_to_replace = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0]; auto ins_to_be_replaced = r.instructions["gemm1"];
std::copy_if(r.instructions["trailing_pm"]->inputs().begin(), if(r.instructions.find("trailing_pm") != r.instructions.end())
r.instructions["trailing_pm"]->inputs().end(), {
std::back_inserter(inputs), ins_to_replace = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
[&](auto input) { return input != r.instructions["gemm1"]; }); std::copy_if(r.instructions["trailing_pm"]->inputs().begin(),
ins_to_be_replaced = r.instructions["trailing_pm"]; r.instructions["trailing_pm"]->inputs().end(),
} std::back_inserter(inputs),
mm->add_return({ins_to_replace}); [&](auto input) { return input != r.instructions["gemm1"]; });
mpm.get_module().replace_instruction( ins_to_be_replaced = r.instructions["trailing_pm"];
ins_to_be_replaced, mlir_op{new_gemm1->get_operator()}, inputs, {mm}); }
mm->add_return({ins_to_replace});
mpm.get_module().replace_instruction(
ins_to_be_replaced, mlir_op{new_gemm1->get_operator()}, inputs, {mm});
} }
auto matcher() const { auto matcher() const
auto match_softmax_input = match::any_of[match::inputs()](match::name("dot").bind("gemm0"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot").bind("gemm0"))).bind("scale")); {
auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax")(match_softmax_input).bind("softmax"))).bind("gemm1"); auto match_softmax_input = match::any_of[match::inputs()](
match::name("dot").bind("gemm0"),
match::name("pointwise")(
match::any_of[match::inputs()](match::name("dot").bind("gemm0")))
.bind("scale"));
auto is_mlir_attention =
match::name("dot")(match::any_of[match::inputs()](
match::name("softmax")(match_softmax_input).bind("softmax")))
.bind("gemm1");
return is_mlir_attention; return is_mlir_attention;
} }
bool check(const match::matcher_result& r) const { bool check(const match::matcher_result& r) const
{
// We are only enabling attention // We are only enabling attention
// in the highest enablement mode for now // in the highest enablement mode for now
if(mode != mlir_mode::all){ if(mode != mlir_mode::all)
{
return false; return false;
} }
auto gemm0 = r.instructions["gemm0"]; auto gemm0 = r.instructions["gemm0"];
// Check the pointwise mod only contains a single mul // Check the pointwise mod only contains a single mul
if(r.instructions.find("scale") != r.instructions.end()){ if(r.instructions.find("scale") != r.instructions.end())
auto scale_pm = r.instructions["scale"]; {
auto scale_pm = r.instructions["scale"];
bool found_mul = false; bool found_mul = false;
for(const auto& scale_ins : *scale_pm->module_inputs().front()){ for(const auto& scale_ins : *scale_pm->module_inputs().front())
if(contains({"@param", "@literal", "@return"}, scale_ins.name())){ {
if(contains({"@param", "@literal", "@return"}, scale_ins.name()))
{
continue; continue;
} }
if(scale_ins.name() == "mul" && !found_mul){ if(scale_ins.name() == "mul" && !found_mul)
{
found_mul = true; found_mul = true;
continue; continue;
} }
...@@ -487,7 +506,8 @@ struct find_mlir_standalone_attention_op ...@@ -487,7 +506,8 @@ struct find_mlir_standalone_attention_op
return not contains( return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type()); i->get_shape().type());
})){ }))
{
return false; return false;
} }
return true; return true;
...@@ -495,7 +515,8 @@ struct find_mlir_standalone_attention_op ...@@ -495,7 +515,8 @@ struct find_mlir_standalone_attention_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
if(!check(r)){ if(!check(r))
{
return; return;
} }
rewrite(mpm, r); rewrite(mpm, r);
...@@ -504,17 +525,22 @@ struct find_mlir_standalone_attention_op ...@@ -504,17 +525,22 @@ struct find_mlir_standalone_attention_op
struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
{ {
auto matcher() const { auto matcher() const
{
auto standalone_matcher = find_mlir_standalone_attention_op::matcher(); auto standalone_matcher = find_mlir_standalone_attention_op::matcher();
return match::name("pointwise")(match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));; return match::name("pointwise")(
match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));
;
} }
bool check(const match::matcher_result& r) const { bool check(const match::matcher_result& r) const
if(!find_mlir_standalone_attention_op::check(r)){ {
if(!find_mlir_standalone_attention_op::check(r))
{
return false; return false;
} }
auto trailing_pm_ins = r.instructions["trailing_pm"]; // input after contiguous auto trailing_pm_ins = r.instructions["trailing_pm"]; // input after contiguous
auto* trailing_pm = trailing_pm_ins->module_inputs().front(); auto* trailing_pm = trailing_pm_ins->module_inputs().front();
// Whitelist pointwise operators. // Whitelist pointwise operators.
if(std::any_of(trailing_pm->begin(), trailing_pm->end(), [&](const auto& i) { if(std::any_of(trailing_pm->begin(), trailing_pm->end(), [&](const auto& i) {
return not is_pointwise_op_supported_by_mlir(i); return not is_pointwise_op_supported_by_mlir(i);
...@@ -525,7 +551,8 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op ...@@ -525,7 +551,8 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
if(!check(r)){ if(!check(r))
{
return; return;
} }
rewrite(mpm, r); rewrite(mpm, r);
...@@ -547,7 +574,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); ...@@ -547,7 +574,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_requested(std::string_view option, bool fallback = false) bool is_requested(std::string_view option, bool fallback = false)
{ {
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""); auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty()) if(string_value.empty())
return fallback; return fallback;
const auto options = split_string(string_value, ','); const auto options = split_string(string_value, ',');
...@@ -575,8 +602,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const ...@@ -575,8 +602,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
(enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none; (enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
// Attention offloads; default disabled // Attention offloads; default disabled
match::find_matches(mpm, match::find_matches(mpm, find_mlir_attention_fused_ops{get_mode("attention", mlir_mode::none)});
find_mlir_attention_fused_ops{get_mode("attention", mlir_mode::none)});
match::find_matches(mpm, match::find_matches(mpm,
find_mlir_standalone_attention_op{get_mode("attention", mlir_mode::none)}); find_mlir_standalone_attention_op{get_mode("attention", mlir_mode::none)});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment