"host/vscode:/vscode.git/clone" did not exist on "12dfba3d03f402c051e2129fa21f33264f4d26e5"
Commit 69178b49 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

update rewrite gemm

parent 035a04eb
...@@ -45,10 +45,31 @@ void rewrite_gemm::apply(module& m) const ...@@ -45,10 +45,31 @@ void rewrite_gemm::apply(module& m) const
continue; continue;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto in0 = inputs.at(0); auto in0 = inputs.at(0);
if(in0->get_shape().lens().at(0) != 1) // only batch size = 1 if(in0->get_shape().lens().at(0) != 1) // only batch size = 1
continue; continue;
auto in_size = in0->get_shape().lens().size();
if(in_size == 4 and in0->get_shape().lens().at(1) != 1)
{
continue;
}
auto in1 = inputs.at(1); auto in1 = inputs.at(1);
if(in_size < 4)
{
std::vector<size_t> new_lens0(in0->get_shape().lens().begin(), in0->get_shape().lens().end());
std::vector<size_t> new_lens1(in1->get_shape().lens().begin(), in1->get_shape().lens().end());
std::vector<size_t> ones(4-in_size, 1);
new_lens0.insert(new_lens0.begin(), ones.begin(), ones.end());
new_lens1.insert(new_lens1.begin(), ones.begin(), ones.end());
if(not in0->get_shape().standard())
in0 = m.insert_instruction(in0, make_op("contiguous"), in0);
in0 = m.insert_instruction(in0, make_op("reshape", {{"dims", new_lens0}}), in0);
if(not in1->get_shape().standard())
in1 = m.insert_instruction(in1, make_op("contiguous"), in1);
in1 = m.insert_instruction(in1, make_op("reshape", {{"dims", new_lens1}}), in1);
}
auto in0_transposed = auto in0_transposed =
m.insert_instruction(ins, make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), in0); m.insert_instruction(ins, make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), in0);
auto in1_transposed = auto in1_transposed =
...@@ -57,8 +78,14 @@ void rewrite_gemm::apply(module& m) const ...@@ -57,8 +78,14 @@ void rewrite_gemm::apply(module& m) const
auto conv = make_op("convolution"); auto conv = make_op("convolution");
auto conv_out = m.replace_instruction(ins, conv, {in0_transposed, in1_transposed}); auto conv_out = m.replace_instruction(ins, conv, {in0_transposed, in1_transposed});
auto conv_transpose = auto conv_transpose =
m.add_instruction(make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), conv_out); m.insert_instruction(std::next(conv_out), make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), conv_out);
// m.insert_instruction(std::next(conv_transpose), make_op("unsqueeze"));
auto out_lens = conv_transpose->get_shape().lens();
if(out_lens.size() != in_size)
{
out_lens.erase(out_lens.begin(), out_lens.begin()+(out_lens.size()-in_size));
m.insert_instruction(std::next(conv_transpose), make_op("reshape", {{"dims", out_lens}}), conv_transpose);
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment