Unverified Commit 0d2606bb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Change attributes names to be more consistent and reflect better meaning (#916)

* rename broadcast and multibroadcast output_lens attribute to out_lens attribute, and change tests and source code to reflect the same

* change the reshape attribute from dims to out_lens

* change transpose attribute's name from dims to perm to reflect better meaning

* use permutation instead of perm for transpose

clang formaating

* use dims instead of out_lens for reshape

clang formatting
parent d8a2a933
......@@ -93,9 +93,9 @@ instruction_ref insert_quant_ins(module& modl,
auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), max_clip);
insert_loc, make_op("multibroadcast", {{"out_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), min_clip);
insert_loc, make_op("multibroadcast", {{"out_lens", rounded_lens}}), min_clip);
auto clipped_ins =
modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(
......
......@@ -241,11 +241,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
......@@ -263,7 +263,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", sih_lens}}), wrb);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
}
instruction_ref hidden_out = prog.end();
......@@ -565,17 +565,17 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rzr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto rh = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rh);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
// initial states
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
......@@ -592,7 +592,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(3 * hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
wb);
auto rb_zr = prog.insert_instruction(
......@@ -605,11 +605,11 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
sbias);
brb_zr = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(2 * hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr);
brb_h = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
rb_h);
}
......@@ -1038,11 +1038,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
......@@ -1067,7 +1067,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
wrb = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, 4 * static_cast<size_t>(hs)}}}),
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb);
}
......@@ -1081,17 +1081,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto pphi = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphi);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
auto ppho = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), ppho);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
auto pphf = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphf);
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
}
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
......
......@@ -55,7 +55,7 @@ struct find_mul_conv
auto new_a = p.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"dims", w_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction(
......@@ -120,7 +120,7 @@ struct find_mul_slice_conv
auto new_a = p.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"dims", slice_w_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
......@@ -989,8 +989,8 @@ struct find_split_transpose
}
// insert an transpose instruction
auto tr =
p.insert_instruction(std::next(input), make_op("transpose", {{"dims", perm}}), input);
auto tr = p.insert_instruction(
std::next(input), make_op("transpose", {{"permutation", perm}}), input);
// compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
......
......@@ -96,7 +96,7 @@ struct match_find_quantizable_ops
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"output_lens", lens}}), dq_scale);
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
m.replace_instruction(qop, dq);
}
......
......@@ -153,7 +153,8 @@ struct find_transpose
}
else
{
p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front());
p.replace_instruction(
ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front());
}
}
};
......@@ -278,10 +279,12 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i);
return p.insert_instruction(
ins, make_op("transpose", {{"permutation", permutation}}), i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat);
auto t = p.insert_instruction(
ins, make_op("transpose", {{"permutation", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
......@@ -418,7 +421,7 @@ struct find_resize
auto rsp_data = p.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"output_lens", out_dims}}), rsp_data);
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
......
......@@ -55,7 +55,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins)
auto t_in = in0->inputs().front();
auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in);
auto r_in =
m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in);
ret_inputs.push_back(r_in);
}
else
......@@ -85,7 +86,8 @@ static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins)
auto t_in = in1->inputs().front();
auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in);
auto r_in =
m.insert_instruction(ins, make_op("transpose", {{"permutation", dims}}), p_in);
ret_inputs.push_back(r_in);
}
else
......
......@@ -20,7 +20,8 @@ struct parse_biasadd : op_parser<parse_biasadd>
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]);
make_op("broadcast", {{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op("add"), args[0], l0);
}
};
......
......@@ -46,10 +46,12 @@ struct parse_matmul : op_parser<parse_matmul>
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
auto l1 = (transa)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0])
: args[0];
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
return info.add_instruction(make_op("dot"), l1, l2);
}
......
......@@ -23,9 +23,9 @@ struct parse_relu6 : op_parser<parse_relu6>
auto max_val = info.add_literal(6.0f);
min_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
max_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
return info.add_instruction(make_op("clip"), args.front(), min_val, max_val);
}
};
......
......@@ -20,7 +20,7 @@ struct parse_transpose : op_parser<parse_transpose>
auto perm = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> dims(perm.begin(), perm.end());
return info.add_instruction(make_op("transpose", {{"dims", dims}}), args.front());
return info.add_instruction(make_op("transpose", {{"permutation", dims}}), args.front());
}
};
......
......@@ -35,20 +35,20 @@ bool tf_parser::should_transpose(instruction_ref ins) const
instruction_ref tf_parser::to_nhwc(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins);
return mm->add_instruction(make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), ins);
return ins;
}
instruction_ref tf_parser::to_nchw(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins);
return mm->add_instruction(make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), ins);
return ins;
}
instruction_ref tf_parser::to_kcxy(instruction_ref ins) const
{
return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins);
return mm->add_instruction(make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), ins);
}
std::vector<instruction_ref> tf_parser::to_nchw(const std::vector<instruction_ref>& args) const
......
......@@ -40,7 +40,7 @@ TEST_CASE(after_literal_transpose)
auto l = m.add_literal(get_2x2());
EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().transposed());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
m.add_instruction(pass_op{}, t);
EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().transposed());
......@@ -58,7 +58,7 @@ TEST_CASE(after_literal_broadcast)
EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().broadcasted());
auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
m.add_instruction(pass_op{}, b);
EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().broadcasted());
......@@ -74,7 +74,7 @@ TEST_CASE(after_param_transpose)
auto l = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 2}});
EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().transposed());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
m.add_instruction(pass_op{}, t);
EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().transposed());
......@@ -92,7 +92,7 @@ TEST_CASE(after_param_broadcast)
EXPECT(m.get_output_shapes().back().standard());
EXPECT(not m.get_output_shapes().back().broadcasted());
auto b = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", l1->get_shape().lens()}}), l2);
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", l1->get_shape().lens()}}), l2);
m.add_instruction(pass_op{}, b);
EXPECT(not m.get_output_shapes().back().standard());
EXPECT(m.get_output_shapes().back().broadcasted());
......
......@@ -50,8 +50,8 @@ TEST_CASE(dot_add_beta_float)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
......@@ -79,8 +79,8 @@ TEST_CASE(dot_add_beta_half)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
......@@ -108,8 +108,8 @@ TEST_CASE(dot_add_beta_double)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
......@@ -137,8 +137,8 @@ TEST_CASE(dot_add_beta_int)
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
......
......@@ -17,7 +17,7 @@ TEST_CASE(standard_op)
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_standard_op{}, c);
auto count = std::distance(m.begin(), m.end());
......@@ -30,7 +30,7 @@ TEST_CASE(standard_op_const)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_standard_op{}, c);
run_pass(m);
......@@ -42,7 +42,7 @@ TEST_CASE(non_standard_op)
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_op{}, c);
auto count = std::distance(m.begin(), m.end());
......@@ -55,7 +55,7 @@ TEST_CASE(non_standard_op_const)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
m.add_instruction(pass_op{}, c);
run_pass(m);
......@@ -67,7 +67,7 @@ TEST_CASE(transpose_gem)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto ic = m.add_instruction(migraphx::make_op("identity"), c);
m.add_instruction(migraphx::make_op("dot"), ic, l);
......@@ -81,7 +81,7 @@ TEST_CASE(transpose_standard_op)
migraphx::module m;
auto l = m.add_parameter("x", {migraphx::shape::float_type, {2, 2}});
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto sn = m.add_instruction(migraphx::make_op("sin"), c);
m.add_instruction(pass_standard_op{}, sn);
......@@ -95,7 +95,7 @@ TEST_CASE(transpose_standard_op_const)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto t = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto t = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), t);
auto sn = m.add_instruction(migraphx::make_op("sin"), c);
m.add_instruction(pass_standard_op{}, sn);
......@@ -123,7 +123,7 @@ TEST_CASE(non_standard_return_input)
migraphx::module m;
auto l = m.add_literal(get_2x2());
auto tl = m.add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto tl = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l);
auto c = m.add_instruction(migraphx::make_op("contiguous"), tl);
m.add_return({c});
auto count = std::distance(m.begin(), m.end());
......
......@@ -100,11 +100,13 @@ TEST_CASE(quant_dot_trans)
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}};
auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
auto l1 = m.add_parameter("a", s1);
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r});
return m;
......@@ -120,13 +122,15 @@ TEST_CASE(quant_dot_trans)
auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}};
auto alloca = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}};
auto allocb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
......@@ -245,11 +249,13 @@ TEST_CASE(quant_dot_trans_pad)
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}};
auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
auto l1 = m.add_parameter("a", s1);
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r});
return m;
......@@ -267,7 +273,8 @@ TEST_CASE(quant_dot_trans_pad)
auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}};
auto ta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
......@@ -287,7 +294,8 @@ TEST_CASE(quant_dot_trans_pad)
pta);
}
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
auto tb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
......
......@@ -364,19 +364,19 @@ TEST_CASE(inline_tuple_true_test)
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
......@@ -401,10 +401,10 @@ TEST_CASE(inline_tuple_true_test)
auto y = mm->add_parameter("y", sy);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add = mm->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2);
mm->add_return({add, mul});
......@@ -434,19 +434,19 @@ TEST_CASE(inline_tuple_false_test)
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l1);
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l2);
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
......@@ -473,10 +473,10 @@ TEST_CASE(inline_tuple_false_test)
auto y = mm->add_parameter("y", sy);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {1, 4}}}), l3);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1);
auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {3, 4}}}), l3);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add = mm->add_instruction(migraphx::make_op("add"), y, m2);
mm->add_return({mul, add});
......
This diff is collapsed.
......@@ -74,7 +74,7 @@ TEST_CASE(broadcast)
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", lens}}),
input);
}
......@@ -94,14 +94,14 @@ TEST_CASE(broadcast)
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}),
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"dims", lens}}), input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", lens}}), input);
}
}
......@@ -953,70 +953,70 @@ TEST_CASE(multibroadcast)
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {5, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraphx::make_op("multibroadcast", {{"output_lens", lens}}),
migraphx::make_op("multibroadcast", {{"out_lens", lens}}),
input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{4, 1, 3};
migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), input);
throws_shape(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), input);
}
}
......@@ -1558,10 +1558,10 @@ TEST_CASE(transpose_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraphx::make_op("transpose", {{"dims", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"dims", {1, 0}}}), input);
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input);
expect_shape(output, migraphx::make_op("transpose"), input);
throws_shape(migraphx::make_op("transpose", {{"dims", {1, 2}}}), input);
throws_shape(migraphx::make_op("transpose", {{"permutation", {1, 2}}}), input);
}
TEST_CASE(step_test)
......
......@@ -25,10 +25,10 @@ create_clip_op(migraphx::program& p, float max, float min, migraphx::instruction
auto input_lens = input->get_shape().lens();
auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min);
max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
max_val);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
min_val);
return mm->add_instruction(migraphx::make_op("clip"), input, min_val, max_val);
}
......@@ -43,9 +43,9 @@ migraphx::instruction_ref create_clip_op(migraphx::instruction_ref insert_loc,
auto max_val = mm->add_literal(max);
auto min_val = mm->add_literal(min);
max_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
min_val = mm->insert_instruction(
insert_loc, migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
insert_loc, migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
return mm->insert_instruction(insert_loc, migraphx::make_op("clip"), input, min_val, max_val);
}
......
......@@ -668,7 +668,7 @@ TEST_CASE(matmul_vm)
auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto bual = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 1, 6}}}), ual);
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), bual, bl);
......@@ -715,7 +715,7 @@ TEST_CASE(matmul_vm)
auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto bual = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {3, 1, 6}}}), ual);
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl);
......@@ -837,7 +837,7 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
auto bubl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 1}}}), ubl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 1}}}), ubl);
mm->add_instruction(migraphx::make_op("dot"), al, bubl);
std::vector<float> gold = {-0.792717,
6.33595,
......@@ -897,7 +897,7 @@ TEST_CASE(matmul_mm1)
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), al, bbl);
std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938,
-0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232,
......@@ -946,7 +946,7 @@ TEST_CASE(matmul_mm1)
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto bal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 3, 3, 4}}}), al);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 3, 4}}}), al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot"), bal, bl);
......@@ -994,7 +994,7 @@ TEST_CASE(matmul_mm2)
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
std::vector<float> gold = {
0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259,
-0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934,
......@@ -1030,11 +1030,11 @@ TEST_CASE(matmul_mm2)
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}};
auto al = mm->add_literal(migraphx::literal{a_shape, a});
auto bal = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 3, 5}}}), al);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 5}}}), al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 5, 3}}}), bl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 5, 3}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), bal, bbl);
std::vector<float> gold = {
1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00,
......@@ -1132,7 +1132,7 @@ TEST_CASE(matmul_mm2)
migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto bbl = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2, 4, 5}}}), bl);
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 4, 5}}}), bl);
mm->add_instruction(migraphx::make_op("dot"), al, bbl);
std::vector<float> gold = {
-1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156,
......@@ -1192,9 +1192,10 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2);
std::vector<int> gold = {448, 472, 496, 520, 544, 568, 592, 616, 496, 524, 552,
......@@ -1219,9 +1220,10 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, tl2);
std::vector<int> gold = {14, 38, 62, 86, 110, 134, 158, 182, 38, 126, 214,
......@@ -1246,10 +1248,12 @@ TEST_CASE(quant_dot_2args_multi4)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2);
std::vector<int> gold = {56, 152, 248, 344, 440, 536, 632, 728, 62, 174, 286,
......@@ -1302,9 +1306,10 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, l2);
std::vector<int> gold = {
......@@ -1328,9 +1333,10 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2);
std::vector<int> gold = {
......@@ -1354,10 +1360,12 @@ TEST_CASE(quant_dot_2args_general)
std::iota(data1.begin(), data1.end(), 0);
std::iota(data2.begin(), data2.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
std::vector<int> gold = {
......@@ -1446,10 +1454,11 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
......@@ -1477,10 +1486,11 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
......@@ -1508,11 +1518,13 @@ TEST_CASE(quant_dot_3args_general)
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
......@@ -1577,12 +1589,12 @@ TEST_CASE(quant_dot_3args_batch)
std::iota(data2.begin(), data2.end(), 0);
std::iota(data3.begin(), data3.end(), 2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto tl1 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment