Commit 1a9924c0 authored by turneram's avatar turneram
Browse files

Formatting

parent 540e262f
......@@ -42,70 +42,77 @@ void rewrite_batched_gemms::apply(module& m) const
{
if(ins->name() != "dot")
continue;
//std::cout << "Rewrite Batched GEMMS" << std::endl;
//ins->debug_print();
//m.debug_print();
//return;
// std::cout << "Rewrite Batched GEMMS" << std::endl;
// ins->debug_print();
// m.debug_print();
// return;
auto inputs = ins->inputs();
auto a_mat = inputs.front();
auto b_mat = inputs.at(1); //.back()?
auto a_lens = a_mat->get_shape().lens();
auto b_lens = b_mat->get_shape().lens();
if (a_lens.size() > 2)
if(a_lens.size() > 2)
{
auto batch_size = std::accumulate(
a_lens.rbegin() + 2, a_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto reshape_a = m.insert_instruction(ins, make_op("reshape", {{"dims", {batch_size * a_lens[a_lens.size() - 2], a_lens.back()}}}), a_mat);
//reshape_a->debug_print();
//std::cout << b_mat->get_operator().name() << std::endl;
auto reshape_a = m.insert_instruction(
ins,
make_op("reshape",
{{"dims", {batch_size * a_lens[a_lens.size() - 2], a_lens.back()}}}),
a_mat);
// reshape_a->debug_print();
// std::cout << b_mat->get_operator().name() << std::endl;
instruction_ref unbc_b;
if (b_mat->get_operator().name() == "concat")
if(b_mat->get_operator().name() == "concat")
{
auto concat_inputs = b_mat->inputs();
std::vector<instruction_ref> concat_lits;
int concat_axis = 1;
bool return_early = false;
for (auto c : concat_inputs)
for(auto c : concat_inputs)
{
if (c->get_operator().name() == "contiguous")
if(c->get_operator().name() == "contiguous")
c = c->inputs().front();
//std::cout << c->get_operator().name() << ", " << c->get_shape() << ", " << c->get_shape().broadcasted() <<std::endl;
if (c->get_shape().broadcasted())
// std::cout << c->get_operator().name() << ", " << c->get_shape() << ", " <<
// c->get_shape().broadcasted() <<std::endl;
if(c->get_shape().broadcasted())
{
//std::cout << c->inputs().front()->get_operator() <<std::endl;
// std::cout << c->inputs().front()->get_operator() <<std::endl;
auto lit = c->inputs().front();
auto lit_dims = lit->get_shape().lens().size();
if (lit_dims > 2)
if(lit_dims > 2)
return_early = true;
concat_axis = lit_dims - 1;
concat_lits.push_back(lit);
}
}
if (return_early)
if(return_early)
continue;
unbc_b = m.insert_instruction(ins, make_op("concat", {{"axis", concat_axis}}), concat_lits);
unbc_b = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), concat_lits);
}
else if (b_mat->get_operator().name() == "contiguous")
else if(b_mat->get_operator().name() == "contiguous")
{
//std::cout << "Contiguous B" <<std::endl;
//b_mat->debug_print();
// std::cout << "Contiguous B" <<std::endl;
// b_mat->debug_print();
auto b_input = b_mat->inputs().front();
//std::cout << b_input->get_operator().name() << ", " << b_input->get_shape().broadcasted() << ", " << b_input->can_eval() << std::endl;
if (b_input->get_shape().broadcasted())
// std::cout << b_input->get_operator().name() << ", " <<
// b_input->get_shape().broadcasted() << ", " << b_input->can_eval() << std::endl;
if(b_input->get_shape().broadcasted())
{
auto lit = b_input->inputs().front();
auto lit_dims = lit->get_shape().lens().size();
if (lit_dims > 2)
if(lit_dims > 2)
continue;
unbc_b = lit;
//unbc_b->debug_print();
// unbc_b->debug_print();
}
else
continue;
}
else
{
//std::cout << "Else" << std::endl;
// std::cout << "Else" << std::endl;
continue;
}
auto new_dot = m.insert_instruction(ins, make_op("dot"), reshape_a, unbc_b);
......@@ -113,24 +120,25 @@ void rewrite_batched_gemms::apply(module& m) const
out_lens.pop_back();
out_lens.push_back(b_lens.back());
//std::cout << std::next(ins)->get_operator().name() << std::endl;
// std::cout << std::next(ins)->get_operator().name() << std::endl;
auto next_ins = std::next(ins);
if (next_ins->get_operator().name() == "add")
if(next_ins->get_operator().name() == "add")
{
auto add_in = next_ins->inputs().back() == ins ? next_ins->inputs().front() : next_ins->inputs().back();
//add_in->debug_print();
auto reshape_add = m.insert_instruction(next_ins, make_op("reshape", {{"dims", {batch_size * a_lens[a_lens.size() - 2], b_lens.back()}}}), add_in);
auto add_in = next_ins->inputs().back() == ins ? next_ins->inputs().front()
: next_ins->inputs().back();
// add_in->debug_print();
auto reshape_add = m.insert_instruction(
next_ins,
make_op("reshape",
{{"dims", {batch_size * a_lens[a_lens.size() - 2], b_lens.back()}}}),
add_in);
new_dot = m.replace_instruction(next_ins, make_op("add"), reshape_add, new_dot);
}
//std::cout << "here" <<std::endl;
// std::cout << "here" <<std::endl;
m.replace_instruction(ins, make_op("reshape", {{"dims", out_lens}}), new_dot);
}
//m.debug_print();
// m.debug_print();
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment