Commit 1a9924c0 authored by turneram's avatar turneram
Browse files

Formatting

parent 540e262f
...@@ -42,70 +42,77 @@ void rewrite_batched_gemms::apply(module& m) const ...@@ -42,70 +42,77 @@ void rewrite_batched_gemms::apply(module& m) const
{ {
if(ins->name() != "dot") if(ins->name() != "dot")
continue; continue;
//std::cout << "Rewrite Batched GEMMS" << std::endl; // std::cout << "Rewrite Batched GEMMS" << std::endl;
//ins->debug_print(); // ins->debug_print();
//m.debug_print(); // m.debug_print();
//return; // return;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto a_mat = inputs.front(); auto a_mat = inputs.front();
auto b_mat = inputs.at(1); //.back()? auto b_mat = inputs.at(1); //.back()?
auto a_lens = a_mat->get_shape().lens(); auto a_lens = a_mat->get_shape().lens();
auto b_lens = b_mat->get_shape().lens(); auto b_lens = b_mat->get_shape().lens();
if (a_lens.size() > 2) if(a_lens.size() > 2)
{ {
auto batch_size = std::accumulate( auto batch_size = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto reshape_a = m.insert_instruction(ins, make_op("reshape", {{"dims", {batch_size * a_lens[a_lens.size() - 2], a_lens.back()}}}), a_mat); auto reshape_a = m.insert_instruction(
//reshape_a->debug_print(); ins,
//std::cout << b_mat->get_operator().name() << std::endl; make_op("reshape",
{{"dims", {batch_size * a_lens[a_lens.size() - 2], a_lens.back()}}}),
a_mat);
// reshape_a->debug_print();
// std::cout << b_mat->get_operator().name() << std::endl;
instruction_ref unbc_b; instruction_ref unbc_b;
if (b_mat->get_operator().name() == "concat") if(b_mat->get_operator().name() == "concat")
{ {
auto concat_inputs = b_mat->inputs(); auto concat_inputs = b_mat->inputs();
std::vector<instruction_ref> concat_lits; std::vector<instruction_ref> concat_lits;
int concat_axis = 1; int concat_axis = 1;
bool return_early = false; bool return_early = false;
for (auto c : concat_inputs) for(auto c : concat_inputs)
{ {
if (c->get_operator().name() == "contiguous") if(c->get_operator().name() == "contiguous")
c = c->inputs().front(); c = c->inputs().front();
//std::cout << c->get_operator().name() << ", " << c->get_shape() << ", " << c->get_shape().broadcasted() <<std::endl; // std::cout << c->get_operator().name() << ", " << c->get_shape() << ", " <<
if (c->get_shape().broadcasted()) // c->get_shape().broadcasted() <<std::endl;
if(c->get_shape().broadcasted())
{ {
//std::cout << c->inputs().front()->get_operator() <<std::endl; // std::cout << c->inputs().front()->get_operator() <<std::endl;
auto lit = c->inputs().front(); auto lit = c->inputs().front();
auto lit_dims = lit->get_shape().lens().size(); auto lit_dims = lit->get_shape().lens().size();
if (lit_dims > 2) if(lit_dims > 2)
return_early = true; return_early = true;
concat_axis = lit_dims - 1; concat_axis = lit_dims - 1;
concat_lits.push_back(lit); concat_lits.push_back(lit);
} }
} }
if (return_early) if(return_early)
continue; continue;
unbc_b = m.insert_instruction(ins, make_op("concat", {{"axis", concat_axis}}), concat_lits); unbc_b = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), concat_lits);
} }
else if (b_mat->get_operator().name() == "contiguous") else if(b_mat->get_operator().name() == "contiguous")
{ {
//std::cout << "Contiguous B" <<std::endl; // std::cout << "Contiguous B" <<std::endl;
//b_mat->debug_print(); // b_mat->debug_print();
auto b_input = b_mat->inputs().front(); auto b_input = b_mat->inputs().front();
//std::cout << b_input->get_operator().name() << ", " << b_input->get_shape().broadcasted() << ", " << b_input->can_eval() << std::endl; // std::cout << b_input->get_operator().name() << ", " <<
if (b_input->get_shape().broadcasted()) // b_input->get_shape().broadcasted() << ", " << b_input->can_eval() << std::endl;
if(b_input->get_shape().broadcasted())
{ {
auto lit = b_input->inputs().front(); auto lit = b_input->inputs().front();
auto lit_dims = lit->get_shape().lens().size(); auto lit_dims = lit->get_shape().lens().size();
if (lit_dims > 2) if(lit_dims > 2)
continue; continue;
unbc_b = lit; unbc_b = lit;
//unbc_b->debug_print(); // unbc_b->debug_print();
} }
else else
continue; continue;
} }
else else
{ {
//std::cout << "Else" << std::endl; // std::cout << "Else" << std::endl;
continue; continue;
} }
auto new_dot = m.insert_instruction(ins, make_op("dot"), reshape_a, unbc_b); auto new_dot = m.insert_instruction(ins, make_op("dot"), reshape_a, unbc_b);
...@@ -113,24 +120,25 @@ void rewrite_batched_gemms::apply(module& m) const ...@@ -113,24 +120,25 @@ void rewrite_batched_gemms::apply(module& m) const
out_lens.pop_back(); out_lens.pop_back();
out_lens.push_back(b_lens.back()); out_lens.push_back(b_lens.back());
//std::cout << std::next(ins)->get_operator().name() << std::endl; // std::cout << std::next(ins)->get_operator().name() << std::endl;
auto next_ins = std::next(ins); auto next_ins = std::next(ins);
if (next_ins->get_operator().name() == "add") if(next_ins->get_operator().name() == "add")
{ {
auto add_in = next_ins->inputs().back() == ins ? next_ins->inputs().front() : next_ins->inputs().back(); auto add_in = next_ins->inputs().back() == ins ? next_ins->inputs().front()
//add_in->debug_print(); : next_ins->inputs().back();
auto reshape_add = m.insert_instruction(next_ins, make_op("reshape", {{"dims", {batch_size * a_lens[a_lens.size() - 2], b_lens.back()}}}), add_in); // add_in->debug_print();
auto reshape_add = m.insert_instruction(
next_ins,
make_op("reshape",
{{"dims", {batch_size * a_lens[a_lens.size() - 2], b_lens.back()}}}),
add_in);
new_dot = m.replace_instruction(next_ins, make_op("add"), reshape_add, new_dot); new_dot = m.replace_instruction(next_ins, make_op("add"), reshape_add, new_dot);
} }
//std::cout << "here" <<std::endl; // std::cout << "here" <<std::endl;
m.replace_instruction(ins, make_op("reshape", {{"dims", out_lens}}), new_dot); m.replace_instruction(ins, make_op("reshape", {{"dims", out_lens}}), new_dot);
} }
// m.debug_print();
//m.debug_print();
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment