Commit ab236eec authored by turneram's avatar turneram
Browse files

Formatting

parent b1126322
...@@ -816,14 +816,18 @@ struct find_conv_dot_horiz_fusion ...@@ -816,14 +816,18 @@ struct find_conv_dot_horiz_fusion
auto batch_size = input->get_shape().lens().front(); auto batch_size = input->get_shape().lens().front();
auto sequence_length = input->get_shape().lens().at(1); auto sequence_length = input->get_shape().lens().at(1);
auto hidden_size = input->get_shape().lens().at(2); auto hidden_size = input->get_shape().lens().at(2);
auto reshape = m.insert_instruction(std::next(input), make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}), input); auto reshape = m.insert_instruction(
std::next(input),
make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input);
auto fused = m.insert_instruction(std::next(reshape), op, reshape, concat); auto fused = m.insert_instruction(std::next(reshape), op, reshape, concat);
int64_t offset = 0; int64_t offset = 0;
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
fused = m.insert_instruction( fused = m.insert_instruction(
std::next(fused), std::next(fused),
make_op("reshape", {{"dims", {batch_size, sequence_length, hidden_size*3}}}), make_op("reshape",
{{"dims", {batch_size, sequence_length, hidden_size * 3}}}),
fused); fused);
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction( m.replace_instruction(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment