Commit 5a2b89fc authored by turneram's avatar turneram
Browse files

Only apply to bs>1; wip fuse dot add

parent ab236eec
......@@ -803,42 +803,59 @@ struct find_conv_dot_horiz_fusion
int concat_axis = 0;
if(name == "dot")
{
axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis - 1;
for(auto& arg : args)
std::cout << ins->get_shape().lens().front() << std::endl;
if(ins->get_shape().lens().front() > 1)
{
arg = arg->inputs().front();
m.move_instructions(arg, input);
axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis - 1;
for(auto& arg : args)
{
arg = arg->inputs().front();
m.move_instructions(arg, input);
}
// TODO: Check if axises match
auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto batch_size = input->get_shape().lens().front();
auto sequence_length = input->get_shape().lens().at(1);
auto hidden_size = input->get_shape().lens().at(2);
auto reshape = m.insert_instruction(
std::next(input),
make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input);
auto fused = m.insert_instruction(std::next(reshape), op, reshape, concat);
int64_t offset = 0;
std::vector<instruction_ref> add_args;
instruction_ref next_ins;
for(auto arg : range(start, last))
{
auto aarg = std::next(std::next(arg));
while(aarg->name() != "add")
aarg = std::next(aarg);
aarg->debug_print();
add_args.push_back(aarg->inputs().front());
fused = m.insert_instruction(
std::next(fused),
make_op("reshape",
{{"dims", {batch_size, sequence_length, hidden_size * 3}}}),
fused);
int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
arg,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused);
offset += len;
}
m.debug_print();
return;
}
// TODO: Check if axises match
auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto batch_size = input->get_shape().lens().front();
auto sequence_length = input->get_shape().lens().at(1);
auto hidden_size = input->get_shape().lens().at(2);
auto reshape = m.insert_instruction(
std::next(input),
make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input);
auto fused = m.insert_instruction(std::next(reshape), op, reshape, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
else
{
fused = m.insert_instruction(
std::next(fused),
make_op("reshape",
{{"dims", {batch_size, sequence_length, hidden_size * 3}}}),
fused);
int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
arg,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused);
offset += len;
axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis;
}
return;
}
for(auto arg : args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment