Commit dd033c75 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents 50f87a87 8829d6ab
......@@ -111,4 +111,33 @@ TEST_CASE(const_scalar)
EXPECT(m1 == m2);
}
TEST_CASE(const_dot)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {1.0f, 2.0f, 1.0f, 2.0f};
auto l = m1.add_literal(migraphx::literal(s, vec));
auto dl = m1.add_instruction(migraphx::make_op("dot"), l, l);
auto x = m1.add_parameter("x", s);
auto r = m1.add_instruction(migraphx::make_op("add"), dl, x);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {3.0f, 6.0f, 3.0f, 6.0f};
auto x = m2.add_parameter("x", s);
auto l = m2.add_literal(migraphx::literal(s, vec));
auto r = m2.add_instruction(migraphx::make_op("add"), l, x);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -136,6 +136,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_mean.*')
backend_test.include(r'.*test_min.*')
backend_test.include(r'.*test_mul.*')
backend_test.include(r'.*test_multinomial.*')
backend_test.include(r'.*test_Multinomial.*')
backend_test.include(r'.*test_neg.*')
backend_test.include(r'.*test_not.*')
backend_test.include(r'.*test_operator_addmm.*')
......@@ -253,10 +255,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_constantofshape_float_ones_cpu')
backend_test.exclude(r'test_constantofshape_int_shape_zero_cpu')
backend_test.exclude(r'test_constantofshape_int_zeros_cpu')
backend_test.exclude(r'test_depthtospace_crd_mode_cpu')
backend_test.exclude(r'test_depthtospace_crd_mode_example_cpu')
backend_test.exclude(r'test_depthtospace_dcr_mode_cpu')
backend_test.exclude(r'test_depthtospace_example_cpu')
backend_test.exclude(r'test_expand_dim_changed_cpu')
backend_test.exclude(r'test_expand_dim_unchanged_cpu')
backend_test.exclude(r'test_expand_shape_model1_cpu')
......
......@@ -53,6 +53,22 @@ def test_neg_int64():
print(r)
def test_nonzero():
p = migraphx.parse_onnx("nonzero_dynamic_test.onnx")
print(p)
print("Compiling ...")
p.compile(migraphx.get_target("gpu"))
print(p)
params = {}
shapes = p.get_parameter_shapes()
params["data"] = np.array([1, 1, 0, 1]).reshape(
shapes["data"].lens()).astype(np.bool)
r = p.run(params)
print(r)
def test_fp16_imagescaler():
p = migraphx.parse_onnx("imagescaler_half_test.onnx")
print(p)
......@@ -98,3 +114,4 @@ test_sub_uint64()
test_neg_int64()
test_fp16_imagescaler()
test_if_pl()
test_nonzero()
......@@ -6,6 +6,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_fp16.hpp>
......@@ -431,7 +432,8 @@ TEST_CASE(op_capture)
auto pb = mm->add_parameter("b", s2);
auto pc = mm->add_parameter("c", s2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
auto ps =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("dot"), pa, ps);
return p;
......@@ -450,10 +452,10 @@ TEST_CASE(op_capture)
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto opa = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), pa);
auto opb = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), pb);
auto opc = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pc);
auto ps = mm->add_instruction(migraphx::make_op("dot"), opa, opb, opc);
auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), pa);
auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 4}}), ps);
auto ps = migraphx::add_apply_alpha_beta(
*mm, {opa, opb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pa);
auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), ps);
mm->add_instruction(migraphx::make_op("dot"), opm, ops);
return p;
......@@ -556,10 +558,8 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb, pc);
auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot"));
mm->add_return({r});
return p;
......@@ -573,7 +573,6 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0f);
scale_a = mm->add_instruction(
......@@ -592,16 +591,7 @@ TEST_CASE(dot_float)
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto zp_c = mm->add_literal(static_cast<int8_t>(100));
auto scale_c = mm->add_literal(10.0f);
scale_c = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), scale_c);
zp_c = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}),
zp_c);
auto qc = mm->add_instruction(migraphx::make_op("quantizelinear"), pc, scale_c, zp_c);
auto dqc = mm->add_instruction(migraphx::make_op("dequantizelinear"), qc, scale_c, zp_c);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), dqa, dqb, dqc);
auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot"));
mm->add_return({r});
return p;
......@@ -613,9 +603,8 @@ TEST_CASE(dot_float)
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
mm->add_parameter("c", sc);
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
auto scale_a = mm->add_instruction(
......@@ -628,8 +617,7 @@ TEST_CASE(dot_float)
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), quant_a, quant_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f);
auto mdc =
......@@ -649,6 +637,7 @@ TEST_CASE(dot_float)
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
optimize_prog_int8(p);
......@@ -665,8 +654,7 @@ TEST_CASE(dot_double_2args)
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb);
auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot"));
mm->add_return({r});
return p;
......@@ -696,8 +684,7 @@ TEST_CASE(dot_double_2args)
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot"));
mm->add_return({r});
return p;
};
......@@ -722,9 +709,8 @@ TEST_CASE(dot_double_2args)
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale = mm->add_literal(50.0);
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
......@@ -753,8 +739,7 @@ TEST_CASE(dot_half_1arg)
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", s);
auto r =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), x, x);
auto r = mm->add_instruction(migraphx::make_op("dot"), x, x);
mm->add_return({r});
return p;
......@@ -782,8 +767,7 @@ TEST_CASE(dot_half_1arg)
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
auto r = mm->add_instruction(migraphx::make_op("dot"), dqa, dqb);
mm->add_return({r});
return p;
};
......@@ -800,10 +784,8 @@ TEST_CASE(dot_half_1arg)
scale);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qx, qx);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
......@@ -1055,9 +1037,9 @@ TEST_CASE(int8_quantization_dot)
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
auto r =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_return({r});
return p;
};
......@@ -1075,7 +1057,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result));
EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000));
}
}
......@@ -1142,8 +1124,7 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if");
auto out1 = then_mod->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), a, b);
auto out1 = migraphx::add_apply_alpha_beta(*then_mod, {a, b}, migraphx::make_op("dot"));
then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else");
......@@ -1181,11 +1162,10 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot =
then_mod->add_instruction(migraphx::make_op("quant_dot", {{"beta", 0}}), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r});
......@@ -1251,7 +1231,8 @@ TEST_CASE(test_op_capture)
auto pb = mm->add_literal(s2, d2);
auto pc = mm->add_literal(s2, d2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
auto ps =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("dot"), pa, ps);
auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {};
......
......@@ -6,6 +6,7 @@
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>
......@@ -211,7 +212,11 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f;
float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> m;
......@@ -274,7 +279,11 @@ TEST_CASE(gemm_beta_0)
float alpha = 1.0f;
float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> m;
......@@ -359,13 +368,13 @@ TEST_CASE(gemm_mutli_dim1_2_3)
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha = migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{l1, l2}, migraphx::make_op("dot"), alpha);
auto l_beta = mm->add_literal(beta);
auto b_beta = mm->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta);
......@@ -418,7 +427,11 @@ TEST_CASE(gemm_mutli_3args)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> m;
......@@ -479,7 +492,7 @@ TEST_CASE(gemm_3args)
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}};
auto cl = mm->add_literal(migraphx::literal{c_shape, c});
mm->add_instruction(migraphx::make_op("dot"), al, bl, cl);
migraphx::add_apply_alpha_beta(*mm, {al, bl, cl}, migraphx::make_op("dot"), 1.0f, 1.0f);
std::vector<float> gold = {-1.60947,
0.703083,
-5.46156,
......@@ -561,7 +574,8 @@ TEST_CASE(matmul_vv_inner_product)
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.32f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, ubl);
migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-0.4590752};
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
......@@ -634,7 +648,8 @@ TEST_CASE(matmul_vm)
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
float alpha = 0.5f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, bl);
migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, bl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::ref::target{});
......@@ -718,7 +733,8 @@ TEST_CASE(matmul_vm)
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl);
migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{bual, bl}, migraphx::make_op("dot"), 0.21f);
std::vector<float> gold = {0.25812,
-0.247582,
0.480051,
......@@ -805,7 +821,8 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.3f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), al, ubl);
migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{al, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
......@@ -1337,7 +1354,8 @@ TEST_CASE(quant_dot_2args_general)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2}, migraphx::make_op("quant_dot"), 2);
std::vector<int> gold = {
28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340};
......@@ -1366,7 +1384,7 @@ TEST_CASE(quant_dot_2args_general)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
std::vector<int> gold = {
126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410};
......@@ -1398,7 +1416,7 @@ TEST_CASE(quant_dot_3args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
std::vector<int> gold = {
982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115};
......@@ -1426,9 +1444,7 @@ TEST_CASE(quant_dot_3args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), l1, l2, l3);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2);
std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
......@@ -1459,8 +1475,7 @@ TEST_CASE(quant_dot_3args_general)
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
std::vector<int> gold = {
1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585};
......@@ -1491,8 +1506,7 @@ TEST_CASE(quant_dot_3args_general)
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
std::vector<int> gold = {
286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605};
......@@ -1525,8 +1539,7 @@ TEST_CASE(quant_dot_3args_general)
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
std::vector<int> gold = {
844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170};
......@@ -1558,8 +1571,7 @@ TEST_CASE(quant_dot_3args_batch)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 2);
std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
......@@ -1596,8 +1608,7 @@ TEST_CASE(quant_dot_3args_batch)
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
......
#include <iostream>
#include <vector>
#include <cmath>
#include <random>
#include <migraphx/literal.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
......@@ -2687,6 +2688,56 @@ TEST_CASE(mul_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(multinomial_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size = 100000;
float seed = 0.0f;
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
migraphx::shape s{migraphx::shape::float_type, {1, 5}};
std::vector<int> dist{15, 25, 15, 25, 20};
std::vector<float> data(5);
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return std::log(d); });
auto input = mm->add_literal(migraphx::literal(s, data));
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
auto mb_maxes =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5}}}), maxes);
auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes);
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<int32_t> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int> res_dist(5, 0);
for(auto& r : result_vec)
res_dist[r]++;
auto dist_sum = std::accumulate(dist.begin(), dist.end(), 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
std::vector<float> norm(5);
std::vector<float> res_norm(5);
std::transform(dist.begin(), dist.end(), norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify_range(norm, res_norm, 100000));
}
TEST_CASE(neg_test)
{
migraphx::program p;
......@@ -2705,6 +2756,26 @@ TEST_CASE(neg_test)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(nonzero_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> data = {
1.0f, 1.3f, 0.0f, -1.2f, 0.0f, -100.f, 200.f, 0.0f, 0.1f, 0.2f, 0.0f, 0.5f};
auto input = mm->add_literal(migraphx::literal(s, data));
auto ret = mm->add_instruction(migraphx::make_op("nonzero"), input);
mm->add_return({ret});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::cout << "result = " << result << std::endl;
std::vector<int64_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(not_test)
{
// int32
......
......@@ -10,6 +10,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/apply_alpha_beta.hpp>
bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; }
bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; }
......@@ -127,12 +128,11 @@ TEST_CASE(dot)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
......@@ -144,11 +144,10 @@ TEST_CASE(dot)
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
m2.add_return({d3});
}
......@@ -168,22 +167,19 @@ TEST_CASE(dot_non_zero_point)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{1});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
m2.add_return({dot});
}
......@@ -203,22 +199,19 @@ TEST_CASE(dot_uint8)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::uint8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
m2.add_return({dot});
}
......@@ -240,12 +233,11 @@ TEST_CASE(dot_add)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab);
......@@ -261,10 +253,9 @@ TEST_CASE(dot_add)
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q1, q2);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
m2.add_return({add});
......@@ -471,21 +462,20 @@ TEST_CASE(conv_pooling_dot)
d1);
auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling",
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
a1);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d8, d4);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto mb1 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1);
......@@ -518,19 +508,18 @@ TEST_CASE(conv_pooling_dot)
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling",
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
......@@ -575,25 +564,24 @@ TEST_CASE(mobilenet_snippet)
d1);
auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling",
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
d6);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot =
mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot = mm.add_instruction(migraphx::make_op("dot"), d8, d4);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto mb1 =
mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1);
......@@ -699,12 +687,11 @@ TEST_CASE(dot_correctness)
auto scale_b = m1->add_literal(0.5f);
auto zero = m1->add_literal(std::int8_t{0});
auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero);
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot =
m1->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero);
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot = m1->add_instruction(migraphx::make_op("dot"), d1, d2);
m1->add_return({dot});
run_pass(*m1);
......@@ -715,8 +702,7 @@ TEST_CASE(dot_correctness)
auto* m2 = p2.get_main_module();
auto a = m2->add_parameter("a", sh1);
auto b = m2->add_parameter("b", sh2);
auto dot = m2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), a, b);
auto dot = m2->add_instruction(migraphx::make_op("dot"), a, b);
m2->add_return({dot});
}
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
......@@ -21,8 +22,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
......@@ -17,8 +18,7 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
return p;
}
};
......@@ -15,7 +15,7 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2);
return p;
}
};
......@@ -19,7 +19,7 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1);
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2);
mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2);
return p;
}
};
......@@ -21,7 +21,7 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2);
mm->add_instruction(migraphx::make_op("quant_dot"), sl1, sl2);
return p;
}
};
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -17,8 +18,7 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
float alpha = 0.23f;
auto res = mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ul1, ul2);
auto res = migraphx::add_apply_alpha_beta(*mm, {ul1, ul2}, migraphx::make_op("dot"), alpha);
auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres);
......
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -19,9 +20,7 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
......@@ -3,7 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{
migraphx::program create_program() const
......@@ -19,9 +19,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
float alpha = 0.0f;
float beta = 1.0f;
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -19,9 +20,7 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
float alpha = 1.0f;
float beta = 0.0f;
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -19,9 +20,7 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -19,8 +20,7 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
float alpha = 1.0f;
float beta = 1.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, tl2);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2}, migraphx::make_op("dot"), alpha, beta);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
......@@ -17,7 +18,7 @@ struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
......@@ -19,8 +20,7 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment