Commit 26574e25 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

update rewrite_gemm

parent 95b5efeb
...@@ -46,8 +46,9 @@ void rewrite_gemm::apply(module& m) const ...@@ -46,8 +46,9 @@ void rewrite_gemm::apply(module& m) const
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto in0 = inputs.at(0); auto in0 = inputs.at(0);
if(in0->get_shape().lens().at(0) != 1) // only batch size = 1 if(in0->get_shape().lens().size() > 2)
continue; if(in0->get_shape().lens().at(0) != 1) // only batch size = 1
continue;
auto in_size = in0->get_shape().lens().size(); auto in_size = in0->get_shape().lens().size();
if(in_size == 4 and in0->get_shape().lens().at(1) != 1) if(in_size == 4 and in0->get_shape().lens().at(1) != 1)
{ {
...@@ -65,11 +66,11 @@ void rewrite_gemm::apply(module& m) const ...@@ -65,11 +66,11 @@ void rewrite_gemm::apply(module& m) const
new_lens0.insert(new_lens0.begin(), ones.begin(), ones.end()); new_lens0.insert(new_lens0.begin(), ones.begin(), ones.end());
new_lens1.insert(new_lens1.begin(), ones.begin(), ones.end()); new_lens1.insert(new_lens1.begin(), ones.begin(), ones.end());
if(not in0->get_shape().standard()) if(not in0->get_shape().standard())
in0 = m.insert_instruction(in0, make_op("contiguous"), in0); in0 = m.insert_instruction(ins, make_op("contiguous"), in0);
in0 = m.insert_instruction(in0, make_op("reshape", {{"dims", new_lens0}}), in0); in0 = m.insert_instruction(ins, make_op("reshape", {{"dims", new_lens0}}), in0);
if(not in1->get_shape().standard()) if(not in1->get_shape().standard())
in1 = m.insert_instruction(in1, make_op("contiguous"), in1); in1 = m.insert_instruction(ins, make_op("contiguous"), in1);
in1 = m.insert_instruction(in1, make_op("reshape", {{"dims", new_lens1}}), in1); in1 = m.insert_instruction(ins, make_op("reshape", {{"dims", new_lens1}}), in1);
} }
auto in0_transposed = auto in0_transposed =
...@@ -77,19 +78,20 @@ void rewrite_gemm::apply(module& m) const ...@@ -77,19 +78,20 @@ void rewrite_gemm::apply(module& m) const
auto in1_transposed = auto in1_transposed =
m.insert_instruction(ins, make_op("transpose", {{"permutation", {3, 2, 1, 0}}}), in1); m.insert_instruction(ins, make_op("transpose", {{"permutation", {3, 2, 1, 0}}}), in1);
auto conv = make_op("convolution"); auto conv = m.insert_instruction(ins, make_op("convolution"), {in0_transposed, in1_transposed});
auto conv_out = m.replace_instruction(ins, conv, {in0_transposed, in1_transposed});
auto conv_transpose = m.insert_instruction( auto conv_transpose = m.insert_instruction(
std::next(conv_out), make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), conv_out); ins, make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), conv);
auto out_lens = conv_transpose->get_shape().lens(); auto out_lens = conv_transpose->get_shape().lens();
auto conv_transpose_out = conv_transpose;
if(out_lens.size() != in_size) if(out_lens.size() != in_size)
{ {
out_lens.erase(out_lens.begin(), out_lens.begin() + (out_lens.size() - in_size)); out_lens.erase(out_lens.begin(), out_lens.begin() + (out_lens.size() - in_size));
m.insert_instruction(std::next(conv_transpose), conv_transpose_out = m.insert_instruction(ins,
make_op("reshape", {{"dims", out_lens}}), make_op("reshape", {{"dims", out_lens}}),
conv_transpose); conv_transpose);
} }
m.replace_instruction(ins, conv_transpose_out);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment