Commit e213e87e authored by Paul's avatar Paul
Browse files

Format

parent 73f78d8d
...@@ -799,25 +799,23 @@ struct find_transpose_slice ...@@ -799,25 +799,23 @@ struct find_transpose_slice
struct find_reshape_gemm struct find_reshape_gemm
{ {
auto matcher() const auto matcher() const { return match::name("reshape")(match::arg(0)(match::name("dot"))); }
{
return match::name("reshape")(match::arg(0)(match::name("dot")));
}
static bool is_batched_unsqueeze(instruction_ref ins) static bool is_batched_unsqueeze(instruction_ref ins)
{ {
auto input = ins->inputs().front()->get_shape().lens(); auto input = ins->inputs().front()->get_shape().lens();
auto output = ins->get_shape().lens(); auto output = ins->get_shape().lens();
if (output.size() <= input.size()) if(output.size() <= input.size())
return false; return false;
if (not std::equal(input.end() - 2, input.end(), output.end() - 2, output.end())) if(not std::equal(input.end() - 2, input.end(), output.end() - 2, output.end()))
return false; return false;
return true; return true;
} }
static operation make_reshape(std::vector<std::size_t> batches, instruction_ref ins) static operation make_reshape(std::vector<std::size_t> batches, instruction_ref ins)
{ {
batches.insert(batches.end(), ins->get_shape().lens().end() - 2, ins->get_shape().lens().end()); batches.insert(
batches.end(), ins->get_shape().lens().end() - 2, ins->get_shape().lens().end());
return make_op("reshape", {{"dims", batches}}); return make_op("reshape", {{"dims", batches}});
} }
...@@ -827,14 +825,18 @@ struct find_reshape_gemm ...@@ -827,14 +825,18 @@ struct find_reshape_gemm
auto dot_ins = reshape_ins->inputs().front(); auto dot_ins = reshape_ins->inputs().front();
// TODO: Put this in the matcher // TODO: Put this in the matcher
if (not is_batched_unsqueeze(reshape_ins)) if(not is_batched_unsqueeze(reshape_ins))
return; return;
std::vector<std::size_t> batches; std::vector<std::size_t> batches;
std::copy(reshape_ins->get_shape().lens().begin(), reshape_ins->get_shape().lens().end() - 2, std::back_inserter(batches)); std::copy(reshape_ins->get_shape().lens().begin(),
reshape_ins->get_shape().lens().end() - 2,
auto input0 = m.insert_instruction(dot_ins, make_reshape(batches, dot_ins->inputs()[0]), dot_ins->inputs()[0]); std::back_inserter(batches));
auto input1 = m.insert_instruction(dot_ins, make_reshape(batches, dot_ins->inputs()[1]), dot_ins->inputs()[1]);
auto input0 = m.insert_instruction(
dot_ins, make_reshape(batches, dot_ins->inputs()[0]), dot_ins->inputs()[0]);
auto input1 = m.insert_instruction(
dot_ins, make_reshape(batches, dot_ins->inputs()[1]), dot_ins->inputs()[1]);
m.replace_instruction(dot_ins, make_op("dot"), input0, input1); m.replace_instruction(dot_ins, make_op("dot"), input0, input1);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment