Commit ad6dcf26 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

* add input fusion for v input

parent f69d828d
...@@ -64,14 +64,14 @@ struct mlir_op ...@@ -64,14 +64,14 @@ struct mlir_op
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ {
module_ref mod = mods[0];
check_shapes{inputs, *this}.packed_or_broadcasted(); check_shapes{inputs, *this}.packed_or_broadcasted();
if(mods.size() != 1) if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
module_ref mod = mods[0]; auto type = mod->get_output_shapes().front().type();
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes; std::unordered_map<instruction_ref, shape> ins_shapes;
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
...@@ -101,6 +101,27 @@ struct mlir_op ...@@ -101,6 +101,27 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op); MIGRAPHX_REGISTER_OP(mlir_op);
namespace { namespace {
std::tuple<instruction_ref, std::vector<operation>>
get_fusable_input_op_stream(instruction_ref lower_input)
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(
contains({"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
upper_input->name()))
{
operation op = upper_input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name()))
{
op = migraphx::make_op("reshape", {{"dims", upper_input->get_shape().lens()}});
}
op_stream.push_back(op);
upper_input = upper_input->inputs().at(0);
}
return {upper_input, op_stream};
}
std::tuple<instruction_ref, std::vector<instruction_ref>> std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{ {
...@@ -109,22 +130,10 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) ...@@ -109,22 +130,10 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
size_t input_cnt = 0; size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs()) for(instruction_ref input : gemm_based_op->inputs())
{ {
std::vector<operation> op_stream; auto [upper_input, op_stream] = get_fusable_input_op_stream(input);
while(contains( top_inputs.push_back(upper_input);
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input = instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape()); mm->add_parameter("y" + std::to_string(input_cnt++), upper_input->get_shape());
for(const auto& op : reverse(op_stream)) for(const auto& op : reverse(op_stream))
{ {
prev_input = mm->add_instruction(op, {prev_input}); prev_input = mm->add_instruction(op, {prev_input});
...@@ -424,9 +433,16 @@ struct find_mlir_standalone_attention_op ...@@ -424,9 +433,16 @@ struct find_mlir_standalone_attention_op
{ {
return std::make_pair(old_ins, softmax); return std::make_pair(old_ins, softmax);
} }
inputs.push_back(old_ins); auto [old_upper_ins, op_stream] = get_fusable_input_op_stream(old_ins);
return std::make_pair(old_ins, instruction_ref new_upper_ins =
mm->add_parameter("v", old_ins->get_shape())); mm->add_parameter("v", old_upper_ins->get_shape());
instruction_ref prev_input = new_upper_ins;
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
inputs.push_back(old_upper_ins);
return std::make_pair(old_ins, prev_input);
}); });
auto gemm1_a = ins_map[r.instructions["gemm1"]->inputs().front()]; auto gemm1_a = ins_map[r.instructions["gemm1"]->inputs().front()];
auto gemm1_b = ins_map[r.instructions["gemm1"]->inputs().back()]; auto gemm1_b = ins_map[r.instructions["gemm1"]->inputs().back()];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment