Commit 26574e25 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

update rewrite_gemm

parent 95b5efeb
......@@ -46,8 +46,9 @@ void rewrite_gemm::apply(module& m) const
auto inputs = ins->inputs();
auto in0 = inputs.at(0);
if(in0->get_shape().lens().at(0) != 1) // only batch size = 1
continue;
if(in0->get_shape().lens().size() > 2)
if(in0->get_shape().lens().at(0) != 1) // only batch size = 1
continue;
auto in_size = in0->get_shape().lens().size();
if(in_size == 4 and in0->get_shape().lens().at(1) != 1)
{
......@@ -65,11 +66,11 @@ void rewrite_gemm::apply(module& m) const
new_lens0.insert(new_lens0.begin(), ones.begin(), ones.end());
new_lens1.insert(new_lens1.begin(), ones.begin(), ones.end());
if(not in0->get_shape().standard())
in0 = m.insert_instruction(in0, make_op("contiguous"), in0);
in0 = m.insert_instruction(in0, make_op("reshape", {{"dims", new_lens0}}), in0);
in0 = m.insert_instruction(ins, make_op("contiguous"), in0);
in0 = m.insert_instruction(ins, make_op("reshape", {{"dims", new_lens0}}), in0);
if(not in1->get_shape().standard())
in1 = m.insert_instruction(in1, make_op("contiguous"), in1);
in1 = m.insert_instruction(in1, make_op("reshape", {{"dims", new_lens1}}), in1);
in1 = m.insert_instruction(ins, make_op("contiguous"), in1);
in1 = m.insert_instruction(ins, make_op("reshape", {{"dims", new_lens1}}), in1);
}
auto in0_transposed =
......@@ -77,19 +78,20 @@ void rewrite_gemm::apply(module& m) const
auto in1_transposed =
m.insert_instruction(ins, make_op("transpose", {{"permutation", {3, 2, 1, 0}}}), in1);
auto conv = make_op("convolution");
auto conv_out = m.replace_instruction(ins, conv, {in0_transposed, in1_transposed});
auto conv = m.insert_instruction(ins, make_op("convolution"), {in0_transposed, in1_transposed});
auto conv_transpose = m.insert_instruction(
std::next(conv_out), make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), conv_out);
ins, make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), conv);
auto out_lens = conv_transpose->get_shape().lens();
auto conv_transpose_out = conv_transpose;
if(out_lens.size() != in_size)
{
out_lens.erase(out_lens.begin(), out_lens.begin() + (out_lens.size() - in_size));
m.insert_instruction(std::next(conv_transpose),
conv_transpose_out = m.insert_instruction(ins,
make_op("reshape", {{"dims", out_lens}}),
conv_transpose);
}
m.replace_instruction(ins, conv_transpose_out);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment