Commit e213e87e authored by Paul's avatar Paul
Browse files

Format

parent 73f78d8d
......@@ -799,25 +799,23 @@ struct find_transpose_slice
struct find_reshape_gemm
{
auto matcher() const
{
return match::name("reshape")(match::arg(0)(match::name("dot")));
}
auto matcher() const { return match::name("reshape")(match::arg(0)(match::name("dot"))); }
static bool is_batched_unsqueeze(instruction_ref ins)
{
auto input = ins->inputs().front()->get_shape().lens();
auto output = ins->get_shape().lens();
if (output.size() <= input.size())
if(output.size() <= input.size())
return false;
if (not std::equal(input.end() - 2, input.end(), output.end() - 2, output.end()))
if(not std::equal(input.end() - 2, input.end(), output.end() - 2, output.end()))
return false;
return true;
}
static operation make_reshape(std::vector<std::size_t> batches, instruction_ref ins)
{
batches.insert(batches.end(), ins->get_shape().lens().end() - 2, ins->get_shape().lens().end());
batches.insert(
batches.end(), ins->get_shape().lens().end() - 2, ins->get_shape().lens().end());
return make_op("reshape", {{"dims", batches}});
}
......@@ -827,14 +825,18 @@ struct find_reshape_gemm
auto dot_ins = reshape_ins->inputs().front();
// TODO: Put this in the matcher
if (not is_batched_unsqueeze(reshape_ins))
if(not is_batched_unsqueeze(reshape_ins))
return;
std::vector<std::size_t> batches;
std::copy(reshape_ins->get_shape().lens().begin(), reshape_ins->get_shape().lens().end() - 2, std::back_inserter(batches));
auto input0 = m.insert_instruction(dot_ins, make_reshape(batches, dot_ins->inputs()[0]), dot_ins->inputs()[0]);
auto input1 = m.insert_instruction(dot_ins, make_reshape(batches, dot_ins->inputs()[1]), dot_ins->inputs()[1]);
std::copy(reshape_ins->get_shape().lens().begin(),
reshape_ins->get_shape().lens().end() - 2,
std::back_inserter(batches));
auto input0 = m.insert_instruction(
dot_ins, make_reshape(batches, dot_ins->inputs()[0]), dot_ins->inputs()[0]);
auto input1 = m.insert_instruction(
dot_ins, make_reshape(batches, dot_ins->inputs()[1]), dot_ins->inputs()[1]);
m.replace_instruction(dot_ins, make_op("dot"), input0, input1);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment