Commit d976ecfe authored by Paul's avatar Paul
Browse files

Add missing transpose

parent 8724e73f
......@@ -299,11 +299,19 @@ struct find_dot_mul
}
auto broadcast_v = d_ins->get_operator().to_value();
broadcast_v["out_lens"] = c_ins->get_shape().lens();
auto c_lens = c_ins->get_shape().lens();
std::vector<int64_t> permutation(c_lens.size());
std::iota(permutation.begin(), permutation.end(), 0);
if(c_ins == b_ins)
{
std::swap(permutation.back(), permutation[permutation.size() - 2]);
c_lens = reorder_dims(c_lens, permutation);
}
broadcast_v["out_lens"] = c_lens;
auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_ins);
auto db_transpose_ins = m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), db_ins);
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_transpose_ins);
if(c_ins == b_ins)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment