Unverified Commit 0d2606bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Change attributes names to be more consistent and reflect better meaning (#916)

* rename broadcast and multibroadcast output_lens attribute to out_lens attribute, and change tests and source code to reflect the same

* change the reshape attribute from dims to out_lens

* change transpose attribute's name from dims to perm to reflect better meaning

* use permutation instead of perm for transpose

clang formaating

* use dims instead of out_lens for reshape

clang formatting
parent d8a2a933
...@@ -93,9 +93,9 @@ instruction_ref insert_quant_ins(module& modl, ...@@ -93,9 +93,9 @@ instruction_ref insert_quant_ins(module& modl,
auto max_clip = modl.add_literal(127.0f); auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f); auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction( max_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), max_clip); insert_loc, make_op("multibroadcast", {{"out_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction( min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), min_clip); insert_loc, make_op("multibroadcast", {{"out_lens", rounded_lens}}), min_clip);
auto clipped_ins = auto clipped_ins =
modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip); modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction( quant_ins = modl.insert_instruction(
......
...@@ -241,11 +241,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -241,11 +241,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w // squeeze and transpose w
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw); auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// squeeze and transpose r // squeeze and transpose r
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr); auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
...@@ -263,7 +263,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -263,7 +263,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb); auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction( bb = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", sih_lens}}), wrb); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -565,17 +565,17 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -565,17 +565,17 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// w matrix squeeze to 2-dim and do a transpose // w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw); auto tw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r slide to two part, zr and h // r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction( auto rzr = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rzr); auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto rh = prog.insert_instruction( auto rh = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr); ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rh); auto trh = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
// initial states // initial states
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
...@@ -592,7 +592,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -592,7 +592,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction( bwb = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(3 * hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
wb); wb);
auto rb_zr = prog.insert_instruction( auto rb_zr = prog.insert_instruction(
...@@ -605,11 +605,11 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -605,11 +605,11 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
sbias); sbias);
brb_zr = prog.insert_instruction( brb_zr = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(2 * hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr); rb_zr);
brb_h = prog.insert_instruction( brb_h = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
rb_h); rb_h);
} }
...@@ -1038,11 +1038,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1038,11 +1038,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose // w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw); auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r matrix, squeeze and transpose // r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr); auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
...@@ -1067,7 +1067,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1067,7 +1067,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
wrb = prog.insert_instruction( wrb = prog.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, 4 * static_cast<size_t>(hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb); ub_wrb);
} }
...@@ -1081,17 +1081,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1081,17 +1081,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto pphi = prog.insert_instruction( auto pphi = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction( pphi_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphi); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
auto ppho = prog.insert_instruction( auto ppho = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph); ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction( ppho_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), ppho); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
auto pphf = prog.insert_instruction( auto pphf = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph); ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction( pphf_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphf); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
} }
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens)); long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
......
...@@ -55,7 +55,7 @@ struct find_mul_conv ...@@ -55,7 +55,7 @@ struct find_mul_conv
auto new_a = p.insert_instruction( auto new_a = p.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 0}, {"dims", w_ins->get_shape().lens()}}), make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front()); a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins); auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction( auto new_conv = p.insert_instruction(
...@@ -120,7 +120,7 @@ struct find_mul_slice_conv ...@@ -120,7 +120,7 @@ struct find_mul_slice_conv
auto new_a = p.insert_instruction( auto new_a = p.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 0}, {"dims", slice_w_ins->get_shape().lens()}}), make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front()); a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins); auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
...@@ -989,8 +989,8 @@ struct find_split_transpose ...@@ -989,8 +989,8 @@ struct find_split_transpose
} }
// insert an transpose instruction // insert an transpose instruction
auto tr = auto tr = p.insert_instruction(
p.insert_instruction(std::next(input), make_op("transpose", {{"dims", perm}}), input); std::next(input), make_op("transpose", {{"permutation", perm}}), input);
// compute the axis in the slice // compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front(); auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
......
...@@ -96,7 +96,7 @@ struct match_find_quantizable_ops ...@@ -96,7 +96,7 @@ struct match_find_quantizable_ops
auto lens = dq->get_shape().lens(); auto lens = dq->get_shape().lens();
auto scale_mb = auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"output_lens", lens}}), dq_scale); m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb); dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
m.replace_instruction(qop, dq); m.replace_instruction(qop, dq);
} }
......
...@@ -153,7 +153,8 @@ struct find_transpose ...@@ -153,7 +153,8 @@ struct find_transpose
} }
else else
{ {
p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front()); p.replace_instruction(
ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front());
} }
} }
}; };
...@@ -278,10 +279,12 @@ struct find_concat_transpose ...@@ -278,10 +279,12 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i); return p.insert_instruction(
ins, make_op("transpose", {{"permutation", permutation}}), i);
}); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat); auto t = p.insert_instruction(
ins, make_op("transpose", {{"permutation", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens()); assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t); p.replace_instruction(ins, t);
} }
...@@ -418,7 +421,7 @@ struct find_resize ...@@ -418,7 +421,7 @@ struct find_resize
auto rsp_data = p.insert_instruction( auto rsp_data = p.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction( auto mb_rsp = p.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"output_lens", out_dims}}), rsp_data); ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp); auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end()); std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb); p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
......
...@@ -55,7 +55,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins) ...@@ -55,7 +55,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins)
auto t_in = in0->inputs().front(); auto t_in = in0->inputs().front();
auto p_in = pad_ins(m, t_in, offset); auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>(); auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in); auto r_in =
m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in);
ret_inputs.push_back(r_in); ret_inputs.push_back(r_in);
} }
else else
...@@ -85,7 +86,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins) ...@@ -85,7 +86,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins)
auto t_in = in1->inputs().front(); auto t_in = in1->inputs().front();
auto p_in = pad_ins(m, t_in, offset); auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>(); auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in); auto r_in =
m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in);
ret_inputs.push_back(r_in); ret_inputs.push_back(r_in);
} }
else else
......
...@@ -20,7 +20,8 @@ struct parse_biasadd : op_parser<parse_biasadd> ...@@ -20,7 +20,8 @@ struct parse_biasadd : op_parser<parse_biasadd>
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = info.add_instruction( auto l0 = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]); make_op("broadcast", {{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op("add"), args[0], l0); return info.add_instruction(make_op("add"), args[0], l0);
} }
}; };
......
...@@ -46,9 +46,11 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -46,9 +46,11 @@ struct parse_matmul : op_parser<parse_matmul>
// swap the last two elements // swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2); std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0]) auto l1 = (transa)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0])
: args[0]; : args[0];
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1]) auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1]; : args[1];
return info.add_instruction(make_op("dot"), l1, l2); return info.add_instruction(make_op("dot"), l1, l2);
......
...@@ -23,9 +23,9 @@ struct parse_relu6 : op_parser<parse_relu6> ...@@ -23,9 +23,9 @@ struct parse_relu6 : op_parser<parse_relu6>
auto max_val = info.add_literal(6.0f); auto max_val = info.add_literal(6.0f);
min_val = min_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
max_val = max_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
return info.add_instruction(make_op("clip"), args.front(), min_val, max_val); return info.add_instruction(make_op("clip"), args.front(), min_val, max_val);
} }
}; };
......
...@@ -20,7 +20,7 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -20,7 +20,7 @@ struct parse_transpose : op_parser<parse_transpose>
auto perm = args[1]->eval().get<int32_t>().to_vector(); auto perm = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> dims(perm.begin(), perm.end()); std::vector<int64_t> dims(perm.begin(), perm.end());
return info.add_instruction(make_op("transpose", {{"dims", dims}}), args.front()); return info.add_instruction(make_op("transpose", {{"permutation", dims}}), args.front());
} }
}; };
......
...@@ -35,20 +35,20 @@ bool tf_parser::should_transpose(instruction_ref ins) const ...@@ -35,20 +35,20 @@ bool tf_parser::should_transpose(instruction_ref ins) const
instruction_ref tf_parser::to_nhwc(instruction_ref ins) const instruction_ref tf_parser::to_nhwc(instruction_ref ins) const
{ {
if(should_transpose(ins)) if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins); return mm->add_instruction(make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), ins);
return ins; return ins;
} }
instruction_ref tf_parser::to_nchw(instruction_ref ins) const instruction_ref tf_parser::to_nchw(instruction_ref ins) const
{ {
if(should_transpose(ins)) if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins); return mm->add_instruction(make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), ins);
return ins; return ins;
} }
instruction_ref tf_parser::to_kcxy(instruction_ref ins) const instruction_ref tf_parser::to_kcxy(instruction_ref ins) const
{ {
return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins); return mm->add_instruction(make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), ins);
} }
std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const
......
...@@ -40,7 +40,7 @@ TEST_CASE(after_literal_transpose) ...@@ -40,7 +40,7 @@ TEST_CASE(after_literal_transpose)
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
EXPECT(m.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().transposed()); EXPECT(not m.get_output_shapes().back().transposed());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
m.add_instruction(pass_op{}, t); m.add_instruction(pass_op{}, t);
EXPECT(not m.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().transposed()); EXPECT(m.get_output_shapes().back().transposed());
...@@ -58,7 +58,7 @@ TEST_CASE(after_literal_broadcast) ...@@ -58,7 +58,7 @@ TEST_CASE(after_literal_broadcast)
EXPECT(m.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().broadcasted()); EXPECT(not m.get_output_shapes().back().broadcasted());
auto b = m.add_instruction( auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
m.add_instruction(pass_op{}, b); m.add_instruction(pass_op{}, b);
EXPECT(not m.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().broadcasted()); EXPECT(m.get_output_shapes().back().broadcasted());
...@@ -74,7 +74,7 @@ TEST_CASE(after_param_transpose) ...@@ -74,7 +74,7 @@ TEST_CASE(after_param_transpose)
auto l = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}}); auto l = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
EXPECT(m.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().transposed()); EXPECT(not m.get_output_shapes().back().transposed());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
m.add_instruction(pass_op{}, t); m.add_instruction(pass_op{}, t);
EXPECT(not m.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().transposed()); EXPECT(m.get_output_shapes().back().transposed());
...@@ -92,7 +92,7 @@ TEST_CASE(after_param_broadcast) ...@@ -92,7 +92,7 @@ TEST_CASE(after_param_broadcast)
EXPECT(m.get_output_shapes().back().standard()); EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().broadcasted()); EXPECT(not m.get_output_shapes().back().broadcasted());
auto b = m.add_instruction( auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2); migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
m.add_instruction(pass_op{}, b); m.add_instruction(pass_op{}, b);
EXPECT(not m.get_output_shapes().back().standard()); EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().broadcasted()); EXPECT(m.get_output_shapes().back().broadcasted());
......
...@@ -50,8 +50,8 @@ TEST_CASE(dot_add_beta_float) ...@@ -50,8 +50,8 @@ TEST_CASE(dot_add_beta_float)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast = m2.add_instruction( auto beta_broadcast =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
...@@ -79,8 +79,8 @@ TEST_CASE(dot_add_beta_half) ...@@ -79,8 +79,8 @@ TEST_CASE(dot_add_beta_half)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = m2.add_instruction( auto beta_broadcast =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
...@@ -108,8 +108,8 @@ TEST_CASE(dot_add_beta_double) ...@@ -108,8 +108,8 @@ TEST_CASE(dot_add_beta_double)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast = m2.add_instruction( auto beta_broadcast =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
...@@ -137,8 +137,8 @@ TEST_CASE(dot_add_beta_int) ...@@ -137,8 +137,8 @@ TEST_CASE(dot_add_beta_int)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}});
auto beta_broadcast = m2.add_instruction( auto beta_broadcast =
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
......
...@@ -17,7 +17,7 @@ TEST_CASE(standard_op) ...@@ -17,7 +17,7 @@ TEST_CASE(standard_op)
migraphx::module m; migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_standard_op{}, c); m.add_instruction(pass_standard_op{}, c);
auto count = std::distance(m.begin(), m.end()); auto count = std::distance(m.begin(), m.end());
...@@ -30,7 +30,7 @@ TEST_CASE(standard_op_const) ...@@ -30,7 +30,7 @@ TEST_CASE(standard_op_const)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_standard_op{}, c); m.add_instruction(pass_standard_op{}, c);
run_pass(m); run_pass(m);
...@@ -42,7 +42,7 @@ TEST_CASE(non_standard_op) ...@@ -42,7 +42,7 @@ TEST_CASE(non_standard_op)
migraphx::module m; migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_op{}, c); m.add_instruction(pass_op{}, c);
auto count = std::distance(m.begin(), m.end()); auto count = std::distance(m.begin(), m.end());
...@@ -55,7 +55,7 @@ TEST_CASE(non_standard_op_const) ...@@ -55,7 +55,7 @@ TEST_CASE(non_standard_op_const)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_op{}, c); m.add_instruction(pass_op{}, c);
run_pass(m); run_pass(m);
...@@ -67,7 +67,7 @@ TEST_CASE(transpose_gem) ...@@ -67,7 +67,7 @@ TEST_CASE(transpose_gem)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto ic = m.add_instruction(migraphx::make_op("identity"), c); auto ic = m.add_instruction(migraphx::make_op("identity"), c);
m.add_instruction(migraphx::make_op("dot"), ic, l); m.add_instruction(migraphx::make_op("dot"), ic, l);
...@@ -81,7 +81,7 @@ TEST_CASE(transpose_standard_op) ...@@ -81,7 +81,7 @@ TEST_CASE(transpose_standard_op)
migraphx::module m; migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}}); auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto sn = m.add_instruction(migraphx::make_op("sin"), c); auto sn = m.add_instruction(migraphx::make_op("sin"), c);
m.add_instruction(pass_standard_op{}, sn); m.add_instruction(pass_standard_op{}, sn);
...@@ -95,7 +95,7 @@ TEST_CASE(transpose_standard_op_const) ...@@ -95,7 +95,7 @@ TEST_CASE(transpose_standard_op_const)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t); auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto sn = m.add_instruction(migraphx::make_op("sin"), c); auto sn = m.add_instruction(migraphx::make_op("sin"), c);
m.add_instruction(pass_standard_op{}, sn); m.add_instruction(pass_standard_op{}, sn);
...@@ -123,7 +123,7 @@ TEST_CASE(non_standard_return_input) ...@@ -123,7 +123,7 @@ TEST_CASE(non_standard_return_input)
migraphx::module m; migraphx::module m;
auto l = m.add_literal(get_2x2()); auto l = m.add_literal(get_2x2());
auto tl = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto tl = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), tl); auto c = m.add_instruction(migraphx::make_op("contiguous"), tl);
m.add_return({c}); m.add_return({c});
auto count = std::distance(m.begin(), m.end()); auto count = std::distance(m.begin(), m.end());
......
...@@ -101,9 +101,11 @@ TEST_CASE(quant_dot_trans) ...@@ -101,9 +101,11 @@ TEST_CASE(quant_dot_trans)
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}}; migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}};
auto l1 = m.add_parameter("a", s1); auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction( auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2); migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r}); m.add_return({r});
...@@ -120,13 +122,15 @@ TEST_CASE(quant_dot_trans) ...@@ -120,13 +122,15 @@ TEST_CASE(quant_dot_trans)
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3); auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}}; migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}};
auto alloca = m.add_instruction( auto alloca = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca); auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}}; migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}};
auto allocb = m.add_instruction( auto allocb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
...@@ -246,9 +250,11 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -246,9 +250,11 @@ TEST_CASE(quant_dot_trans_pad)
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}}; migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}};
auto l1 = m.add_parameter("a", s1); auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction( auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2); migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r}); m.add_return({r});
...@@ -267,7 +273,8 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -267,7 +273,8 @@ TEST_CASE(quant_dot_trans_pad)
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3); auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}}; migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}};
auto ta = m.add_instruction( auto ta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
...@@ -287,7 +294,8 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -287,7 +294,8 @@ TEST_CASE(quant_dot_trans_pad)
pta); pta);
} }
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}}; migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
auto tb = m.add_instruction( auto tb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
......
...@@ -364,19 +364,19 @@ TEST_CASE(inline_tuple_true_test) ...@@ -364,19 +364,19 @@ TEST_CASE(inline_tuple_true_test)
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction( auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction( auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0}); then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction( auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction( auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1}); else_mod->add_return({mul1, add1});
...@@ -401,10 +401,10 @@ TEST_CASE(inline_tuple_true_test) ...@@ -401,10 +401,10 @@ TEST_CASE(inline_tuple_true_test)
auto y = mm->add_parameter("y", sy); auto y = mm->add_parameter("y", sy);
auto m1 = auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add = mm->add_instruction(migraphx::make_op("add"), x, m1); auto add = mm->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2); auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2);
mm->add_return({add, mul}); mm->add_return({add, mul});
...@@ -434,19 +434,19 @@ TEST_CASE(inline_tuple_false_test) ...@@ -434,19 +434,19 @@ TEST_CASE(inline_tuple_false_test)
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction( auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1); auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction( auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2); auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0}); then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction( auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1); auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction( auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2); auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1}); else_mod->add_return({mul1, add1});
...@@ -473,10 +473,10 @@ TEST_CASE(inline_tuple_false_test) ...@@ -473,10 +473,10 @@ TEST_CASE(inline_tuple_false_test)
auto y = mm->add_parameter("y", sy); auto y = mm->add_parameter("y", sy);
auto m1 = auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1); auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1);
auto m2 = auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add = mm->add_instruction(migraphx::make_op("add"), y, m2); auto add = mm->add_instruction(migraphx::make_op("add"), y, m2);
mm->add_return({mul, add}); mm->add_return({mul, add});
......
This diff is collapsed.
...@@ -74,7 +74,7 @@ TEST_CASE(broadcast) ...@@ -74,7 +74,7 @@ TEST_CASE(broadcast)
std::vector<std::size_t> lens{1, 1}; std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}}; migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}), migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", lens}}),
input); input);
} }
...@@ -94,14 +94,14 @@ TEST_CASE(broadcast) ...@@ -94,14 +94,14 @@ TEST_CASE(broadcast)
std::vector<std::size_t> lens{3, 2, 4, 3}; std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{3, 2, 4, 3}; std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}}; migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input); throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}), input);
} }
} }
...@@ -953,70 +953,70 @@ TEST_CASE(multibroadcast) ...@@ -953,70 +953,70 @@ TEST_CASE(multibroadcast)
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {5, 1}}; migraphx::shape input{migraphx::shape::float_type, {5, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {3}}; migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 4, 1, 3}; std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 1, 3}; std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}), migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}}; migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
{ {
std::vector<std::size_t> lens{2, 3, 4, 5}; std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}}; migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
{ {
std::vector<std::size_t> lens{2, 3, 4, 5}; std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input); throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
} }
...@@ -1558,10 +1558,10 @@ TEST_CASE(transpose_shape) ...@@ -1558,10 +1558,10 @@ TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}}; migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input); expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input); expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input); expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input); throws_shape(migraphx::make_op("transpose", {{"permutation", {1, 2}}}), input);
} }
TEST_CASE(step_test) TEST_CASE(step_test)
......
...@@ -25,10 +25,10 @@ create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction ...@@ -25,10 +25,10 @@ create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction
auto input_lens = input->get_shape().lens(); auto input_lens = input->get_shape().lens();
auto max_val = mm->add_literal(max); auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min); auto min_val = mm->add_literal(min);
max_val = mm->add_instruction( max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); max_val);
min_val = mm->add_instruction( min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); min_val);
return mm->add_instruction(migraphx::make_op("clip"), input, min_val, max_val); return mm->add_instruction(migraphx::make_op("clip"), input, min_val, max_val);
} }
...@@ -43,9 +43,9 @@ migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc, ...@@ -43,9 +43,9 @@ migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc,
auto max_val = mm->add_literal(max); auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min); auto min_val = mm->add_literal(min);
max_val = mm->insert_instruction( max_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val); insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
min_val = mm->insert_instruction( min_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val); insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
return mm->insert_instruction(insert_loc, migraphx::make_op("clip"), input, min_val, max_val); return mm->insert_instruction(insert_loc, migraphx::make_op("clip"), input, min_val, max_val);
} }
......
...@@ -668,7 +668,7 @@ TEST_CASE(matmul_vm) ...@@ -668,7 +668,7 @@ TEST_CASE(matmul_vm)
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto bual = mm->add_instruction( auto bual = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 1, 6}}}), ual); migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), bual, bl); mm->add_instruction(migraphx::make_op("dot"), bual, bl);
...@@ -715,7 +715,7 @@ TEST_CASE(matmul_vm) ...@@ -715,7 +715,7 @@ TEST_CASE(matmul_vm)
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto bual = mm->add_instruction( auto bual = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 1, 6}}}), ual); migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl); mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl);
...@@ -837,7 +837,7 @@ TEST_CASE(matmul_mv) ...@@ -837,7 +837,7 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
auto bubl = mm->add_instruction( auto bubl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 1}}}), ubl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 1}}}), ubl);
mm->add_instruction(migraphx::make_op("dot"), al, bubl); mm->add_instruction(migraphx::make_op("dot"), al, bubl);
std::vector<float> gold = {-0.792717, std::vector<float> gold = {-0.792717,
6.33595, 6.33595,
...@@ -897,7 +897,7 @@ TEST_CASE(matmul_mm1) ...@@ -897,7 +897,7 @@ TEST_CASE(matmul_mm1)
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction( auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), al, bbl); mm->add_instruction(migraphx::make_op("dot"), al, bbl);
std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938, std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938,
-0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232, -0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232,
...@@ -946,7 +946,7 @@ TEST_CASE(matmul_mm1) ...@@ -946,7 +946,7 @@ TEST_CASE(matmul_mm1)
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto bal = mm->add_instruction( auto bal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 3, 4}}}), al); migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 3, 4}}}), al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), bal, bl); mm->add_instruction(migraphx::make_op("dot"), bal, bl);
...@@ -994,7 +994,7 @@ TEST_CASE(matmul_mm2) ...@@ -994,7 +994,7 @@ TEST_CASE(matmul_mm2)
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction( auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
std::vector<float> gold = { std::vector<float> gold = {
0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259, 0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259,
-0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934, -0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934,
...@@ -1030,11 +1030,11 @@ TEST_CASE(matmul_mm2) ...@@ -1030,11 +1030,11 @@ TEST_CASE(matmul_mm2)
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto bal = mm->add_instruction( auto bal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 3, 5}}}), al); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 5}}}), al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction( auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), bal, bbl); mm->add_instruction(migraphx::make_op("dot"), bal, bbl);
std::vector<float> gold = { std::vector<float> gold = {
1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00, 1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00,
...@@ -1132,7 +1132,7 @@ TEST_CASE(matmul_mm2) ...@@ -1132,7 +1132,7 @@ TEST_CASE(matmul_mm2)
migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction( auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 4, 5}}}), bl); migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 4, 5}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), al, bbl); mm->add_instruction(migraphx::make_op("dot"), al, bbl);
std::vector<float> gold = { std::vector<float> gold = {
-1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156, -1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156,
...@@ -1193,7 +1193,8 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1193,7 +1193,8 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2); mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2);
...@@ -1221,7 +1222,8 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1221,7 +1222,8 @@ TEST_CASE(quant_dot_2args_multi4)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, tl2); mm->add_instruction(migraphx::make_op("quant_dot"), l1, tl2);
std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214, std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214,
...@@ -1247,9 +1249,11 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1247,9 +1249,11 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2); mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2);
std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286, std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286,
...@@ -1303,7 +1307,8 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1303,7 +1307,8 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2); mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2);
...@@ -1330,7 +1335,8 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1330,7 +1335,8 @@ TEST_CASE(quant_dot_2args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2); mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2);
std::vector<int> gold = { std::vector<int> gold = {
...@@ -1355,9 +1361,11 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1355,9 +1361,11 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data2.begin(), data2.end(), 0); std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2); mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
std::vector<int> gold = { std::vector<int> gold = {
...@@ -1447,7 +1455,8 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1447,7 +1455,8 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data3.begin(), data3.end(), 2); std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( mm->add_instruction(
...@@ -1479,7 +1488,8 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1479,7 +1488,8 @@ TEST_CASE(quant_dot_3args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
...@@ -1509,9 +1519,11 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1509,9 +1519,11 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data3.begin(), data3.end(), 2); std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2); auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
...@@ -1578,11 +1590,11 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1578,11 +1590,11 @@ TEST_CASE(quant_dot_3args_batch)
std::iota(data3.begin(), data3.end(), 2); std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = auto tl1 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = auto tl2 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3); migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment