Commit dd033c75 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents 50f87a87 8829d6ab
...@@ -111,4 +111,33 @@ TEST_CASE(const_scalar) ...@@ -111,4 +111,33 @@ TEST_CASE(const_scalar)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(const_dot)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {1.0f, 2.0f, 1.0f, 2.0f};
auto l = m1.add_literal(migraphx::literal(s, vec));
auto dl = m1.add_instruction(migraphx::make_op("dot"), l, l);
auto x = m1.add_parameter("x", s);
auto r = m1.add_instruction(migraphx::make_op("add"), dl, x);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {3.0f, 6.0f, 3.0f, 6.0f};
auto x = m2.add_parameter("x", s);
auto l = m2.add_literal(migraphx::literal(s, vec));
auto r = m2.add_instruction(migraphx::make_op("add"), l, x);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -136,6 +136,8 @@ def create_backend_test(testname=None, target_device=None): ...@@ -136,6 +136,8 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_mean.*') backend_test.include(r'.*test_mean.*')
backend_test.include(r'.*test_min.*') backend_test.include(r'.*test_min.*')
backend_test.include(r'.*test_mul.*') backend_test.include(r'.*test_mul.*')
backend_test.include(r'.*test_multinomial.*')
backend_test.include(r'.*test_Multinomial.*')
backend_test.include(r'.*test_neg.*') backend_test.include(r'.*test_neg.*')
backend_test.include(r'.*test_not.*') backend_test.include(r'.*test_not.*')
backend_test.include(r'.*test_operator_addmm.*') backend_test.include(r'.*test_operator_addmm.*')
...@@ -253,10 +255,6 @@ def create_backend_test(testname=None, target_device=None): ...@@ -253,10 +255,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_constantofshape_float_ones_cpu') backend_test.exclude(r'test_constantofshape_float_ones_cpu')
backend_test.exclude(r'test_constantofshape_int_shape_zero_cpu') backend_test.exclude(r'test_constantofshape_int_shape_zero_cpu')
backend_test.exclude(r'test_constantofshape_int_zeros_cpu') backend_test.exclude(r'test_constantofshape_int_zeros_cpu')
backend_test.exclude(r'test_depthtospace_crd_mode_cpu')
backend_test.exclude(r'test_depthtospace_crd_mode_example_cpu')
backend_test.exclude(r'test_depthtospace_dcr_mode_cpu')
backend_test.exclude(r'test_depthtospace_example_cpu')
backend_test.exclude(r'test_expand_dim_changed_cpu') backend_test.exclude(r'test_expand_dim_changed_cpu')
backend_test.exclude(r'test_expand_dim_unchanged_cpu') backend_test.exclude(r'test_expand_dim_unchanged_cpu')
backend_test.exclude(r'test_expand_shape_model1_cpu') backend_test.exclude(r'test_expand_shape_model1_cpu')
......
...@@ -53,6 +53,22 @@ def test_neg_int64(): ...@@ -53,6 +53,22 @@ def test_neg_int64():
print(r) print(r)
def test_nonzero():
p = migraphx.parse_onnx("nonzero_dynamic_test.onnx")
print(p)
print("Compiling ...")
p.compile(migraphx.get_target("gpu"))
print(p)
params = {}
shapes = p.get_parameter_shapes()
params["data"] = np.array([1, 1, 0, 1]).reshape(
shapes["data"].lens()).astype(np.bool)
r = p.run(params)
print(r)
def test_fp16_imagescaler(): def test_fp16_imagescaler():
p = migraphx.parse_onnx("imagescaler_half_test.onnx") p = migraphx.parse_onnx("imagescaler_half_test.onnx")
print(p) print(p)
...@@ -98,3 +114,4 @@ test_sub_uint64() ...@@ -98,3 +114,4 @@ test_sub_uint64()
test_neg_int64() test_neg_int64()
test_fp16_imagescaler() test_fp16_imagescaler()
test_if_pl() test_if_pl()
test_nonzero()
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp> #include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_fp16.hpp> #include <migraphx/quantize_fp16.hpp>
...@@ -431,7 +432,8 @@ TEST_CASE(op_capture) ...@@ -431,7 +432,8 @@ TEST_CASE(op_capture)
auto pb = mm->add_parameter("b", s2); auto pb = mm->add_parameter("b", s2);
auto pc = mm->add_parameter("c", s2); auto pc = mm->add_parameter("c", s2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2); auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto ps =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("dot"), pa, ps); mm->add_instruction(migraphx::make_op("dot"), pa, ps);
return p; return p;
...@@ -450,10 +452,10 @@ TEST_CASE(op_capture) ...@@ -450,10 +452,10 @@ TEST_CASE(op_capture)
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2); auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto opa = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), pa); auto opa = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), pa);
auto opb = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), pb); auto opb = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), pb);
auto opc = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pc); auto ps = migraphx::add_apply_alpha_beta(
auto ps = mm->add_instruction(migraphx::make_op("dot"), opa, opb, opc); *mm, {opa, opb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), pa); auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pa);
auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 4}}), ps); auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), ps);
mm->add_instruction(migraphx::make_op("dot"), opm, ops); mm->add_instruction(migraphx::make_op("dot"), opm, ops);
return p; return p;
...@@ -556,10 +558,8 @@ TEST_CASE(dot_float) ...@@ -556,10 +558,8 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction( auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot"));
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb, pc);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -573,7 +573,6 @@ TEST_CASE(dot_float) ...@@ -573,7 +573,6 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto zp_a = mm->add_literal(static_cast<int8_t>(0)); auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0f); auto scale_a = mm->add_literal(10.0f);
scale_a = mm->add_instruction( scale_a = mm->add_instruction(
...@@ -592,16 +591,7 @@ TEST_CASE(dot_float) ...@@ -592,16 +591,7 @@ TEST_CASE(dot_float)
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto zp_c = mm->add_literal(static_cast<int8_t>(100)); auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot"));
auto scale_c = mm->add_literal(10.0f);
scale_c = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), scale_c);
zp_c = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}),
zp_c);
auto qc = mm->add_instruction(migraphx::make_op("quantizelinear"), pc, scale_c, zp_c);
auto dqc = mm->add_instruction(migraphx::make_op("dequantizelinear"), qc, scale_c, zp_c);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), dqa, dqb, dqc);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -613,9 +603,8 @@ TEST_CASE(dot_float) ...@@ -613,9 +603,8 @@ TEST_CASE(dot_float)
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
mm->add_parameter("c", sc);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f); auto scale = mm->add_literal(10.0f);
auto scale_a = mm->add_instruction( auto scale_a = mm->add_instruction(
...@@ -628,8 +617,7 @@ TEST_CASE(dot_float) ...@@ -628,8 +617,7 @@ TEST_CASE(dot_float)
auto zp_b = auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction( auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f); std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f); auto dc = mm->add_literal(100.0f);
auto mdc = auto mdc =
...@@ -649,6 +637,7 @@ TEST_CASE(dot_float) ...@@ -649,6 +637,7 @@ TEST_CASE(dot_float)
p, p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); {migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
optimize_prog_int8(p); optimize_prog_int8(p);
...@@ -665,8 +654,7 @@ TEST_CASE(dot_double_2args) ...@@ -665,8 +654,7 @@ TEST_CASE(dot_double_2args)
migraphx::shape sb{migraphx::shape::double_type, {16, 8}}; migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto r = mm->add_instruction( auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot"));
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -696,8 +684,7 @@ TEST_CASE(dot_double_2args) ...@@ -696,8 +684,7 @@ TEST_CASE(dot_double_2args)
zp_b); zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction( auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot"));
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -722,9 +709,8 @@ TEST_CASE(dot_double_2args) ...@@ -722,9 +709,8 @@ TEST_CASE(dot_double_2args)
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b); migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto zp_b = auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction( auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto scale = mm->add_literal(50.0); auto scale = mm->add_literal(50.0);
scale = mm->add_instruction( scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale); migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
...@@ -753,8 +739,7 @@ TEST_CASE(dot_half_1arg) ...@@ -753,8 +739,7 @@ TEST_CASE(dot_half_1arg)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {9, 9}}; migraphx::shape s{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto r = auto r = mm->add_instruction(migraphx::make_op("dot"), x, x);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), x, x);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -782,8 +767,7 @@ TEST_CASE(dot_half_1arg) ...@@ -782,8 +767,7 @@ TEST_CASE(dot_half_1arg)
zp_b); zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction( auto r = mm->add_instruction(migraphx::make_op("dot"), dqa, dqb);
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -800,10 +784,8 @@ TEST_CASE(dot_half_1arg) ...@@ -800,10 +784,8 @@ TEST_CASE(dot_half_1arg)
scale); scale);
zp = zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction( auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0})); auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction( dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
...@@ -1055,9 +1037,9 @@ TEST_CASE(int8_quantization_dot) ...@@ -1055,9 +1037,9 @@ TEST_CASE(int8_quantization_dot)
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc); auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto r =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -1075,7 +1057,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1075,7 +1057,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result)); EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000));
} }
} }
...@@ -1142,8 +1124,7 @@ TEST_CASE(int8_subgraph) ...@@ -1142,8 +1124,7 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw); auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto out1 = then_mod->add_instruction( auto out1 = migraphx::add_apply_alpha_beta(*then_mod, {a, b}, migraphx::make_op("dot"));
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), a, b);
then_mod->add_return({out1}); then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
...@@ -1181,11 +1162,10 @@ TEST_CASE(int8_subgraph) ...@@ -1181,11 +1162,10 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1); migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction( auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1); migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb); auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
then_mod->add_instruction(migraphx::make_op("quant_dot", {{"beta", 0}}), qa, qb); auto so = then_mod->add_literal(100.0f);
auto so = then_mod->add_literal(100.0f); so = then_mod->add_instruction(
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so); migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so); auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r}); then_mod->add_return({r});
...@@ -1251,7 +1231,8 @@ TEST_CASE(test_op_capture) ...@@ -1251,7 +1231,8 @@ TEST_CASE(test_op_capture)
auto pb = mm->add_literal(s2, d2); auto pb = mm->add_literal(s2, d2);
auto pc = mm->add_literal(s2, d2); auto pc = mm->add_literal(s2, d2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2); auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto ps =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("dot"), pa, ps); mm->add_instruction(migraphx::make_op("dot"), pa, ps);
auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {}; auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {};
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "test.hpp" #include "test.hpp"
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
...@@ -211,7 +212,11 @@ TEST_CASE(gemm_mutli_dim_2_beta0) ...@@ -211,7 +212,11 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
...@@ -274,7 +279,11 @@ TEST_CASE(gemm_beta_0) ...@@ -274,7 +279,11 @@ TEST_CASE(gemm_beta_0)
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
...@@ -359,13 +368,13 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -359,13 +368,13 @@ TEST_CASE(gemm_mutli_dim1_2_3)
0.49759611, 0.10021662, 0.00592602, 0.90862000}; 0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}}; migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
auto m12_alpha = auto m12_alpha = migraphx::add_apply_alpha_beta(
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2); *mm, std::vector<migraphx::instruction_ref>{l1, l2}, migraphx::make_op("dot"), alpha);
auto l_beta = mm->add_literal(beta); auto l_beta = mm->add_literal(beta);
auto b_beta = mm->add_instruction( auto b_beta = mm->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta); migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta);
...@@ -418,7 +427,11 @@ TEST_CASE(gemm_mutli_3args) ...@@ -418,7 +427,11 @@ TEST_CASE(gemm_mutli_3args)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
...@@ -479,7 +492,7 @@ TEST_CASE(gemm_3args) ...@@ -479,7 +492,7 @@ TEST_CASE(gemm_3args)
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}}; migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}};
auto cl = mm->add_literal(migraphx::literal{c_shape, c}); auto cl = mm->add_literal(migraphx::literal{c_shape, c});
mm->add_instruction(migraphx::make_op("dot"), al, bl, cl); migraphx::add_apply_alpha_beta(*mm, {al, bl, cl}, migraphx::make_op("dot"), 1.0f, 1.0f);
std::vector<float> gold = {-1.60947, std::vector<float> gold = {-1.60947,
0.703083, 0.703083,
-5.46156, -5.46156,
...@@ -561,7 +574,8 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -561,7 +574,8 @@ TEST_CASE(matmul_vv_inner_product)
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.32f; float alpha = 0.32f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, ubl); migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-0.4590752}; std::vector<float> gold = {-0.4590752};
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -634,7 +648,8 @@ TEST_CASE(matmul_vm) ...@@ -634,7 +648,8 @@ TEST_CASE(matmul_vm)
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
float alpha = 0.5f; float alpha = 0.5f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, bl); migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, bl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163}; std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
...@@ -718,7 +733,8 @@ TEST_CASE(matmul_vm) ...@@ -718,7 +733,8 @@ TEST_CASE(matmul_vm)
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual); migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl); migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{bual, bl}, migraphx::make_op("dot"), 0.21f);
std::vector<float> gold = {0.25812, std::vector<float> gold = {0.25812,
-0.247582, -0.247582,
0.480051, 0.480051,
...@@ -805,7 +821,8 @@ TEST_CASE(matmul_mv) ...@@ -805,7 +821,8 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.3f; float alpha = 0.3f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), al, ubl); migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{al, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {0.395946, 0.357067, -0.588187}; std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -1337,7 +1354,8 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1337,7 +1354,8 @@ TEST_CASE(quant_dot_2args_general)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2}, migraphx::make_op("quant_dot"), 2);
std::vector<int> gold = { std::vector<int> gold = {
28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340}; 28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340};
...@@ -1366,7 +1384,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1366,7 +1384,7 @@ TEST_CASE(quant_dot_2args_general)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2); migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
std::vector<int> gold = { std::vector<int> gold = {
126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410}; 126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410};
...@@ -1398,7 +1416,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1398,7 +1416,7 @@ TEST_CASE(quant_dot_3args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3); migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
std::vector<int> gold = { std::vector<int> gold = {
982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115}; 982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115};
...@@ -1426,9 +1444,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1426,9 +1444,7 @@ TEST_CASE(quant_dot_3args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), l1, l2, l3);
std::vector<int> gold = { std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462}; 70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
...@@ -1459,8 +1475,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1459,8 +1475,7 @@ TEST_CASE(quant_dot_3args_general)
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
std::vector<int> gold = { std::vector<int> gold = {
1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585}; 1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585};
...@@ -1491,8 +1506,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1491,8 +1506,7 @@ TEST_CASE(quant_dot_3args_general)
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
std::vector<int> gold = { std::vector<int> gold = {
286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605}; 286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605};
...@@ -1525,8 +1539,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1525,8 +1539,7 @@ TEST_CASE(quant_dot_3args_general)
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
std::vector<int> gold = { std::vector<int> gold = {
844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170}; 844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170};
...@@ -1558,8 +1571,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1558,8 +1571,7 @@ TEST_CASE(quant_dot_3args_batch)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 2);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), l1, l2, l3);
std::vector<int> gold = { std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380, 102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
...@@ -1596,8 +1608,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1596,8 +1608,7 @@ TEST_CASE(quant_dot_3args_batch)
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3);
std::vector<int> gold = { std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015, 90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
......
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <random>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/batch_norm_inference.hpp> #include <migraphx/op/batch_norm_inference.hpp>
...@@ -2687,6 +2688,56 @@ TEST_CASE(mul_test) ...@@ -2687,6 +2688,56 @@ TEST_CASE(mul_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(multinomial_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size = 100000;
float seed = 0.0f;
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
migraphx::shape s{migraphx::shape::float_type, {1, 5}};
std::vector<int> dist{15, 25, 15, 25, 20};
std::vector<float> data(5);
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return std::log(d); });
auto input = mm->add_literal(migraphx::literal(s, data));
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
auto mb_maxes =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5}}}), maxes);
auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes);
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<int32_t> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int> res_dist(5, 0);
for(auto& r : result_vec)
res_dist[r]++;
auto dist_sum = std::accumulate(dist.begin(), dist.end(), 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
std::vector<float> norm(5);
std::vector<float> res_norm(5);
std::transform(dist.begin(), dist.end(), norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify_range(norm, res_norm, 100000));
}
TEST_CASE(neg_test) TEST_CASE(neg_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2705,6 +2756,26 @@ TEST_CASE(neg_test) ...@@ -2705,6 +2756,26 @@ TEST_CASE(neg_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(nonzero_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> data = {
1.0f, 1.3f, 0.0f, -1.2f, 0.0f, -100.f, 200.f, 0.0f, 0.1f, 0.2f, 0.0f, 0.5f};
auto input = mm->add_literal(migraphx::literal(s, data));
auto ret = mm->add_instruction(migraphx::make_op("nonzero"), input);
mm->add_return({ret});
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::cout << "result = " << result << std::endl;
std::vector<int64_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(not_test) TEST_CASE(not_test)
{ {
// int32 // int32
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/apply_alpha_beta.hpp>
bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; } bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; }
bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; } bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; }
...@@ -127,12 +128,11 @@ TEST_CASE(dot) ...@@ -127,12 +128,11 @@ TEST_CASE(dot)
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0}); auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot}); m1.add_return({dot});
} }
...@@ -144,11 +144,10 @@ TEST_CASE(dot) ...@@ -144,11 +144,10 @@ TEST_CASE(dot)
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f); auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q1, q2); auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
m2.add_return({d3}); m2.add_return({d3});
} }
...@@ -168,22 +167,19 @@ TEST_CASE(dot_non_zero_point) ...@@ -168,22 +167,19 @@ TEST_CASE(dot_non_zero_point)
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{1}); auto zero = m1.add_literal(std::int8_t{1});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot}); m1.add_return({dot});
} }
migraphx::module m2; migraphx::module m2;
{ {
auto t1 = m2.add_parameter("t1", sh1); auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2); auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
m2.add_return({dot}); m2.add_return({dot});
} }
...@@ -203,22 +199,19 @@ TEST_CASE(dot_uint8) ...@@ -203,22 +199,19 @@ TEST_CASE(dot_uint8)
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::uint8_t{0}); auto zero = m1.add_literal(std::uint8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot}); m1.add_return({dot});
} }
migraphx::module m2; migraphx::module m2;
{ {
auto t1 = m2.add_parameter("t1", sh1); auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2); auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
m2.add_return({dot}); m2.add_return({dot});
} }
...@@ -240,12 +233,11 @@ TEST_CASE(dot_add) ...@@ -240,12 +233,11 @@ TEST_CASE(dot_add)
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0}); auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero); auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab); auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab);
...@@ -261,10 +253,9 @@ TEST_CASE(dot_add) ...@@ -261,10 +253,9 @@ TEST_CASE(dot_add)
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f); auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1); auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab); auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
m2.add_return({add}); m2.add_return({add});
...@@ -471,21 +462,20 @@ TEST_CASE(conv_pooling_dot) ...@@ -471,21 +462,20 @@ TEST_CASE(conv_pooling_dot)
d1); d1);
auto bc1 = m1.add_instruction( auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling", auto ap = m1.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {7, 7}}, {"lengths", {7, 7}},
{"ceil_mode", 0}}), {"ceil_mode", 0}}),
a1); a1);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d8, d4);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4); auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto mb1 = auto mb1 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1);
...@@ -518,19 +508,18 @@ TEST_CASE(conv_pooling_dot) ...@@ -518,19 +508,18 @@ TEST_CASE(conv_pooling_dot)
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction( auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling", auto ap = m2.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {7, 7}}, {"lengths", {7, 7}},
{"ceil_mode", 0}}), {"ceil_mode", 0}}),
a1); a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q4, db); auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto mb1 = auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
...@@ -575,25 +564,24 @@ TEST_CASE(mobilenet_snippet) ...@@ -575,25 +564,24 @@ TEST_CASE(mobilenet_snippet)
d1); d1);
auto bc1 = mm.add_instruction( auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero); auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero); auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling", auto ap = mm.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {7, 7}}, {"lengths", {7, 7}},
{"ceil_mode", 0}}), {"ceil_mode", 0}}),
d6); d6);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero); auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero); auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7); auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero); auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot = auto dot = mm.add_instruction(migraphx::make_op("dot"), d8, d4);
mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4); auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero); auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto mb1 = auto mb1 =
mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1);
...@@ -699,12 +687,11 @@ TEST_CASE(dot_correctness) ...@@ -699,12 +687,11 @@ TEST_CASE(dot_correctness)
auto scale_b = m1->add_literal(0.5f); auto scale_b = m1->add_literal(0.5f);
auto zero = m1->add_literal(std::int8_t{0}); auto zero = m1->add_literal(std::int8_t{0});
auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero); auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero);
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero); auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero); auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero); auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot = auto dot = m1->add_instruction(migraphx::make_op("dot"), d1, d2);
m1->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1->add_return({dot}); m1->add_return({dot});
run_pass(*m1); run_pass(*m1);
...@@ -715,8 +702,7 @@ TEST_CASE(dot_correctness) ...@@ -715,8 +702,7 @@ TEST_CASE(dot_correctness)
auto* m2 = p2.get_main_module(); auto* m2 = p2.get_main_module();
auto a = m2->add_parameter("a", sh1); auto a = m2->add_parameter("a", sh1);
auto b = m2->add_parameter("b", sh2); auto b = m2->add_parameter("b", sh2);
auto dot = m2->add_instruction(migraphx::make_op("dot"), a, b);
auto dot = m2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), a, b);
m2->add_return({dot}); m2->add_return({dot});
} }
......
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -21,8 +22,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> ...@@ -21,8 +22,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
return p; return p;
} }
}; };
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -17,8 +18,7 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2> ...@@ -17,8 +18,7 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2, l3);
return p; return p;
} }
}; };
...@@ -15,7 +15,7 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3> ...@@ -15,7 +15,7 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2); mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2);
return p; return p;
} }
}; };
...@@ -19,7 +19,7 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> ...@@ -19,7 +19,7 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1);
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2); migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2); mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2);
return p; return p;
} }
}; };
...@@ -21,7 +21,7 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> ...@@ -21,7 +21,7 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2); auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2); mm->add_instruction(migraphx::make_op("quant_dot"), sl1, sl2);
return p; return p;
} }
}; };
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -17,8 +18,7 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv> ...@@ -17,8 +18,7 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
float alpha = 0.23f; float alpha = 0.23f;
auto res = migraphx::add_apply_alpha_beta(*mm, {ul1, ul2}, migraphx::make_op("dot"), alpha);
auto res = mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ul1, ul2);
auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres);
......
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -19,9 +20,7 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args> ...@@ -19,9 +20,7 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
auto l3 = mm->add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -19,9 +19,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> ...@@ -19,9 +19,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
float alpha = 0.0f; float alpha = 0.0f;
float beta = 1.0f; float beta = 1.0f;
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -19,9 +20,7 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0> ...@@ -19,9 +20,7 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -19,9 +20,7 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25> ...@@ -19,9 +20,7 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
auto l3 = mm->add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -19,8 +20,7 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose> ...@@ -19,8 +20,7 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, tl2); migraphx::add_apply_alpha_beta(*mm, {l1, tl2}, migraphx::make_op("dot"), alpha, beta);
return p; return p;
} }
}; };
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -17,7 +18,7 @@ struct quant_dot_3args_1 : verify_program<quant_dot_3args_1> ...@@ -17,7 +18,7 @@ struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3); migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
return p; return p;
} }
}; };
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -19,8 +20,7 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2> ...@@ -19,8 +20,7 @@ struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment