Unverified Commit 0d2606bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Change attributes names to be more consistent and reflect better meaning (#916)

* rename broadcast and multibroadcast output_lens attribute to out_lens attribute, and change tests and source code to reflect the same

* change the reshape attribute from dims to out_lens

* change transpose attribute's name from dims to perm to reflect better meaning

* use permutation instead of perm for transpose

clang formaating

* use dims instead of out_lens for reshape

clang formatting
parent d8a2a933
...@@ -93,9 +93,9 @@ instruction_ref insert_quant_ins(module& modl, ...@@ -93,9 +93,9 @@ instruction_ref insert_quant_ins(module& modl,
auto max_clip = modl.add_literal(127.0f); auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f); auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction( max_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), max_clip); insert_loc, make_op("multibroadcast", {{"out_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction( min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), min_clip); insert_loc, make_op("multibroadcast", {{"out_lens", rounded_lens}}), min_clip);
auto clipped_ins = auto clipped_ins =
modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip); modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction( quant_ins = modl.insert_instruction(
......
...@@ -241,11 +241,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -241,11 +241,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w // squeeze and transpose w
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw); auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// squeeze and transpose r // squeeze and transpose r
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr); auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
...@@ -263,7 +263,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -263,7 +263,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb); auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction( bb = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", sih_lens}}), wrb); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -565,17 +565,17 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -565,17 +565,17 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// w matrix squeeze to 2-dim and do a transpose // w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw); auto tw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r slide to two part, zr and h // r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction( auto rzr = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rzr); auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto rh = prog.insert_instruction( auto rh = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr); ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rh); auto trh = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
// initial states // initial states
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
...@@ -592,7 +592,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -592,7 +592,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction( bwb = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(3 * hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
wb); wb);
auto rb_zr = prog.insert_instruction( auto rb_zr = prog.insert_instruction(
...@@ -605,11 +605,11 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -605,11 +605,11 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
sbias); sbias);
brb_zr = prog.insert_instruction( brb_zr = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(2 * hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr); rb_zr);
brb_h = prog.insert_instruction( brb_h = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
rb_h); rb_h);
} }
...@@ -1038,11 +1038,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1038,11 +1038,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose // w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw); auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r matrix, squeeze and transpose // r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr); auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
...@@ -1067,7 +1067,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1067,7 +1067,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
wrb = prog.insert_instruction( wrb = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, 4 * static_cast<size_t>(hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb); ub_wrb);
} }
...@@ -1081,17 +1081,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1081,17 +1081,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto pphi = prog.insert_instruction( auto pphi = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction( pphi_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphi); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
auto ppho = prog.insert_instruction( auto ppho = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph); ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction( ppho_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), ppho); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
auto pphf = prog.insert_instruction( auto pphf = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph); ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction( pphf_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphf); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
} }
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens)); long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
......
...@@ -55,7 +55,7 @@ struct find_mul_conv ...@@ -55,7 +55,7 @@ struct find_mul_conv
auto new_a = p.insert_instruction( auto new_a = p.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 0}, {"dims", w_ins->get_shape().lens()}}), make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front()); a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins); auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction( auto new_conv = p.insert_instruction(
...@@ -120,7 +120,7 @@ struct find_mul_slice_conv ...@@ -120,7 +120,7 @@ struct find_mul_slice_conv
auto new_a = p.insert_instruction( auto new_a = p.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 0}, {"dims", slice_w_ins->get_shape().lens()}}), make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front()); a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins); auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
...@@ -989,8 +989,8 @@ struct find_split_transpose ...@@ -989,8 +989,8 @@ struct find_split_transpose
} }
// insert an transpose instruction // insert an transpose instruction
auto tr = auto tr = p.insert_instruction(
p.insert_instruction(std::next(input), make_op("transpose", {{"dims", perm}}), input); std::next(input), make_op("transpose", {{"permutation", perm}}), input);
// compute the axis in the slice // compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front(); auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
......
...@@ -96,7 +96,7 @@ struct match_find_quantizable_ops ...@@ -96,7 +96,7 @@ struct match_find_quantizable_ops
auto lens = dq->get_shape().lens(); auto lens = dq->get_shape().lens();
auto scale_mb = auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"output_lens", lens}}), dq_scale); m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb); dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
m.replace_instruction(qop, dq); m.replace_instruction(qop, dq);
} }
......
...@@ -153,7 +153,8 @@ struct find_transpose ...@@ -153,7 +153,8 @@ struct find_transpose
} }
else else
{ {
p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front()); p.replace_instruction(
ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front());
} }
} }
}; };
...@@ -278,10 +279,12 @@ struct find_concat_transpose ...@@ -278,10 +279,12 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i); return p.insert_instruction(
ins, make_op("transpose", {{"permutation", permutation}}), i);
}); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat); auto t = p.insert_instruction(
ins, make_op("transpose", {{"permutation", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens()); assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t); p.replace_instruction(ins, t);
} }
...@@ -418,7 +421,7 @@ struct find_resize ...@@ -418,7 +421,7 @@ struct find_resize
auto rsp_data = p.insert_instruction( auto rsp_data = p.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction( auto mb_rsp = p.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"output_lens", out_dims}}), rsp_data); ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp); auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end()); std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb); p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
......
...@@ -55,7 +55,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins) ...@@ -55,7 +55,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins)
auto t_in = in0->inputs().front(); auto t_in = in0->inputs().front();
auto p_in = pad_ins(m, t_in, offset); auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>(); auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in); auto r_in =
m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in);
ret_inputs.push_back(r_in); ret_inputs.push_back(r_in);
} }
else else
...@@ -85,7 +86,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins) ...@@ -85,7 +86,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins)
auto t_in = in1->inputs().front(); auto t_in = in1->inputs().front();
auto p_in = pad_ins(m, t_in, offset); auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>(); auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in); auto r_in =
m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in);
ret_inputs.push_back(r_in); ret_inputs.push_back(r_in);
} }
else else
......
...@@ -20,7 +20,8 @@ struct parse_biasadd : op_parser<parse_biasadd> ...@@ -20,7 +20,8 @@ struct parse_biasadd : op_parser<parse_biasadd>
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = info.add_instruction( auto l0 = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]); make_op("broadcast", {{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op("add"), args[0], l0); return info.add_instruction(make_op("add"), args[0], l0);
} }
}; };
......
...@@ -46,10 +46,12 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -46,10 +46,12 @@ struct parse_matmul : op_parser<parse_matmul>
// swap the last two elements // swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2); std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0]) auto l1 = (transa)
: args[0]; ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0])
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1]) : args[0];
: args[1]; auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
return info.add_instruction(make_op("dot"), l1, l2); return info.add_instruction(make_op("dot"), l1, l2);
} }
......
...@@ -23,9 +23,9 @@ struct parse_relu6 : op_parser<parse_relu6> ...@@ -23,9 +23,9 @@ struct parse_relu6 : op_parser<parse_relu6>
auto max_val = info.add_literal(6.0f); auto max_val = info.add_literal(6.0f);
min_val = min_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
max_val = max_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
return info.add_instruction(make_op("clip"), args.front(), min_val, max_val); return info.add_instruction(make_op("clip"), args.front(), min_val, max_val);
} }
}; };
......
...@@ -20,7 +20,7 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -20,7 +20,7 @@ struct parse_transpose : op_parser<parse_transpose>
auto perm = args[1]->eval().get<int32_t>().to_vector(); auto perm = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> dims(perm.begin(), perm.end()); std::vector<int64_t> dims(perm.begin(), perm.end());
return info.add_instruction(make_op("transpose", {{"dims", dims}}), args.front()); return info.add_instruction(make_op("transpose", {{"permutation", dims}}), args.front());
} }
}; };
......
...@@ -35,20 +35,20 @@ bool tf_parser::should_transpose(instruction_ref ins) const ...@@ -35,20 +35,20 @@ bool tf_parser::should_transpose(instruction_ref ins) const
instruction_ref tf_parser::to_nhwc(instruction_ref ins) const instruction_ref tf_parser::to_nhwc(instruction_ref ins) const
{ {
if(should_transpose(ins)) if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins); return mm->add_instruction(make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), ins);
return ins; return ins;
} }
instruction_ref tf_parser::to_nchw(instruction_ref ins) const instruction_ref tf_parser::to_nchw(instruction_ref ins) const
{ {
if(should_transpose(ins)) if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins); return mm->add_instruction(make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), ins);
return ins; return ins;
} }
instruction_ref tf_parser::to_kcxy(instruction_ref ins) const instruction_ref tf_parser::to_kcxy(instruction_ref ins) const
{ {
return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins); return mm->add_instruction(make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), ins);
} }
std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const
......
...@@ -40,7 +40,7 @@ TEST_CASE(after_literal_transpose) ...@@ -40,7 +40,7 @@ TEST_CASE(after_literal_transpose)
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
EXPECT(m.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().transposed()); EXPECT(not m.get_output_shapes().back().transposed());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
m.add_instruction(pass_op{}, t); m.add_instruction(pass_op{}, t);
EXPECT(not m.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().transposed()); EXPECT(m.get_output_shapes().back().transposed());
...@@ -58,7 +58,7 @@ TEST_CASE(after_literal_broadcast) ...@@ -58,7 +58,7 @@ TEST_CASE(after_literal_broadcast)
EXPECT(m.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().broadcasted()); EXPECT(not m.get_output_shapes().back().broadcasted());
auto b = m.add_instruction( auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
m.add_instruction(pass_op{}, b); m.add_instruction(pass_op{}, b);
EXPECT(not m.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().broadcasted()); EXPECT(m.get_output_shapes().back().broadcasted());
...@@ -74,7 +74,7 @@ TEST_CASE(after_param_transpose) ...@@ -74,7 +74,7 @@ TEST_CASE(after_param_transpose)
auto l = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}}); auto l = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
EXPECT(m.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().transposed()); EXPECT(not m.get_output_shapes().back().transposed());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
m.add_instruction(pass_op{}, t); m.add_instruction(pass_op{}, t);
EXPECT(not m.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().transposed()); EXPECT(m.get_output_shapes().back().transposed());
...@@ -92,7 +92,7 @@ TEST_CASE(after_param_broadcast) ...@@ -92,7 +92,7 @@ TEST_CASE(after_param_broadcast)
EXPECT(m.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().broadcasted()); EXPECT(not m.get_output_shapes().back().broadcasted());
auto b = m.add_instruction( auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
m.add_instruction(pass_op{}, b); m.add_instruction(pass_op{}, b);
EXPECT(not m.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().broadcasted()); EXPECT(m.get_output_shapes().back().broadcasted());
......
...@@ -50,8 +50,8 @@ TEST_CASE(dot_add_beta_float) ...@@ -50,8 +50,8 @@ TEST_CASE(dot_add_beta_float)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast = m2.add_instruction( auto beta_broadcast =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
...@@ -79,8 +79,8 @@ TEST_CASE(dot_add_beta_half) ...@@ -79,8 +79,8 @@ TEST_CASE(dot_add_beta_half)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = m2.add_instruction( auto beta_broadcast =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
...@@ -108,8 +108,8 @@ TEST_CASE(dot_add_beta_double) ...@@ -108,8 +108,8 @@ TEST_CASE(dot_add_beta_double)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast = m2.add_instruction( auto beta_broadcast =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
...@@ -137,8 +137,8 @@ TEST_CASE(dot_add_beta_int) ...@@ -137,8 +137,8 @@ TEST_CASE(dot_add_beta_int)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}});
auto beta_broadcast = m2.add_instruction( auto beta_broadcast =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
......
...@@ -17,7 +17,7 @@ TEST_CASE(standard_op) ...@@ -17,7 +17,7 @@ TEST_CASE(standard_op)
migraphx::module m; migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_standard_op{}, c); m.add_instruction(pass_standard_op{}, c);
auto count = std::distance(m.begin(), m.end()); auto count = std::distance(m.begin(), m.end());
...@@ -30,7 +30,7 @@ TEST_CASE(standard_op_const) ...@@ -30,7 +30,7 @@ TEST_CASE(standard_op_const)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_standard_op{}, c); m.add_instruction(pass_standard_op{}, c);
run_pass(m); run_pass(m);
...@@ -42,7 +42,7 @@ TEST_CASE(non_standard_op) ...@@ -42,7 +42,7 @@ TEST_CASE(non_standard_op)
migraphx::module m; migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_op{}, c); m.add_instruction(pass_op{}, c);
auto count = std::distance(m.begin(), m.end()); auto count = std::distance(m.begin(), m.end());
...@@ -55,7 +55,7 @@ TEST_CASE(non_standard_op_const) ...@@ -55,7 +55,7 @@ TEST_CASE(non_standard_op_const)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_op{}, c); m.add_instruction(pass_op{}, c);
run_pass(m); run_pass(m);
...@@ -67,7 +67,7 @@ TEST_CASE(transpose_gem) ...@@ -67,7 +67,7 @@ TEST_CASE(transpose_gem)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto ic = m.add_instruction(migraphx::make_op("identity"), c); auto ic = m.add_instruction(migraphx::make_op("identity"), c);
m.add_instruction(migraphx::make_op("dot"), ic, l); m.add_instruction(migraphx::make_op("dot"), ic, l);
...@@ -81,7 +81,7 @@ TEST_CASE(transpose_standard_op) ...@@ -81,7 +81,7 @@ TEST_CASE(transpose_standard_op)
migraphx::module m; migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto sn = m.add_instruction(migraphx::make_op("sin"), c); auto sn = m.add_instruction(migraphx::make_op("sin"), c);
m.add_instruction(pass_standard_op{}, sn); m.add_instruction(pass_standard_op{}, sn);
...@@ -95,7 +95,7 @@ TEST_CASE(transpose_standard_op_const) ...@@ -95,7 +95,7 @@ TEST_CASE(transpose_standard_op_const)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto sn = m.add_instruction(migraphx::make_op("sin"), c); auto sn = m.add_instruction(migraphx::make_op("sin"), c);
m.add_instruction(pass_standard_op{}, sn); m.add_instruction(pass_standard_op{}, sn);
...@@ -123,7 +123,7 @@ TEST_CASE(non_standard_return_input) ...@@ -123,7 +123,7 @@ TEST_CASE(non_standard_return_input)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto tl = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto tl = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), tl); auto c = m.add_instruction(migraphx::make_op("contiguous"), tl);
m.add_return({c}); m.add_return({c});
auto count = std::distance(m.begin(), m.end()); auto count = std::distance(m.begin(), m.end());
......
...@@ -100,11 +100,13 @@ TEST_CASE(quant_dot_trans) ...@@ -100,11 +100,13 @@ TEST_CASE(quant_dot_trans)
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}}; migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}}; migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}};
auto l1 = m.add_parameter("a", s1); auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); auto tl1 =
auto l2 = m.add_parameter("b", s2); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); auto l2 = m.add_parameter("b", s2);
auto r = m.add_instruction( auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2); migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r}); m.add_return({r});
return m; return m;
...@@ -120,13 +122,15 @@ TEST_CASE(quant_dot_trans) ...@@ -120,13 +122,15 @@ TEST_CASE(quant_dot_trans)
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3); auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}}; migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}};
auto alloca = m.add_instruction( auto alloca = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca); auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}}; migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}};
auto allocb = m.add_instruction( auto allocb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
...@@ -245,11 +249,13 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -245,11 +249,13 @@ TEST_CASE(quant_dot_trans_pad)
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}}; migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}}; migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}};
auto l1 = m.add_parameter("a", s1); auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); auto tl1 =
auto l2 = m.add_parameter("b", s2); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); auto l2 = m.add_parameter("b", s2);
auto r = m.add_instruction( auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2); migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r}); m.add_return({r});
return m; return m;
...@@ -267,7 +273,8 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -267,7 +273,8 @@ TEST_CASE(quant_dot_trans_pad)
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3); auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}}; migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}};
auto ta = m.add_instruction( auto ta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
...@@ -287,7 +294,8 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -287,7 +294,8 @@ TEST_CASE(quant_dot_trans_pad)
pta); pta);
} }
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}}; migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
auto tb = m.add_instruction( auto tb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
......
...@@ -364,19 +364,19 @@ TEST_CASE(inline_tuple_true_test) ...@@ -364,19 +364,19 @@ TEST_CASE(inline_tuple_true_test)
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction( auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction( auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0}); then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction( auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction( auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1}); else_mod->add_return({mul1, add1});
...@@ -401,10 +401,10 @@ TEST_CASE(inline_tuple_true_test) ...@@ -401,10 +401,10 @@ TEST_CASE(inline_tuple_true_test)
auto y = mm->add_parameter("y", sy); auto y = mm->add_parameter("y", sy);
auto m1 = auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add = mm->add_instruction(migraphx::make_op("add"), x, m1); auto add = mm->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2); auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2);
mm->add_return({add, mul}); mm->add_return({add, mul});
...@@ -434,19 +434,19 @@ TEST_CASE(inline_tuple_false_test) ...@@ -434,19 +434,19 @@ TEST_CASE(inline_tuple_false_test)
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction( auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction( auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0}); then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction( auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction( auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1}); else_mod->add_return({mul1, add1});
...@@ -473,10 +473,10 @@ TEST_CASE(inline_tuple_false_test) ...@@ -473,10 +473,10 @@ TEST_CASE(inline_tuple_false_test)
auto y = mm->add_parameter("y", sy); auto y = mm->add_parameter("y", sy);
auto m1 = auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1); auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1);
auto m2 = auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add = mm->add_instruction(migraphx::make_op("add"), y, m2); auto add = mm->add_instruction(migraphx::make_op("add"), y, m2);
mm->add_return({mul, add}); mm->add_return({mul, add});
......
...@@ -74,7 +74,7 @@ TEST_CASE(add_bcast_test) ...@@ -74,7 +74,7 @@ TEST_CASE(add_bcast_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", l0->get_shape().lens()}}), l1); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l2); mm->add_instruction(migraphx::make_op("add"), l0, l2);
auto prog = optimize_onnx("add_bcast_test.onnx"); auto prog = optimize_onnx("add_bcast_test.onnx");
...@@ -102,8 +102,8 @@ TEST_CASE(add_scalar_test) ...@@ -102,8 +102,8 @@ TEST_CASE(add_scalar_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint8_type}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint8_type});
auto m1 = mm->add_instruction( auto m1 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
auto r = mm->add_instruction(migraphx::make_op("add"), l0, m1); auto r = mm->add_instruction(migraphx::make_op("add"), l0, m1);
mm->add_return({r}); mm->add_return({r});
auto prog = migraphx::parse_onnx("add_scalar_test.onnx"); auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
...@@ -373,9 +373,9 @@ TEST_CASE(clip_test) ...@@ -373,9 +373,9 @@ TEST_CASE(clip_test)
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val);
max_val = max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), max_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_onnx("clip_test.onnx"); auto prog = optimize_onnx("clip_test.onnx");
...@@ -390,7 +390,7 @@ TEST_CASE(clip_test_op11_max_only) ...@@ -390,7 +390,7 @@ TEST_CASE(clip_test_op11_max_only)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
max_val = max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), max_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val);
auto r = mm->add_instruction(migraphx::make_op("min"), l0, max_val); auto r = mm->add_instruction(migraphx::make_op("min"), l0, max_val);
mm->add_return({r}); mm->add_return({r});
...@@ -407,9 +407,9 @@ TEST_CASE(clip_test_op11) ...@@ -407,9 +407,9 @@ TEST_CASE(clip_test_op11)
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val);
max_val = max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), max_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_onnx("clip_test_op11.onnx"); auto prog = optimize_onnx("clip_test_op11.onnx");
...@@ -423,7 +423,7 @@ TEST_CASE(clip_test_op11_min_only) ...@@ -423,7 +423,7 @@ TEST_CASE(clip_test_op11_min_only)
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3}}}), min_val);
mm->add_instruction(migraphx::make_op("max"), l0, min_val); mm->add_instruction(migraphx::make_op("max"), l0, min_val);
auto prog = optimize_onnx("clip_test_op11_min_only.onnx"); auto prog = optimize_onnx("clip_test_op11_min_only.onnx");
...@@ -638,7 +638,7 @@ TEST_CASE(conv_bias_test) ...@@ -638,7 +638,7 @@ TEST_CASE(conv_bias_test)
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = mm->add_instruction(migraphx::make_op("convolution"), l0, l1); auto l3 = mm->add_instruction(migraphx::make_op("convolution"), l0, l1);
auto l4 = mm->add_instruction( auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
mm->add_instruction(migraphx::make_op("add"), l3, l4); mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto prog = optimize_onnx("conv_bias_test.onnx"); auto prog = optimize_onnx("conv_bias_test.onnx");
...@@ -661,7 +661,7 @@ TEST_CASE(conv_bn_relu_maxpool_test) ...@@ -661,7 +661,7 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto l3 = auto l3 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1); mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1);
auto l4 = mm->add_instruction( auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4); auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction( auto l6 = mm->add_instruction(
migraphx::make_op("batch_norm_inference", {{"epsilon", 1.0e-5f}}), l5, p3, p4, p5, p6); migraphx::make_op("batch_norm_inference", {{"epsilon", 1.0e-5f}}), l5, p3, p4, p5, p6);
...@@ -687,7 +687,7 @@ TEST_CASE(conv_relu_maxpool_test) ...@@ -687,7 +687,7 @@ TEST_CASE(conv_relu_maxpool_test)
auto l3 = auto l3 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1); mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1);
auto l4 = mm->add_instruction( auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4); auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5); auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
mm->add_instruction( mm->add_instruction(
...@@ -711,7 +711,7 @@ TEST_CASE(conv_relu_maxpool_x2_test) ...@@ -711,7 +711,7 @@ TEST_CASE(conv_relu_maxpool_x2_test)
auto l3 = auto l3 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1); mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l0, l1);
auto l4 = mm->add_instruction( auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4); auto l5 = mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5); auto l6 = mm->add_instruction(migraphx::make_op("relu"), l5);
auto l7 = mm->add_instruction( auto l7 = mm->add_instruction(
...@@ -725,7 +725,8 @@ TEST_CASE(conv_relu_maxpool_x2_test) ...@@ -725,7 +725,8 @@ TEST_CASE(conv_relu_maxpool_x2_test)
auto l10 = auto l10 =
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l7, l8); mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), l7, l8);
auto l11 = mm->add_instruction( auto l11 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l10->get_shape().lens()}}), l9); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l10->get_shape().lens()}}),
l9);
auto l12 = mm->add_instruction(migraphx::make_op("add"), l10, l11); auto l12 = mm->add_instruction(migraphx::make_op("add"), l10, l11);
auto l13 = mm->add_instruction(migraphx::make_op("relu"), l12); auto l13 = mm->add_instruction(migraphx::make_op("relu"), l12);
mm->add_instruction( mm->add_instruction(
...@@ -749,7 +750,7 @@ TEST_CASE(convinteger_bias_test) ...@@ -749,7 +750,7 @@ TEST_CASE(convinteger_bias_test)
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = mm->add_instruction(migraphx::make_op("quant_convolution"), l0, l1); auto l3 = mm->add_instruction(migraphx::make_op("quant_convolution"), l0, l1);
auto l4 = mm->add_instruction( auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
mm->add_instruction(migraphx::make_op("add"), l3, l4); mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto prog = optimize_onnx("convinteger_bias_test.onnx"); auto prog = optimize_onnx("convinteger_bias_test.onnx");
...@@ -801,7 +802,7 @@ TEST_CASE(deconv_bias_test) ...@@ -801,7 +802,7 @@ TEST_CASE(deconv_bias_test)
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = mm->add_instruction(migraphx::make_op("deconvolution"), l0, l1); auto l3 = mm->add_instruction(migraphx::make_op("deconvolution"), l0, l1);
auto l4 = mm->add_instruction( auto l4 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", l3->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", l3->get_shape().lens()}}), l2);
mm->add_instruction(migraphx::make_op("add"), l3, l4); mm->add_instruction(migraphx::make_op("add"), l3, l4);
auto prog = optimize_onnx("deconv_bias_test.onnx"); auto prog = optimize_onnx("deconv_bias_test.onnx");
...@@ -923,7 +924,7 @@ TEST_CASE(dequantizelinear_test) ...@@ -923,7 +924,7 @@ TEST_CASE(dequantizelinear_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}}); auto l0 = mm->add_parameter("0", {migraphx::shape::int8_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto dequant = mm->add_instruction( auto dequant = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
...@@ -942,9 +943,9 @@ TEST_CASE(dequantizelinear_zero_point_test) ...@@ -942,9 +943,9 @@ TEST_CASE(dequantizelinear_zero_point_test)
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}}); auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto l2_mbcast = auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction( l2_mbcast = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
...@@ -971,9 +972,9 @@ migraphx::program make_dequantizelinear_axis_prog() ...@@ -971,9 +972,9 @@ migraphx::program make_dequantizelinear_axis_prog()
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}}); auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}});
auto l1_bcast = mm->add_instruction( auto l1_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l1);
auto l2_bcast = mm->add_instruction( auto l2_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l2);
l2_bcast = mm->add_instruction( l2_bcast = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
...@@ -1129,8 +1130,7 @@ TEST_CASE(expand_test) ...@@ -1129,8 +1130,7 @@ TEST_CASE(expand_test)
auto param = mm->add_parameter("x", s); auto param = mm->add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int32_type, {4}); migraphx::shape ss(migraphx::shape::int32_type, {4});
mm->add_literal(migraphx::literal(ss, {2, 3, 4, 5})); mm->add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), param);
param);
auto prog = optimize_onnx("expand_test.onnx"); auto prog = optimize_onnx("expand_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1150,7 +1150,7 @@ migraphx::program create_external_data_prog() ...@@ -1150,7 +1150,7 @@ migraphx::program create_external_data_prog()
auto conv = mm->add_instruction( auto conv = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), param, weights); migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}}), param, weights);
auto bias_bcast = mm->add_instruction( auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 10, 214, 214}}}), bias); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 10, 214, 214}}}), bias);
mm->add_instruction(migraphx::make_op("add"), conv, bias_bcast); mm->add_instruction(migraphx::make_op("add"), conv, bias_bcast);
return p; return p;
} }
...@@ -1188,8 +1188,9 @@ TEST_CASE(flatten_nonstd_test) ...@@ -1188,8 +1188,9 @@ TEST_CASE(flatten_nonstd_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 5, 4}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 5, 4}});
auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0); auto l1 =
auto l2 = mm->add_instruction(migraphx::make_op("contiguous"), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l0);
auto l2 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), l2); mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), l2);
auto l3 = mm->add_instruction(migraphx::make_op("contiguous"), l1); auto l3 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), l3); mm->add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), l3);
...@@ -1240,7 +1241,7 @@ TEST_CASE(gather_elements_axis0_test) ...@@ -1240,7 +1241,7 @@ TEST_CASE(gather_elements_axis0_test)
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto lbst_stride = mm->add_instruction( auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", ind_s.lens()}}), l_stride); migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices); auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride); auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta); auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
...@@ -1269,7 +1270,7 @@ TEST_CASE(gather_elements_axis1_test) ...@@ -1269,7 +1270,7 @@ TEST_CASE(gather_elements_axis1_test)
auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); auto rsp_data = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto lbst_stride = mm->add_instruction( auto lbst_stride = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", ind_s.lens()}}), l_stride); migraphx::make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices); auto axis_delta = mm->add_instruction(migraphx::make_op("sub"), indices, l_ind_axis_indices);
auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride); auto mul_delta = mm->add_instruction(migraphx::make_op("mul"), axis_delta, lbst_stride);
auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta); auto ind = mm->add_instruction(migraphx::make_op("add"), l_data_indices, mul_delta);
...@@ -1292,16 +1293,16 @@ TEST_CASE(gemm_test) ...@@ -1292,16 +1293,16 @@ TEST_CASE(gemm_test)
auto beta = 2.0f; auto beta = 2.0f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, t1); mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, t1);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b = auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {7, 11}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction( auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l); migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b); auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb); mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
...@@ -1320,13 +1321,13 @@ TEST_CASE(gemm_ex_test) ...@@ -1320,13 +1321,13 @@ TEST_CASE(gemm_ex_test)
auto beta = 0.8f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1); mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction( auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2->get_shape().lens()}}), b_l); migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b); mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
...@@ -1346,15 +1347,15 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -1346,15 +1347,15 @@ TEST_CASE(gemm_ex_brcst_test)
auto beta = 0.8f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1); mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b = auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", out_lens}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), l2);
auto b_b = mm->add_instruction( auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l); migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b); auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb); mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
...@@ -1375,16 +1376,15 @@ TEST_CASE(gemm_half_test) ...@@ -1375,16 +1376,15 @@ TEST_CASE(gemm_half_test)
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction( t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7}; std::vector<std::size_t> lens = {1, 1, 6, 7};
auto dot = auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1); mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), l2); l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction( l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto b_b = auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
l2_b = mm->add_instruction( l2_b = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2_b); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2_b);
...@@ -1670,20 +1670,20 @@ TEST_CASE(if_tuple_test) ...@@ -1670,20 +1670,20 @@ TEST_CASE(if_tuple_test)
auto y = mm->add_parameter("y", sy); auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction( auto m1 =
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1); then_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction( auto m2 =
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2); then_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0}); then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction( auto me1 =
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3); else_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction( auto me2 =
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3); else_mod->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1}); else_mod->add_return({mul1, add1});
...@@ -1709,7 +1709,7 @@ TEST_CASE(imagescaler_test) ...@@ -1709,7 +1709,7 @@ TEST_CASE(imagescaler_test)
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val); migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val);
auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), l0, scaled_tensor); auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), l0, scaled_tensor);
auto bias_bcast = mm->add_instruction( auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), bias_vals); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals);
mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
auto prog = optimize_onnx("imagescaler_test.onnx"); auto prog = optimize_onnx("imagescaler_test.onnx");
...@@ -1731,7 +1731,7 @@ TEST_CASE(imagescaler_half_test) ...@@ -1731,7 +1731,7 @@ TEST_CASE(imagescaler_half_test)
migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val); migraphx::make_op("scalar", {{"scalar_bcst_dims", s.lens()}}), scale_val);
auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), l0, scaled_tensor); auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), l0, scaled_tensor);
auto bias_bcast = mm->add_instruction( auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", s.lens()}}), bias_vals); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), bias_vals);
mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
auto prog = optimize_onnx("imagescaler_half_test.onnx"); auto prog = optimize_onnx("imagescaler_half_test.onnx");
...@@ -1745,8 +1745,8 @@ TEST_CASE(implicit_add_bcast_test) ...@@ -1745,8 +1745,8 @@ TEST_CASE(implicit_add_bcast_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l3 = mm->add_instruction( auto l3 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
mm->add_instruction(migraphx::make_op("add"), l0, l3); mm->add_instruction(migraphx::make_op("add"), l0, l3);
auto prog = optimize_onnx("implicit_add_bcast_test.onnx"); auto prog = optimize_onnx("implicit_add_bcast_test.onnx");
...@@ -1760,8 +1760,8 @@ TEST_CASE(implicit_add_bcast_user_input_shape_test) ...@@ -1760,8 +1760,8 @@ TEST_CASE(implicit_add_bcast_user_input_shape_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5, 1}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5, 1}});
auto l3 = mm->add_instruction( auto l3 =
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4, 5, 6}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4, 5, 6}}}), l1);
auto r = mm->add_instruction(migraphx::make_op("add"), l0, l3); auto r = mm->add_instruction(migraphx::make_op("add"), l0, l3);
mm->add_return({r}); mm->add_return({r});
...@@ -1779,8 +1779,8 @@ TEST_CASE(implicit_pow_bcast_test) ...@@ -1779,8 +1779,8 @@ TEST_CASE(implicit_pow_bcast_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l3 = mm->add_instruction( auto l3 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
mm->add_instruction(migraphx::make_op("pow"), l0, l3); mm->add_instruction(migraphx::make_op("pow"), l0, l3);
auto prog = optimize_onnx("implicit_pow_bcast_test.onnx"); auto prog = optimize_onnx("implicit_pow_bcast_test.onnx");
...@@ -1794,8 +1794,8 @@ TEST_CASE(implicit_sub_bcast_test) ...@@ -1794,8 +1794,8 @@ TEST_CASE(implicit_sub_bcast_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint64_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::uint64_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint64_type, {4, 5}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint64_type, {4, 5}});
auto l3 = mm->add_instruction( auto l3 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
mm->add_instruction(migraphx::make_op("sub"), l0, l3); mm->add_instruction(migraphx::make_op("sub"), l0, l3);
auto prog = optimize_onnx("implicit_sub_bcast_test.onnx"); auto prog = optimize_onnx("implicit_sub_bcast_test.onnx");
...@@ -1831,22 +1831,22 @@ TEST_CASE(instance_norm_test) ...@@ -1831,22 +1831,22 @@ TEST_CASE(instance_norm_test)
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x); auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast = auto mean_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", dims}}), mean); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast); auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0); auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast); auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto epsilon_literal = mm->add_literal(1e-5f); auto epsilon_literal = mm->add_literal(1e-5f);
auto epsilon_bcast = mm->add_instruction( auto epsilon_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal); migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", dims}}), variance); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast); auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2); auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3); auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3);
auto scale_bcast = auto scale_bcast = mm->add_instruction(
mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast = auto bias_bcast = mm->add_instruction(
mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast); auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast);
mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast); mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast);
...@@ -1944,7 +1944,7 @@ TEST_CASE(logical_and_bcast_test) ...@@ -1944,7 +1944,7 @@ TEST_CASE(logical_and_bcast_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 5}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 5}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l0->get_shape().lens()}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("logical_and"), l0, l2); auto ret = mm->add_instruction(migraphx::make_op("logical_and"), l0, l2);
mm->add_return({ret}); mm->add_return({ret});
...@@ -1974,7 +1974,7 @@ TEST_CASE(logical_xor_bcast_test) ...@@ -1974,7 +1974,7 @@ TEST_CASE(logical_xor_bcast_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 1}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 1}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l0->get_shape().lens()}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("logical_xor"), l0, l2); auto ret = mm->add_instruction(migraphx::make_op("logical_xor"), l0, l2);
mm->add_return({ret}); mm->add_return({ret});
...@@ -2033,9 +2033,9 @@ TEST_CASE(matmul_bmbm_test) ...@@ -2033,9 +2033,9 @@ TEST_CASE(matmul_bmbm_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}}); auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}});
auto bl0 = mm->add_instruction( auto bl0 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {5, 2, 3, 6, 7}}}), l0); migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 6, 7}}}), l0);
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {5, 2, 3, 7, 8}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 7, 8}}}), l1);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bl0, bl1); mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bl0, bl1);
auto prog = optimize_onnx("matmul_bmbm_test.onnx"); auto prog = optimize_onnx("matmul_bmbm_test.onnx");
...@@ -2051,7 +2051,7 @@ TEST_CASE(matmul_bmv_test) ...@@ -2051,7 +2051,7 @@ TEST_CASE(matmul_bmv_test)
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1); auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto bsl1 = auto bsl1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 7, 1}}}), sl1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 7, 1}}}), sl1);
auto res = auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, bsl1); mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, bsl1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res);
...@@ -2085,7 +2085,7 @@ TEST_CASE(matmul_vbm_test) ...@@ -2085,7 +2085,7 @@ TEST_CASE(matmul_vbm_test)
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto bsl0 = auto bsl0 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5, 1, 7}}}), sl0); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 1, 7}}}), sl0);
auto res = auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bsl0, l1); mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bsl0, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
...@@ -2298,17 +2298,17 @@ TEST_CASE(onehot_test) ...@@ -2298,17 +2298,17 @@ TEST_CASE(onehot_test)
std::vector<float> data_dep{1, 0, 0, 0, 1, 0, 0, 0, 1}; std::vector<float> data_dep{1, 0, 0, 0, 1, 0, 0, 0, 1};
auto l_dep = mm->add_literal(migraphx::literal(s_dep, data_dep)); auto l_dep = mm->add_literal(migraphx::literal(s_dep, data_dep));
auto gather_out = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), l_dep, l_ind); auto gather_out = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), l_dep, l_ind);
auto tr_out = auto tr_out = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}),
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {2, 0, 1}}}), gather_out); gather_out);
auto off_val = mm->add_instruction( auto off_val = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), l_val); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), l_val);
auto on_val = mm->add_instruction( auto on_val = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), l_val); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), l_val);
auto diff = mm->add_instruction(migraphx::make_op("sub"), on_val, off_val); auto diff = mm->add_instruction(migraphx::make_op("sub"), on_val, off_val);
auto mb_off_val = mm->add_instruction( auto mb_off_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 5, 2}}}), off_val); migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), off_val);
auto mb_diff = mm->add_instruction( auto mb_diff =
migraphx::make_op("multibroadcast", {{"output_lens", {3, 5, 2}}}), diff); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), diff);
auto mul = mm->add_instruction(migraphx::make_op("mul"), tr_out, mb_diff); auto mul = mm->add_instruction(migraphx::make_op("mul"), tr_out, mb_diff);
auto r = mm->add_instruction(migraphx::make_op("add"), mul, mb_off_val); auto r = mm->add_instruction(migraphx::make_op("add"), mul, mb_off_val);
mm->add_return({r}); mm->add_return({r});
...@@ -2457,7 +2457,7 @@ TEST_CASE(prelu_brcst_test) ...@@ -2457,7 +2457,7 @@ TEST_CASE(prelu_brcst_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l0->get_shape().lens()}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("prelu"), l0, bl1); auto ret = mm->add_instruction(migraphx::make_op("prelu"), l0, bl1);
mm->add_return({ret}); mm->add_return({ret});
...@@ -2473,7 +2473,7 @@ TEST_CASE(quantizelinear_test) ...@@ -2473,7 +2473,7 @@ TEST_CASE(quantizelinear_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}}); auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape(); auto s = round->get_shape();
...@@ -2498,7 +2498,7 @@ TEST_CASE(quantizelinear_int32_test) ...@@ -2498,7 +2498,7 @@ TEST_CASE(quantizelinear_int32_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::int32_type, {5}}); auto l0 = mm->add_parameter("0", {migraphx::shape::int32_type, {5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
l0 = mm->add_instruction( l0 = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
...@@ -2528,11 +2528,11 @@ TEST_CASE(quantizelinear_zero_point_test) ...@@ -2528,11 +2528,11 @@ TEST_CASE(quantizelinear_zero_point_test)
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}}); auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {1}});
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto l2_mbcast = auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {5}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction( l2_mbcast = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
...@@ -2564,12 +2564,12 @@ migraphx::program make_quantizelinear_axis_prog() ...@@ -2564,12 +2564,12 @@ migraphx::program make_quantizelinear_axis_prog()
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {5}});
auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}}); auto l2 = mm->add_parameter("2", {migraphx::shape::int8_type, {5}});
auto l1_bcast = mm->add_instruction( auto l1_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_bcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_bcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto l2_bcast = mm->add_instruction( auto l2_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l2);
l2_bcast = mm->add_instruction( l2_bcast = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
...@@ -2852,8 +2852,9 @@ TEST_CASE(reshape_non_standard_test) ...@@ -2852,8 +2852,9 @@ TEST_CASE(reshape_non_standard_test)
migraphx::op::reshape op; migraphx::op::reshape op;
std::vector<int64_t> reshape_dims{4, 3, 2}; std::vector<int64_t> reshape_dims{4, 3, 2};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto tran_x = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1}}}), x); auto tran_x =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), x);
auto cont_x = mm->add_instruction(migraphx::make_op("contiguous"), tran_x); auto cont_x = mm->add_instruction(migraphx::make_op("contiguous"), tran_x);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3, 2}}}), cont_x); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3, 2}}}), cont_x);
auto prog = optimize_onnx("reshape_non_standard_test.onnx"); auto prog = optimize_onnx("reshape_non_standard_test.onnx");
...@@ -3025,7 +3026,8 @@ TEST_CASE(resize_nonstd_input_test) ...@@ -3025,7 +3026,8 @@ TEST_CASE(resize_nonstd_input_test)
std::vector<int> ind = {0, 4}; std::vector<int> ind = {0, 4};
auto li = mm->add_literal(migraphx::literal(si, ind)); auto li = mm->add_literal(migraphx::literal(si, ind));
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), inx); auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), inx);
mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
auto tx_cont = mm->add_instruction(migraphx::make_op("contiguous"), tx); auto tx_cont = mm->add_instruction(migraphx::make_op("contiguous"), tx);
...@@ -3319,12 +3321,10 @@ TEST_CASE(selu_test) ...@@ -3319,12 +3321,10 @@ TEST_CASE(selu_test)
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::shape ls{migraphx::shape::double_type, {1}}; migraphx::shape ls{migraphx::shape::double_type, {1}};
auto la = mm->add_literal({ls, {0.3}}); auto la = mm->add_literal({ls, {0.3}});
auto lg = mm->add_literal({ls, {0.25}}); auto lg = mm->add_literal({ls, {0.25}});
auto mbla = auto mbla = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), la);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), la); auto mblg = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), lg);
auto mblg =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), lg);
auto sign_x = mm->add_instruction(migraphx::make_op("sign"), x); auto sign_x = mm->add_instruction(migraphx::make_op("sign"), x);
auto exp_x = mm->add_instruction(migraphx::make_op("exp"), x); auto exp_x = mm->add_instruction(migraphx::make_op("exp"), x);
...@@ -3651,7 +3651,7 @@ TEST_CASE(sub_bcast_test) ...@@ -3651,7 +3651,7 @@ TEST_CASE(sub_bcast_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", l0->get_shape().lens()}}), l1); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("sub"), l0, l2); mm->add_instruction(migraphx::make_op("sub"), l0, l2);
auto prog = optimize_onnx("sub_bcast_test.onnx"); auto prog = optimize_onnx("sub_bcast_test.onnx");
...@@ -3665,8 +3665,8 @@ TEST_CASE(sub_scalar_test) ...@@ -3665,8 +3665,8 @@ TEST_CASE(sub_scalar_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}}); auto l1 = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto m1 = mm->add_instruction( auto m1 =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 4, 5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
mm->add_instruction(migraphx::make_op("sub"), l0, m1); mm->add_instruction(migraphx::make_op("sub"), l0, m1);
auto prog = optimize_onnx("sub_scalar_test.onnx"); auto prog = optimize_onnx("sub_scalar_test.onnx");
...@@ -3816,7 +3816,7 @@ TEST_CASE(transpose_test) ...@@ -3816,7 +3816,7 @@ TEST_CASE(transpose_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
std::vector<int64_t> perm{0, 3, 1, 2}; std::vector<int64_t> perm{0, 3, 1, 2};
mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), input); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
auto prog = optimize_onnx("transpose_test.onnx"); auto prog = optimize_onnx("transpose_test.onnx");
...@@ -3841,9 +3841,9 @@ TEST_CASE(transpose_gather_test) ...@@ -3841,9 +3841,9 @@ TEST_CASE(transpose_gather_test)
auto ind = auto ind =
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}}); mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data = auto tr_data =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3}}}), data); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), data);
auto tr_ind = auto tr_ind =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 1, 3}}}), ind); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), ind);
int axis = 1; int axis = 1;
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}),
make_contiguous(tr_data), make_contiguous(tr_data),
...@@ -3973,11 +3973,11 @@ TEST_CASE(where_test) ...@@ -3973,11 +3973,11 @@ TEST_CASE(where_test)
{{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::int32_type)}}),
lc); lc);
auto lccm = mm->add_instruction( auto lccm = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 2, 2}}}), int_c); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), int_c);
auto lxm = mm->add_instruction( auto lxm =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 2, 2}}}), lx); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), lx);
auto lym = mm->add_instruction( auto lym =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 2, 2}}}), ly); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2}}}), ly);
auto concat_data = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), lym, lxm); auto concat_data = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), lym, lxm);
auto rsp_data = auto rsp_data =
......
...@@ -74,7 +74,7 @@ TEST_CASE(broadcast) ...@@ -74,7 +74,7 @@ TEST_CASE(broadcast)
std::vector<std::size_t> lens{1, 1}; std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}}; migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}), migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", lens}}),
input); input);
} }
...@@ -94,14 +94,14 @@ TEST_CASE(broadcast) ...@@ -94,14 +94,14 @@ TEST_CASE(broadcast)
std::vector<std::size_t> lens{3, 2, 4, 3}; std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{3, 2, 4, 3}; std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}}; migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input); throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}), input);
} }
} }
...@@ -953,70 +953,70 @@ TEST_CASE(multibroadcast) ...@@ -953,70 +953,70 @@ TEST_CASE(multibroadcast)
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {5, 1}}; migraphx::shape input{migraphx::shape::float_type, {5, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {3}}; migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 4, 1, 3}; std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 1, 3}; std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}}; migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
{ {
std::vector<std::size_t> lens{2, 3, 4, 5}; std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}}; migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
{ {
std::vector<std::size_t> lens{2, 3, 4, 5}; std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
} }
...@@ -1558,10 +1558,10 @@ TEST_CASE(transpose_shape) ...@@ -1558,10 +1558,10 @@ TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}}; migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input); expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input); expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input); expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input); throws_shape(migraphx::make_op("transpose", {{"permutation", {1, 2}}}), input);
} }
TEST_CASE(step_test) TEST_CASE(step_test)
......
...@@ -25,10 +25,10 @@ create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction ...@@ -25,10 +25,10 @@ create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction
auto input_lens = input->get_shape().lens(); auto input_lens = input->get_shape().lens();
auto max_val = mm->add_literal(max); auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min); auto min_val = mm->add_literal(min);
max_val = mm->add_instruction( max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); max_val);
min_val = mm->add_instruction( min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); min_val);
return mm->add_instruction(migraphx::make_op("clip"), input, min_val, max_val); return mm->add_instruction(migraphx::make_op("clip"), input, min_val, max_val);
} }
...@@ -43,9 +43,9 @@ migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc, ...@@ -43,9 +43,9 @@ migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc,
auto max_val = mm->add_literal(max); auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min); auto min_val = mm->add_literal(min);
max_val = mm->insert_instruction( max_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
min_val = mm->insert_instruction( min_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
return mm->insert_instruction(insert_loc, migraphx::make_op("clip"), input, min_val, max_val); return mm->insert_instruction(insert_loc, migraphx::make_op("clip"), input, min_val, max_val);
} }
......
...@@ -668,7 +668,7 @@ TEST_CASE(matmul_vm) ...@@ -668,7 +668,7 @@ TEST_CASE(matmul_vm)
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto bual = mm->add_instruction( auto bual = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 1, 6}}}), ual); migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), bual, bl); mm->add_instruction(migraphx::make_op("dot"), bual, bl);
...@@ -715,7 +715,7 @@ TEST_CASE(matmul_vm) ...@@ -715,7 +715,7 @@ TEST_CASE(matmul_vm)
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto bual = mm->add_instruction( auto bual = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 1, 6}}}), ual); migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl); mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl);
...@@ -837,7 +837,7 @@ TEST_CASE(matmul_mv) ...@@ -837,7 +837,7 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
auto bubl = mm->add_instruction( auto bubl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 1}}}), ubl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 1}}}), ubl);
mm->add_instruction(migraphx::make_op("dot"), al, bubl); mm->add_instruction(migraphx::make_op("dot"), al, bubl);
std::vector<float> gold = {-0.792717, std::vector<float> gold = {-0.792717,
6.33595, 6.33595,
...@@ -897,7 +897,7 @@ TEST_CASE(matmul_mm1) ...@@ -897,7 +897,7 @@ TEST_CASE(matmul_mm1)
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction( auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), al, bbl); mm->add_instruction(migraphx::make_op("dot"), al, bbl);
std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938, std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938,
-0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232, -0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232,
...@@ -946,7 +946,7 @@ TEST_CASE(matmul_mm1) ...@@ -946,7 +946,7 @@ TEST_CASE(matmul_mm1)
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto bal = mm->add_instruction( auto bal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 3, 4}}}), al); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 3, 4}}}), al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), bal, bl); mm->add_instruction(migraphx::make_op("dot"), bal, bl);
...@@ -994,7 +994,7 @@ TEST_CASE(matmul_mm2) ...@@ -994,7 +994,7 @@ TEST_CASE(matmul_mm2)
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction( auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
std::vector<float> gold = { std::vector<float> gold = {
0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259, 0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259,
-0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934, -0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934,
...@@ -1030,11 +1030,11 @@ TEST_CASE(matmul_mm2) ...@@ -1030,11 +1030,11 @@ TEST_CASE(matmul_mm2)
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto bal = mm->add_instruction( auto bal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 3, 5}}}), al); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 5}}}), al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction( auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), bal, bbl); mm->add_instruction(migraphx::make_op("dot"), bal, bbl);
std::vector<float> gold = { std::vector<float> gold = {
1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00, 1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00,
...@@ -1132,7 +1132,7 @@ TEST_CASE(matmul_mm2) ...@@ -1132,7 +1132,7 @@ TEST_CASE(matmul_mm2)
migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction( auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 4, 5}}}), bl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 4, 5}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), al, bbl); mm->add_instruction(migraphx::make_op("dot"), al, bbl);
std::vector<float> gold = { std::vector<float> gold = {
-1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156, -1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156,
...@@ -1192,9 +1192,10 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1192,9 +1192,10 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data1.begin(), data1.end(), 0); std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2); mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2);
std::vector<int> gold = {448, 472, 496, 520, 544, 568, 592, 616, 496, 524, 552, std::vector<int> gold = {448, 472, 496, 520, 544, 568, 592, 616, 496, 524, 552,
...@@ -1219,9 +1220,10 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1219,9 +1220,10 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data1.begin(), data1.end(), 0); std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, tl2); mm->add_instruction(migraphx::make_op("quant_dot"), l1, tl2);
std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214, std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214,
...@@ -1246,10 +1248,12 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1246,10 +1248,12 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data1.begin(), data1.end(), 0); std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2); mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2);
std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286, std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286,
...@@ -1302,9 +1306,10 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1302,9 +1306,10 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data1.begin(), data1.end(), 0); std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2); mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2);
std::vector<int> gold = { std::vector<int> gold = {
...@@ -1328,9 +1333,10 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1328,9 +1333,10 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data1.begin(), data1.end(), 0); std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2); mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2);
std::vector<int> gold = { std::vector<int> gold = {
...@@ -1354,10 +1360,12 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1354,10 +1360,12 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data1.begin(), data1.end(), 0); std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2); mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
std::vector<int> gold = { std::vector<int> gold = {
...@@ -1446,10 +1454,11 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1446,10 +1454,11 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2); std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3); migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
...@@ -1477,10 +1486,11 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1477,10 +1486,11 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2); std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
...@@ -1508,11 +1518,13 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1508,11 +1518,13 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2); std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
...@@ -1577,12 +1589,12 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1577,12 +1589,12 @@ TEST_CASE(quant_dot_3args_batch)
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2); std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment