Commit 0369e974 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'batch_report' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 3a474fca d70fd0df
......@@ -111,4 +111,33 @@ TEST_CASE(const_scalar)
EXPECT(m1 == m2);
}
TEST_CASE(const_dot)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {1.0f, 2.0f, 1.0f, 2.0f};
auto l = m1.add_literal(migraphx::literal(s, vec));
auto dl = m1.add_instruction(migraphx::make_op("dot"), l, l);
auto x = m1.add_parameter("x", s);
auto r = m1.add_instruction(migraphx::make_op("add"), dl, x);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {3.0f, 6.0f, 3.0f, 6.0f};
auto x = m2.add_parameter("x", s);
auto l = m2.add_literal(migraphx::literal(s, vec));
auto r = m2.add_instruction(migraphx::make_op("add"), l, x);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -174,6 +174,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_reduce.*')
backend_test.include(r'.*test_ReLU*')
backend_test.include(r'.*test_relu.*')
backend_test.include(r'.*test_RoiAlign*')
backend_test.include(r'.*test_roialign.*')
backend_test.include(r'.*test_scatter.*')
backend_test.include(r'.*test_Scatter.*')
backend_test.include(r'.*test_selu.*')
......
......@@ -53,6 +53,22 @@ def test_neg_int64():
print(r)
def test_nonzero():
p = migraphx.parse_onnx("nonzero_dynamic_test.onnx")
print(p)
print("Compiling ...")
p.compile(migraphx.get_target("gpu"))
print(p)
params = {}
shapes = p.get_parameter_shapes()
params["data"] = np.array([1, 1, 0, 1]).reshape(
shapes["data"].lens()).astype(np.bool)
r = p.run(params)
print(r)
def test_fp16_imagescaler():
p = migraphx.parse_onnx("imagescaler_half_test.onnx")
print(p)
......@@ -98,3 +114,4 @@ test_sub_uint64()
test_neg_int64()
test_fp16_imagescaler()
test_if_pl()
test_nonzero()
......@@ -6,6 +6,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_fp16.hpp>
......@@ -431,7 +432,8 @@ TEST_CASE(op_capture)
auto pb = mm->add_parameter("b", s2);
auto pc = mm->add_parameter("c", s2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
auto ps =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("dot"), pa, ps);
return p;
......@@ -450,10 +452,10 @@ TEST_CASE(op_capture)
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto opa = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), pa);
auto opb = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), pb);
auto opc = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pc);
auto ps = mm->add_instruction(migraphx::make_op("dot"), opa, opb, opc);
auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), pa);
auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 4}}), ps);
auto ps = migraphx::add_apply_alpha_beta(
*mm, {opa, opb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pa);
auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), ps);
mm->add_instruction(migraphx::make_op("dot"), opm, ops);
return p;
......@@ -556,10 +558,8 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb, pc);
auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot"));
mm->add_return({r});
return p;
......@@ -573,7 +573,6 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0f);
scale_a = mm->add_instruction(
......@@ -592,16 +591,7 @@ TEST_CASE(dot_float)
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto zp_c = mm->add_literal(static_cast<int8_t>(100));
auto scale_c = mm->add_literal(10.0f);
scale_c = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), scale_c);
zp_c = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}),
zp_c);
auto qc = mm->add_instruction(migraphx::make_op("quantizelinear"), pc, scale_c, zp_c);
auto dqc = mm->add_instruction(migraphx::make_op("dequantizelinear"), qc, scale_c, zp_c);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), dqa, dqb, dqc);
auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot"));
mm->add_return({r});
return p;
......@@ -613,9 +603,8 @@ TEST_CASE(dot_float)
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
mm->add_parameter("c", sc);
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
auto scale_a = mm->add_instruction(
......@@ -628,8 +617,7 @@ TEST_CASE(dot_float)
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), quant_a, quant_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f);
auto mdc =
......@@ -649,6 +637,7 @@ TEST_CASE(dot_float)
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
optimize_prog_int8(p);
......@@ -665,8 +654,7 @@ TEST_CASE(dot_double_2args)
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb);
auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot"));
mm->add_return({r});
return p;
......@@ -696,8 +684,7 @@ TEST_CASE(dot_double_2args)
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot"));
mm->add_return({r});
return p;
};
......@@ -722,9 +709,8 @@ TEST_CASE(dot_double_2args)
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale = mm->add_literal(50.0);
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
......@@ -753,8 +739,7 @@ TEST_CASE(dot_half_1arg)
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", s);
auto r =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), x, x);
auto r = mm->add_instruction(migraphx::make_op("dot"), x, x);
mm->add_return({r});
return p;
......@@ -782,8 +767,7 @@ TEST_CASE(dot_half_1arg)
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
auto r = mm->add_instruction(migraphx::make_op("dot"), dqa, dqb);
mm->add_return({r});
return p;
};
......@@ -800,10 +784,8 @@ TEST_CASE(dot_half_1arg)
scale);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qx, qx);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
......@@ -1055,9 +1037,9 @@ TEST_CASE(int8_quantization_dot)
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
auto r =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_return({r});
return p;
};
......@@ -1075,7 +1057,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result));
EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000));
}
}
......@@ -1142,8 +1124,7 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if");
auto out1 = then_mod->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), a, b);
auto out1 = migraphx::add_apply_alpha_beta(*then_mod, {a, b}, migraphx::make_op("dot"));
then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else");
......@@ -1181,11 +1162,10 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot =
then_mod->add_instruction(migraphx::make_op("quant_dot", {{"beta", 0}}), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r});
......@@ -1251,7 +1231,8 @@ TEST_CASE(test_op_capture)
auto pb = mm->add_literal(s2, d2);
auto pc = mm->add_literal(s2, d2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
auto ps =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("dot"), pa, ps);
auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {};
......
......@@ -6,6 +6,7 @@
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>
......@@ -211,7 +212,11 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f;
float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> m;
......@@ -274,7 +279,11 @@ TEST_CASE(gemm_beta_0)
float alpha = 1.0f;
float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> m;
......@@ -359,13 +368,13 @@ TEST_CASE(gemm_mutli_dim1_2_3)
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha = migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{l1, l2}, migraphx::make_op("dot"), alpha);
auto l_beta = mm->add_literal(beta);
auto b_beta = mm->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta);
......@@ -418,7 +427,11 @@ TEST_CASE(gemm_mutli_3args)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> m;
......@@ -479,7 +492,7 @@ TEST_CASE(gemm_3args)
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}};
auto cl = mm->add_literal(migraphx::literal{c_shape, c});
mm->add_instruction(migraphx::make_op("dot"), al, bl, cl);
migraphx::add_apply_alpha_beta(*mm, {al, bl, cl}, migraphx::make_op("dot"), 1.0f, 1.0f);
std::vector<float> gold = {-1.60947,
0.703083,
-5.46156,
......@@ -561,7 +574,8 @@ TEST_CASE(matmul_vv_inner_product)
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.32f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, ubl);
migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-0.4590752};
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
......@@ -634,7 +648,8 @@ TEST_CASE(matmul_vm)
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
float alpha = 0.5f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, bl);
migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, bl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::ref::target{});
......@@ -718,7 +733,8 @@ TEST_CASE(matmul_vm)
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl);
migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{bual, bl}, migraphx::make_op("dot"), 0.21f);
std::vector<float> gold = {0.25812,
-0.247582,
0.480051,
......@@ -805,7 +821,8 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.3f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), al, ubl);
migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{al, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
......@@ -1337,7 +1354,8 @@ TEST_CASE(quant_dot_2args_general)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2}, migraphx::make_op("quant_dot"), 2);
std::vector<int> gold = {
28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340};
......@@ -1366,7 +1384,7 @@ TEST_CASE(quant_dot_2args_general)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
std::vector<int> gold = {
126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410};
......@@ -1398,7 +1416,7 @@ TEST_CASE(quant_dot_3args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
std::vector<int> gold = {
982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115};
......@@ -1426,9 +1444,7 @@ TEST_CASE(quant_dot_3args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), l1, l2, l3);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2);
std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
......@@ -1459,8 +1475,7 @@ TEST_CASE(quant_dot_3args_general)
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
std::vector<int> gold = {
1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585};
......@@ -1491,8 +1506,7 @@ TEST_CASE(quant_dot_3args_general)
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
std::vector<int> gold = {
286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605};
......@@ -1525,8 +1539,7 @@ TEST_CASE(quant_dot_3args_general)
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
std::vector<int> gold = {
844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170};
......@@ -1558,8 +1571,7 @@ TEST_CASE(quant_dot_3args_batch)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 2);
std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
......@@ -1596,8 +1608,7 @@ TEST_CASE(quant_dot_3args_batch)
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
......
......@@ -1838,8 +1838,6 @@ TEST_CASE(if_param_test)
auto res = p.eval(m).back();
std::vector<float> ret;
res.visit([&](auto v) { ret.assign(v.begin(), v.end()); });
std::cout << std::endl;
return ret;
};
......@@ -2756,6 +2754,94 @@ TEST_CASE(neg_test)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(nms_not_center_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {1, 6, 4}};
std::vector<float> boxes_vec = {1.0, 1.0, 0.0, 0.0, 0.0, 0.1, 1.0, 1.1,
0.0, 0.9, 1.0, -0.1, 0.0, 10.0, 1.0, 11.0,
1.0, 10.1, 0.0, 11.1, 1.0, 101.0, 0.0, 100.0};
migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}};
std::vector<float> scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
auto boxes_l = mm->add_literal(migraphx::literal(boxes_s, boxes_vec));
auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec));
auto max_out_l = mm->add_literal(int64_t{4});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression"),
boxes_l,
scores_l,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::cout << "output = " << output << std::endl;
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nms_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {1, 6, 4}};
std::vector<float> boxes_vec = {0.5, 0.5, 1.0, 1.0, 0.5, 0.6, 1.0, 1.0, 0.5, 0.4, 1.0, 1.0,
0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0};
migraphx::shape scores_s{migraphx::shape::float_type, {1, 1, 6}};
std::vector<float> scores_vec = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
auto boxes_l = mm->add_literal(migraphx::literal(boxes_s, boxes_vec));
auto scores_l = mm->add_literal(migraphx::literal(scores_s, scores_vec));
auto max_out_l = mm->add_literal(int64_t{4});
auto iou_threshold = mm->add_literal(0.5f);
auto score_threshold = mm->add_literal(0.0f);
auto r = mm->add_instruction(migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}),
boxes_l,
scores_l,
max_out_l,
iou_threshold,
score_threshold);
mm->add_return({r});
p.compile(migraphx::ref::target{});
auto output = p.eval({}).back();
std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::cout << "output = " << output << std::endl;
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result, gold));
}
TEST_CASE(nonzero_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> data = {
1.0f, 1.3f, 0.0f, -1.2f, 0.0f, -100.f, 200.f, 0.0f, 0.1f, 0.2f, 0.0f, 0.5f};
auto input = mm->add_literal(migraphx::literal(s, data));
auto ret = mm->add_instruction(migraphx::make_op("nonzero"), input);
mm->add_return({ret});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<int64_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(not_test)
{
// int32
......@@ -2845,6 +2931,26 @@ TEST_CASE(pad_test_lowest_half)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pointwise_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3}};
auto l1 = mm->add_literal(migraphx::literal{s, {-1, 0, 1}});
auto l2 = mm->add_literal(migraphx::literal{s, {1, 2, 3}});
auto* pm = p.create_module("pointwise");
auto x1 = pm->add_parameter("x1", {migraphx::shape::float_type});
auto x2 = pm->add_parameter("x2", {migraphx::shape::float_type});
pm->add_instruction(migraphx::make_op("add"), x1, x2);
mm->add_instruction(migraphx::make_op("pointwise"), {l1, l2}, {pm});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(pow_test)
{
migraphx::program p;
......@@ -3812,6 +3918,170 @@ TEST_CASE(reverse_test_axis10)
EXPECT(migraphx::verify_range(results_vector, target_data));
}
TEST_CASE(roialign_out_of_bound_test)
{
auto create_program = [](const std::string& trans_mode = "half_pixel") {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}};
std::vector<float> x_vec = {
0.2764, 0.7150, 0.1958, 0.3416, 0.4638, 0.0259, 0.2963, 0.6518, 0.4856, 0.7250,
0.9637, 0.0895, 0.2919, 0.6753, 0.0234, 0.6132, 0.8085, 0.5324, 0.8992, 0.4467,
0.3265, 0.8479, 0.9698, 0.2471, 0.9336, 0.1878, 0.4766, 0.4308, 0.3400, 0.2162,
0.0206, 0.1720, 0.2155, 0.4394, 0.0653, 0.3406, 0.7724, 0.3921, 0.2541, 0.5799,
0.4062, 0.2194, 0.4473, 0.4687, 0.7109, 0.9327, 0.9815, 0.6320, 0.1728, 0.6119,
0.3097, 0.1283, 0.4984, 0.5068, 0.4279, 0.0173, 0.4388, 0.0430, 0.4671, 0.7119,
0.1011, 0.8477, 0.4726, 0.1777, 0.9923, 0.4042, 0.1869, 0.7795, 0.9946, 0.9689,
0.1366, 0.3671, 0.7011, 0.6234, 0.9867, 0.5585, 0.6985, 0.5609, 0.8788, 0.9928,
0.5697, 0.8511, 0.6711, 0.9406, 0.8751, 0.7496, 0.1650, 0.1049, 0.1559, 0.2514,
0.7012, 0.4056, 0.7879, 0.3461, 0.0415, 0.2998, 0.5094, 0.3727, 0.5482, 0.0502};
migraphx::shape roi_s{migraphx::shape::float_type, {3, 4}};
std::vector<float> roi_vec = {0, 0, 9.99, 9.99, 0, 5, 4, 9, 5, 5, 9.9, 9.9};
migraphx::shape ind_s{migraphx::shape::int64_type, {3}};
std::vector<int64_t> ind_vec = {0, 0, 0};
auto x = mm->add_literal(migraphx::literal(x_s, x_vec));
auto roi = mm->add_literal(migraphx::literal(roi_s, roi_vec));
auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec));
auto r =
mm->add_instruction(migraphx::make_op("roialign",
{{"coordinate_transformation_mode", trans_mode},
{"spatial_scale", 5.0},
{"output_height", 1},
{"output_width", 1},
{"sampling_ratio", 1}}),
x,
roi,
ind);
mm->add_return({r});
return p;
};
{
auto p = create_program("output_half_pixel");
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0f, 0.0f, 0.0f};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(roialign_test)
{
auto create_program = [](const std::string& trans_mode = "half_pixel",
const std::string& pooling_mode = "avg",
int64_t sampling_ratio = 2) {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape x_s{migraphx::shape::float_type, {1, 1, 10, 10}};
std::vector<float> x_vec = {
0.2764, 0.7150, 0.1958, 0.3416, 0.4638, 0.0259, 0.2963, 0.6518, 0.4856, 0.7250,
0.9637, 0.0895, 0.2919, 0.6753, 0.0234, 0.6132, 0.8085, 0.5324, 0.8992, 0.4467,
0.3265, 0.8479, 0.9698, 0.2471, 0.9336, 0.1878, 0.4766, 0.4308, 0.3400, 0.2162,
0.0206, 0.1720, 0.2155, 0.4394, 0.0653, 0.3406, 0.7724, 0.3921, 0.2541, 0.5799,
0.4062, 0.2194, 0.4473, 0.4687, 0.7109, 0.9327, 0.9815, 0.6320, 0.1728, 0.6119,
0.3097, 0.1283, 0.4984, 0.5068, 0.4279, 0.0173, 0.4388, 0.0430, 0.4671, 0.7119,
0.1011, 0.8477, 0.4726, 0.1777, 0.9923, 0.4042, 0.1869, 0.7795, 0.9946, 0.9689,
0.1366, 0.3671, 0.7011, 0.6234, 0.9867, 0.5585, 0.6985, 0.5609, 0.8788, 0.9928,
0.5697, 0.8511, 0.6711, 0.9406, 0.8751, 0.7496, 0.1650, 0.1049, 0.1559, 0.2514,
0.7012, 0.4056, 0.7879, 0.3461, 0.0415, 0.2998, 0.5094, 0.3727, 0.5482, 0.0502};
migraphx::shape roi_s{migraphx::shape::float_type, {3, 4}};
std::vector<float> roi_vec = {0, 0, 9, 9, 0, 5, 4, 9, 5, 5, 9, 9};
migraphx::shape ind_s{migraphx::shape::int64_type, {3}};
std::vector<int64_t> ind_vec = {0, 0, 0};
auto x = mm->add_literal(migraphx::literal(x_s, x_vec));
auto roi = mm->add_literal(migraphx::literal(roi_s, roi_vec));
auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec));
auto r =
mm->add_instruction(migraphx::make_op("roialign",
{{"coordinate_transformation_mode", trans_mode},
{"spatial_scale", 1.0},
{"output_height", 5},
{"output_width", 5},
{"sampling_ratio", sampling_ratio},
{"mode", pooling_mode}}),
x,
roi,
ind);
mm->add_return({r});
return p;
};
{
auto p = create_program();
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.466421425, 0.446552634, 0.340521216, 0.568848491, 0.606780827, 0.371379346,
0.429571986, 0.383519977, 0.556241512, 0.351050019, 0.27680251, 0.488286227,
0.522200167, 0.552770197, 0.417057365, 0.471240699, 0.4844096, 0.690457463,
0.492039412, 0.877398551, 0.623889625, 0.712461948, 0.628926516, 0.335504025,
0.349469036, 0.302179992, 0.43046391, 0.469585985, 0.39774403, 0.542259991,
0.365552008, 0.704923987, 0.516481996, 0.317131996, 0.701444089, 0.291239977,
0.505897999, 0.647610962, 0.623489916, 0.829879999, 0.591567993, 0.738860011,
0.704825997, 0.837148011, 0.889315963, 0.622680008, 0.615276039, 0.709713995,
0.615356028, 0.458524048, 0.238451958, 0.337952018, 0.371693879, 0.609999895,
0.760059953, 0.376724035, 0.378532052, 0.71468991, 0.924308002, 0.972783983,
0.574903965, 0.582623959, 0.570936024, 0.761904061, 0.876998067, 0.535508037,
0.256580025, 0.214098021, 0.279604018, 0.360000014, 0.436488032, 0.350427985,
0.288755983, 0.366139978, 0.234920025};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
auto p = create_program("output_half_pixel");
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.517783, 0.343411, 0.322905, 0.447362, 0.634375, 0.40308, 0.536647, 0.442791,
0.486144, 0.402313, 0.251194, 0.400154, 0.515524, 0.695369, 0.346537, 0.33504,
0.460099, 0.588069, 0.343863, 0.684932, 0.49319, 0.714058, 0.821744, 0.471935,
0.403946, 0.306955, 0.218678, 0.33369, 0.488001, 0.486962, 0.18709, 0.49142,
0.55611, 0.419167, 0.368608, 0.143278, 0.460835, 0.597125, 0.53096, 0.498207,
0.278818, 0.438569, 0.6022, 0.700038, 0.752436, 0.577385, 0.702383, 0.725097,
0.733754, 0.816304, 0.23933, 0.407514, 0.337893, 0.252521, 0.474335, 0.367075,
0.270168, 0.41051, 0.64189, 0.830777, 0.55564, 0.454295, 0.55645, 0.75015,
0.929997, 0.66257, 0.561664, 0.481275, 0.495449, 0.666306, 0.663573, 0.372107,
0.205603, 0.192776, 0.247849};
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
auto p = create_program("output_half_pixel", "max", 0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
0.819145, 0.373103, 0.258302, 0.515419, 0.726104, 0.540536, 0.545512, 0.38511,
0.376545, 0.274635, 0.22341, 0.184511, 0.230843, 0.404869, 0.29546, 0.540409,
0.265838, 0.409324, 0.213915, 0.708654, 0.687264, 0.580821, 0.461283, 0.462879,
0.709632, 0.27873, 0.083619, 0.22428, 0.313992, 0.410508, 0.0929099, 0.415373,
0.296695, 0.231574, 0.136836, 0.0683, 0.296695, 0.211925, 0.245385, 0.28053,
0.17091, 0.179879, 0.245385, 0.343539, 0.392742, 0.51273, 0.536193, 0.382995,
0.422793, 0.761886, 0.0839429, 0.276444, 0.19746, 0.126117, 0.378351, 0.254646,
0.092148, 0.272825, 0.381955, 0.626599, 0.251325, 0.244475, 0.194875, 0.272825,
0.44757, 0.351855, 0.342265, 0.244475, 0.274841, 0.553644, 0.607176, 0.202392,
0.07425, 0.066087, 0.126279};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(round_test)
{
migraphx::program p;
......
......@@ -75,7 +75,7 @@ struct test_loop_op
ins_out_shapes.push_back({out_s.type(), lens});
}
return migraphx::shape(ins_out_shapes);
return {ins_out_shapes};
}
struct test_loop : public migraphx::op::loop::ref_loop
......
......@@ -10,6 +10,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/apply_alpha_beta.hpp>
bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; }
bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; }
......@@ -127,12 +128,11 @@ TEST_CASE(dot)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
......@@ -144,11 +144,10 @@ TEST_CASE(dot)
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
m2.add_return({d3});
}
......@@ -168,22 +167,19 @@ TEST_CASE(dot_non_zero_point)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{1});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
m2.add_return({dot});
}
......@@ -203,22 +199,19 @@ TEST_CASE(dot_uint8)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::uint8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
m2.add_return({dot});
}
......@@ -240,12 +233,11 @@ TEST_CASE(dot_add)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab);
......@@ -261,10 +253,9 @@ TEST_CASE(dot_add)
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q1, q2);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
m2.add_return({add});
......@@ -471,21 +462,20 @@ TEST_CASE(conv_pooling_dot)
d1);
auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling",
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
a1);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d8, d4);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto mb1 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1);
......@@ -518,19 +508,18 @@ TEST_CASE(conv_pooling_dot)
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling",
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
......@@ -575,25 +564,24 @@ TEST_CASE(mobilenet_snippet)
d1);
auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling",
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
d6);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot =
mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot = mm.add_instruction(migraphx::make_op("dot"), d8, d4);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto mb1 =
mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1);
......@@ -699,12 +687,11 @@ TEST_CASE(dot_correctness)
auto scale_b = m1->add_literal(0.5f);
auto zero = m1->add_literal(std::int8_t{0});
auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero);
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot =
m1->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero);
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot = m1->add_instruction(migraphx::make_op("dot"), d1, d2);
m1->add_return({dot});
run_pass(*m1);
......@@ -715,8 +702,7 @@ TEST_CASE(dot_correctness)
auto* m2 = p2.get_main_module();
auto a = m2->add_parameter("a", sh1);
auto b = m2->add_parameter("b", sh2);
auto dot = m2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), a, b);
auto dot = m2->add_instruction(migraphx::make_op("dot"), a, b);
m2->add_return({dot});
}
......
......@@ -945,4 +945,177 @@ TEST_CASE(reshape_cont_nonpw)
EXPECT(m1 == create_module());
}
TEST_CASE(transpose_contiguous_reshape_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto reshape_ins2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins2);
m1.add_instruction(pass_op{}, relu);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), relu);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
m2.add_instruction(pass_op{}, reshape_ins2);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_squeeze_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto sq_ins = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
auto rsqrt = m1.add_instruction(migraphx::make_op("rsqrt"), sq_ins);
m1.add_instruction(pass_op{}, rsqrt);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), rsqrt);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
m2.add_instruction(pass_op{}, sq_ins);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_unsqueeze_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto unsq_ins =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
auto round = m1.add_instruction(migraphx::make_op("round"), unsq_ins);
m1.add_instruction(pass_op{}, round);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto round = m2.add_instruction(migraphx::make_op("round"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), round);
auto unsq_ins =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
m2.add_instruction(pass_op{}, unsq_ins);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_reshape_binary_packed)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 128, 28, 28}});
auto w1 = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
x,
w1); // (2, 256, 28, 28)
auto w2 = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {512, 256, 1, 1}}));
auto conv2 = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
conv1,
w2); // (2, 512, 14, 14)
auto conv2_rsp1 = m1.add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 2, 2, 128, 14, 14}}}), conv2);
auto conv2_trans = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), conv2_rsp1);
auto conv2_cont = m1.add_instruction(migraphx::make_op("contiguous"), conv2_trans);
auto conv2_rsp2 = m1.add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 128, 28, 28}}}), conv2_cont);
auto add_ins = m1.add_instruction(migraphx::make_op("add"), conv2_rsp2, x);
m1.add_instruction(pass_op{}, add_ins);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 128, 28, 28}});
auto w1 = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
x,
w1); // (2, 256, 28, 28)
auto w2 = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {512, 256, 1, 1}}));
auto conv2 = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
conv1,
w2); // (2, 512, 14, 14)
auto conv2_rsp = m2.add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 2, 2, 128, 14, 14}}}), conv2);
auto conv2_trans = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), conv2_rsp);
auto x_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 128, 14, 2, 14, 2}}}), x);
auto add_ins = m2.add_instruction(migraphx::make_op("add"), conv2_trans, x_rsp);
auto add_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 128, 28, 28}}}), add_ins);
m2.add_instruction(pass_op{}, add_rsp);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
{
migraphx::module m1;
{
migraphx::shape sx{migraphx::shape::float_type, {4}};
migraphx::shape sy{migraphx::shape::float_type, {2, 6, 2, 2}};
auto x = m1.add_parameter("x", sx);
auto y = m1.add_parameter("y", sy);
auto x_brcst = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 4, 6}}}), x);
auto y_trans =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y);
auto y_cont = m1.add_instruction(migraphx::make_op("contiguous"), y_trans);
auto y_rsp =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), y_cont);
auto r = m1.add_instruction(migraphx::make_op("add"), y_rsp, x_brcst);
m1.add_return({r});
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
......@@ -21,8 +22,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
......@@ -17,8 +18,7 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
return p;
}
};
......@@ -15,7 +15,7 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2);
return p;
}
};
......@@ -19,7 +19,7 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1);
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2);
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2);
return p;
}
};
......@@ -21,7 +21,7 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2);
mm->add_instruction(migraphx::make_op("quant_dot"), sl1, sl2);
return p;
}
};
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -17,8 +18,7 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
float alpha = 0.23f;
auto res = mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ul1, ul2);
auto res = migraphx::add_apply_alpha_beta(*mm, {ul1, ul2}, migraphx::make_op("dot"), alpha);
auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres);
......
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -19,9 +20,7 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
......@@ -3,7 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{
migraphx::program create_program() const
......@@ -19,9 +19,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
float alpha = 0.0f;
float beta = 1.0f;
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -19,9 +20,7 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
float alpha = 1.0f;
float beta = 0.0f;
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -19,9 +20,7 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -19,8 +20,7 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
float alpha = 1.0f;
float beta = 1.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, tl2);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment