"src/include/threadwise_gemm.hpp" did not exist on "2058bec8cfe6c006409ec3d65b67229ce1c2e6f7"
Commit 73f78d8d authored by Paul's avatar Paul
Browse files

Improve handling of reshape around gemm

parent 69799cb8
......@@ -797,6 +797,48 @@ struct find_transpose_slice
}
};
struct find_reshape_gemm
{
auto matcher() const
{
return match::name("reshape")(match::arg(0)(match::name("dot")));
}
static bool is_batched_unsqueeze(instruction_ref ins)
{
auto input = ins->inputs().front()->get_shape().lens();
auto output = ins->get_shape().lens();
if (output.size() <= input.size())
return false;
if (not std::equal(input.end() - 2, input.end(), output.end() - 2, output.end()))
return false;
return true;
}
static operation make_reshape(std::vector<std::size_t> batches, instruction_ref ins)
{
batches.insert(batches.end(), ins->get_shape().lens().end() - 2, ins->get_shape().lens().end());
return make_op("reshape", {{"dims", batches}});
}
void apply(module& m, const match::matcher_result& r) const
{
auto reshape_ins = r.result;
auto dot_ins = reshape_ins->inputs().front();
// TODO: Put this in the matcher
if (not is_batched_unsqueeze(reshape_ins))
return;
std::vector<std::size_t> batches;
std::copy(reshape_ins->get_shape().lens().begin(), reshape_ins->get_shape().lens().end() - 2, std::back_inserter(batches));
auto input0 = m.insert_instruction(dot_ins, make_reshape(batches, dot_ins->inputs()[0]), dot_ins->inputs()[0]);
auto input1 = m.insert_instruction(dot_ins, make_reshape(batches, dot_ins->inputs()[1]), dot_ins->inputs()[1]);
m.replace_instruction(dot_ins, make_op("dot"), input0, input1);
}
};
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 4; i++)
......@@ -815,7 +857,8 @@ void simplify_reshapes::apply(module& m) const
find_nested_concat{},
find_transpose_slice{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{});
find_transpose_contiguous_reshaper_unary{},
find_reshape_gemm{});
dead_code_elimination{}.apply(m);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment