Commit 5e22d800 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent 69178b49
...@@ -57,9 +57,11 @@ void rewrite_gemm::apply(module& m) const ...@@ -57,9 +57,11 @@ void rewrite_gemm::apply(module& m) const
if(in_size < 4) if(in_size < 4)
{ {
std::vector<size_t> new_lens0(in0->get_shape().lens().begin(), in0->get_shape().lens().end()); std::vector<size_t> new_lens0(in0->get_shape().lens().begin(),
std::vector<size_t> new_lens1(in1->get_shape().lens().begin(), in1->get_shape().lens().end()); in0->get_shape().lens().end());
std::vector<size_t> ones(4-in_size, 1); std::vector<size_t> new_lens1(in1->get_shape().lens().begin(),
in1->get_shape().lens().end());
std::vector<size_t> ones(4 - in_size, 1);
new_lens0.insert(new_lens0.begin(), ones.begin(), ones.end()); new_lens0.insert(new_lens0.begin(), ones.begin(), ones.end());
new_lens1.insert(new_lens1.begin(), ones.begin(), ones.end()); new_lens1.insert(new_lens1.begin(), ones.begin(), ones.end());
if(not in0->get_shape().standard()) if(not in0->get_shape().standard())
...@@ -75,16 +77,18 @@ void rewrite_gemm::apply(module& m) const ...@@ -75,16 +77,18 @@ void rewrite_gemm::apply(module& m) const
auto in1_transposed = auto in1_transposed =
m.insert_instruction(ins, make_op("transpose", {{"permutation", {3, 2, 1, 0}}}), in1); m.insert_instruction(ins, make_op("transpose", {{"permutation", {3, 2, 1, 0}}}), in1);
auto conv = make_op("convolution"); auto conv = make_op("convolution");
auto conv_out = m.replace_instruction(ins, conv, {in0_transposed, in1_transposed}); auto conv_out = m.replace_instruction(ins, conv, {in0_transposed, in1_transposed});
auto conv_transpose = auto conv_transpose = m.insert_instruction(
m.insert_instruction(std::next(conv_out), make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), conv_out); std::next(conv_out), make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), conv_out);
auto out_lens = conv_transpose->get_shape().lens(); auto out_lens = conv_transpose->get_shape().lens();
if(out_lens.size() != in_size) if(out_lens.size() != in_size)
{ {
out_lens.erase(out_lens.begin(), out_lens.begin()+(out_lens.size()-in_size)); out_lens.erase(out_lens.begin(), out_lens.begin() + (out_lens.size() - in_size));
m.insert_instruction(std::next(conv_transpose), make_op("reshape", {{"dims", out_lens}}), conv_transpose); m.insert_instruction(std::next(conv_transpose),
make_op("reshape", {{"dims", out_lens}}),
conv_transpose);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment