Commit ad6dcf26 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

* add input fusion for v input

parent f69d828d
......@@ -64,14 +64,14 @@ struct mlir_op
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
module_ref mod = mods[0];
check_shapes{inputs, *this}.packed_or_broadcasted();
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
module_ref mod = mods[0];
auto type = mod->get_output_shapes().front().type();
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
for(auto ins : iterator_for(*mod))
{
......@@ -101,6 +101,27 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op);
namespace {
std::tuple<instruction_ref, std::vector<operation>>
get_fusable_input_op_stream(instruction_ref lower_input)
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(
contains({"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
upper_input->name()))
{
operation op = upper_input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name()))
{
op = migraphx::make_op("reshape", {{"dims", upper_input->get_shape().lens()}});
}
op_stream.push_back(op);
upper_input = upper_input->inputs().at(0);
}
return {upper_input, op_stream};
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{
......@@ -109,22 +130,10 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains(
{"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
input->name()))
{
operation op = input->get_operator();
if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
{
op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
}
op_stream.push_back(op);
input = input->inputs().at(0);
}
top_inputs.push_back(input);
auto [upper_input, op_stream] = get_fusable_input_op_stream(input);
top_inputs.push_back(upper_input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
mm->add_parameter("y" + std::to_string(input_cnt++), upper_input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
......@@ -424,9 +433,16 @@ struct find_mlir_standalone_attention_op
{
return std::make_pair(old_ins, softmax);
}
inputs.push_back(old_ins);
return std::make_pair(old_ins,
mm->add_parameter("v", old_ins->get_shape()));
auto [old_upper_ins, op_stream] = get_fusable_input_op_stream(old_ins);
instruction_ref new_upper_ins =
mm->add_parameter("v", old_upper_ins->get_shape());
instruction_ref prev_input = new_upper_ins;
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
inputs.push_back(old_upper_ins);
return std::make_pair(old_ins, prev_input);
});
auto gemm1_a = ins_map[r.instructions["gemm1"]->inputs().front()];
auto gemm1_b = ins_map[r.instructions["gemm1"]->inputs().back()];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment