Unverified Commit 9e43cb8b authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Remove alpha and beta attributes from dot operator (#945)

This PR aims to remove alpha and beta attributes from dot operator completely.

Previously dot operator was defined as C = alpha * A . B + beta * C where * is scalar multiplication and . is dot product or matrix multiplication depending on dimension of the inputs.

Aim is to have the definition of dot operator as C = A . B without having alpha or beta.

In order to achieve the same effect as alpha and beta (1) it multiplies the one of the inputs to the dot operator with alpha value. (2) if beta is present then, multiplies the C with beta and then adds into the output from step 1.
parent 31dc067e
...@@ -7,140 +7,55 @@ ...@@ -7,140 +7,55 @@
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::decompose{}}); } void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::decompose{}}); }
TEST_CASE(dot_add) TEST_CASE(quant_dot_add)
{ {
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = m1.add_instruction(migraphx::make_op("dot"), x, y, z); auto q_dot = m1.add_instruction(migraphx::make_op("quant_dot"), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot); m1.add_instruction(migraphx::make_op("identity"), q_dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, z);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_float)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_half)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_double)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto q_dot =
auto beta = m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), x, y);
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}}); auto add = m2.add_instruction(migraphx::make_op("add"), q_dot, z);
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_add_beta_int) TEST_CASE(quant_dot_add_beta)
{ {
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = auto q_dot = m1.add_instruction(
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); migraphx::make_op("quant_dot", {{"alpha", 1.0}, {"beta", 2}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot); m1.add_instruction(migraphx::make_op("identity"), q_dot);
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y); auto q_dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta = auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}}); m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto beta_broadcast = auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast); auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul); auto add = m2.add_instruction(migraphx::make_op("add"), q_dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add); m2.add_instruction(migraphx::make_op("identity"), add);
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
......
#include "migraphx/instruction.hpp"
#include <migraphx/common.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
TEST_CASE(dot_apply_alpha_beta_half)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot_res = migraphx::insert_dot_apply_alpha_beta(m1, m1.end(), {x, y, z}, 3, 2);
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto ht = migraphx::shape::half_type;
auto ft = migraphx::shape::float_type;
auto x = m2.add_parameter("x", migraphx::shape{ht, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{ht, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{ht, {2, 2}});
auto alpha_literal = m2.add_literal(3.0f);
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto x_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), x);
auto x_alpha_float = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_float);
auto x_half =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), x_alpha_float);
auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_half, y);
auto beta_literal = m2.add_literal(2.0f);
auto z_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), z);
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}),
beta_literal);
auto z_beta_float = m2.add_instruction(migraphx::make_op("mul"), z_float, beta_broadcast);
auto z_beta_half =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), z_beta_float);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_half);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_apply_alpha_beta_double)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 1}});
auto dot_res = migraphx::add_dot_apply_alpha_beta(m1, {x, y, z}, 3, 2);
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto dt = migraphx::shape::double_type;
auto x = m2.add_parameter("x", migraphx::shape{dt, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{dt, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{dt, {2, 1}});
auto alpha_literal = m2.add_literal(3.0f);
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto alpha_double = m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}),
alpha_broadcast);
auto x_alpha_double = m2.add_instruction(migraphx::make_op("mul"), alpha_double, x);
auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_alpha_double, y);
auto z_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), z);
auto beta_literal = m2.add_literal(2.0f);
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z_broadcast->get_shape().lens()}}),
beta_literal);
auto beta_double =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}), beta_broadcast);
auto z_beta_double = m2.add_instruction(migraphx::make_op("mul"), z_broadcast, beta_double);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_double);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1295,9 +1295,7 @@ TEST_CASE(gemm_test) ...@@ -1295,9 +1295,7 @@ TEST_CASE(gemm_test)
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = add_dot_apply_alpha_beta(*mm, {t_a, t1}, 1.0f, 0.0f);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, t1);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b = auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
...@@ -1322,9 +1320,7 @@ TEST_CASE(gemm_ex_test) ...@@ -1322,9 +1320,7 @@ TEST_CASE(gemm_ex_test)
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = add_dot_apply_alpha_beta(*mm, {t_a, l1}, 1.0f, 0.0f);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction( auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l); migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
...@@ -1348,9 +1344,7 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -1348,9 +1344,7 @@ TEST_CASE(gemm_ex_brcst_test)
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = add_dot_apply_alpha_beta(*mm, {t_a, l1}, 1.0f, 0.0f);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b = auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), l2);
...@@ -1378,8 +1372,7 @@ TEST_CASE(gemm_half_test) ...@@ -1378,8 +1372,7 @@ TEST_CASE(gemm_half_test)
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7}; std::vector<std::size_t> lens = {1, 1, 6, 7};
auto dot = auto dot = add_dot_apply_alpha_beta(*mm, {t_a, l1}, 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction( l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
...@@ -1810,8 +1803,7 @@ TEST_CASE(initializer_not_an_input) ...@@ -1810,8 +1803,7 @@ TEST_CASE(initializer_not_an_input)
std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8}; std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8};
auto l1 = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w)); auto l1 = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}}); auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, l1); add_dot_apply_alpha_beta(*mm, {l0, l1}, 1.0f, 0.0f);
auto prog = optimize_onnx("initializer_not_an_input.onnx"); auto prog = optimize_onnx("initializer_not_an_input.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -2112,8 +2104,7 @@ TEST_CASE(matmul_bmbm_test) ...@@ -2112,8 +2104,7 @@ TEST_CASE(matmul_bmbm_test)
migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 6, 7}}}), l0); migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 6, 7}}}), l0);
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 7, 8}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 7, 8}}}), l1);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bl0, bl1); add_dot_apply_alpha_beta(*mm, {bl0, bl1}, 1.0f, 0.0f);
auto prog = optimize_onnx("matmul_bmbm_test.onnx"); auto prog = optimize_onnx("matmul_bmbm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -2128,8 +2119,7 @@ TEST_CASE(matmul_bmv_test) ...@@ -2128,8 +2119,7 @@ TEST_CASE(matmul_bmv_test)
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1); auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto bsl1 = auto bsl1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 7, 1}}}), sl1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 7, 1}}}), sl1);
auto res = auto res = add_dot_apply_alpha_beta(*mm, {l0, bsl1}, 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, bsl1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res);
auto prog = optimize_onnx("matmul_bmv_test.onnx"); auto prog = optimize_onnx("matmul_bmv_test.onnx");
...@@ -2144,8 +2134,7 @@ TEST_CASE(matmul_mv_test) ...@@ -2144,8 +2134,7 @@ TEST_CASE(matmul_mv_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}}); auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1); auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = auto res = add_dot_apply_alpha_beta(*mm, {l0, sl1}, 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, sl1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
auto prog = optimize_onnx("matmul_mv_test.onnx"); auto prog = optimize_onnx("matmul_mv_test.onnx");
...@@ -2162,8 +2151,7 @@ TEST_CASE(matmul_vbm_test) ...@@ -2162,8 +2151,7 @@ TEST_CASE(matmul_vbm_test)
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto bsl0 = auto bsl0 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 1, 7}}}), sl0); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 1, 7}}}), sl0);
auto res = auto res = add_dot_apply_alpha_beta(*mm, {bsl0, l1}, 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bsl0, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
auto prog = optimize_onnx("matmul_vbm_test.onnx"); auto prog = optimize_onnx("matmul_vbm_test.onnx");
...@@ -2178,8 +2166,7 @@ TEST_CASE(matmul_vm_test) ...@@ -2178,8 +2166,7 @@ TEST_CASE(matmul_vm_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto res = auto res = add_dot_apply_alpha_beta(*mm, {sl0, l1}, 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), sl0, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
auto prog = optimize_onnx("matmul_vm_test.onnx"); auto prog = optimize_onnx("matmul_vm_test.onnx");
...@@ -2195,8 +2182,7 @@ TEST_CASE(matmul_vv_test) ...@@ -2195,8 +2182,7 @@ TEST_CASE(matmul_vv_test)
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1); auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = auto res = add_dot_apply_alpha_beta(*mm, {sl0, sl1}, 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), sl0, sl1);
auto sr0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); auto sr0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sr0); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sr0);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <sstream> #include <sstream>
#include <migraphx/common.hpp>
#include "test.hpp" #include "test.hpp"
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -160,9 +161,7 @@ TEST_CASE(program_copy) ...@@ -160,9 +161,7 @@ TEST_CASE(program_copy)
auto para1 = mm1->add_parameter("m1", s1); auto para1 = mm1->add_parameter("m1", s1);
auto para2 = mm1->add_parameter("m2", s2); auto para2 = mm1->add_parameter("m2", s2);
auto para3 = mm1->add_parameter("m3", s3); auto para3 = mm1->add_parameter("m3", s3);
mm1->add_instruction( migraphx::add_dot_apply_alpha_beta(*mm1, {para1, para2, para3}, 0.31f, 0.28f);
migraphx::make_op("dot", {{"alpha", 0.31f}, {"beta", 0.28f}}), para1, para2, para3);
migraphx::program p2{}; migraphx::program p2{};
p2 = p1; p2 = p1;
EXPECT(p2 == p1); EXPECT(p2 == p1);
......
...@@ -111,4 +111,33 @@ TEST_CASE(const_scalar) ...@@ -111,4 +111,33 @@ TEST_CASE(const_scalar)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(const_dot)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {1.0f, 2.0f, 1.0f, 2.0f};
auto l = m1.add_literal(migraphx::literal(s, vec));
auto dl = m1.add_instruction(migraphx::make_op("dot"), l, l);
auto x = m1.add_parameter("x", s);
auto r = m1.add_instruction(migraphx::make_op("add"), dl, x);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {3.0f, 6.0f, 3.0f, 6.0f};
auto x = m2.add_parameter("x", s);
auto l = m2.add_literal(migraphx::literal(s, vec));
auto r = m2.add_instruction(migraphx::make_op("add"), l, x);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/common.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp> #include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_fp16.hpp> #include <migraphx/quantize_fp16.hpp>
...@@ -556,10 +557,8 @@ TEST_CASE(dot_float) ...@@ -556,10 +557,8 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction( auto r = migraphx::add_dot_apply_alpha_beta(*mm, {pa, pb});
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb, pc);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -573,7 +572,6 @@ TEST_CASE(dot_float) ...@@ -573,7 +572,6 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto zp_a = mm->add_literal(static_cast<int8_t>(0)); auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0f); auto scale_a = mm->add_literal(10.0f);
scale_a = mm->add_instruction( scale_a = mm->add_instruction(
...@@ -592,16 +590,7 @@ TEST_CASE(dot_float) ...@@ -592,16 +590,7 @@ TEST_CASE(dot_float)
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto zp_c = mm->add_literal(static_cast<int8_t>(100)); auto r = migraphx::add_dot_apply_alpha_beta(*mm, {dqa, dqb});
auto scale_c = mm->add_literal(10.0f);
scale_c = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), scale_c);
zp_c = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}),
zp_c);
auto qc = mm->add_instruction(migraphx::make_op("quantizelinear"), pc, scale_c, zp_c);
auto dqc = mm->add_instruction(migraphx::make_op("dequantizelinear"), qc, scale_c, zp_c);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), dqa, dqb, dqc);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -615,7 +604,6 @@ TEST_CASE(dot_float) ...@@ -615,7 +604,6 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
mm->add_parameter("c", sc);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f); auto scale = mm->add_literal(10.0f);
auto scale_a = mm->add_instruction( auto scale_a = mm->add_instruction(
...@@ -649,6 +637,7 @@ TEST_CASE(dot_float) ...@@ -649,6 +637,7 @@ TEST_CASE(dot_float)
p, p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); {migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
optimize_prog_int8(p); optimize_prog_int8(p);
...@@ -665,8 +654,7 @@ TEST_CASE(dot_double_2args) ...@@ -665,8 +654,7 @@ TEST_CASE(dot_double_2args)
migraphx::shape sb{migraphx::shape::double_type, {16, 8}}; migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto r = mm->add_instruction( auto r = migraphx::add_dot_apply_alpha_beta(*mm, {pa, pb});
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -696,8 +684,7 @@ TEST_CASE(dot_double_2args) ...@@ -696,8 +684,7 @@ TEST_CASE(dot_double_2args)
zp_b); zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction( auto r = migraphx::add_dot_apply_alpha_beta(*mm, {dqa, dqb});
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -753,8 +740,7 @@ TEST_CASE(dot_half_1arg) ...@@ -753,8 +740,7 @@ TEST_CASE(dot_half_1arg)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {9, 9}}; migraphx::shape s{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto r = auto r = migraphx::add_dot_apply_alpha_beta(*mm, {x, x});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), x, x);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -782,8 +768,7 @@ TEST_CASE(dot_half_1arg) ...@@ -782,8 +768,7 @@ TEST_CASE(dot_half_1arg)
zp_b); zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction( auto r = migraphx::add_dot_apply_alpha_beta(*mm, {dqa, dqb});
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -1055,9 +1040,8 @@ TEST_CASE(int8_quantization_dot) ...@@ -1055,9 +1040,8 @@ TEST_CASE(int8_quantization_dot)
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc); auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto r = migraphx::add_dot_apply_alpha_beta(*mm, {pa, pb, pc}, 1, 1);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -1075,7 +1059,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1075,7 +1059,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result)); EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000));
} }
} }
...@@ -1142,8 +1126,7 @@ TEST_CASE(int8_subgraph) ...@@ -1142,8 +1126,7 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw); auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto out1 = then_mod->add_instruction( auto out1 = migraphx::add_dot_apply_alpha_beta(*then_mod, {a, b});
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), a, b);
then_mod->add_return({out1}); then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include "test.hpp" #include "test.hpp"
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
...@@ -211,7 +212,8 @@ TEST_CASE(gemm_mutli_dim_2_beta0) ...@@ -211,7 +212,8 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{l1, l2, l3}, alpha, beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
...@@ -274,7 +276,8 @@ TEST_CASE(gemm_beta_0) ...@@ -274,7 +276,8 @@ TEST_CASE(gemm_beta_0)
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{l1, l2, l3}, alpha, beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
...@@ -364,8 +367,8 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -364,8 +367,8 @@ TEST_CASE(gemm_mutli_dim1_2_3)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
auto m12_alpha = auto m12_alpha = migraphx::add_dot_apply_alpha_beta(
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2); *mm, std::vector<migraphx::instruction_ref>{l1, l2}, alpha);
auto l_beta = mm->add_literal(beta); auto l_beta = mm->add_literal(beta);
auto b_beta = mm->add_instruction( auto b_beta = mm->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta); migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta);
...@@ -418,7 +421,8 @@ TEST_CASE(gemm_mutli_3args) ...@@ -418,7 +421,8 @@ TEST_CASE(gemm_mutli_3args)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{l1, l2, l3}, alpha, beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
...@@ -479,7 +483,7 @@ TEST_CASE(gemm_3args) ...@@ -479,7 +483,7 @@ TEST_CASE(gemm_3args)
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}}; migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}};
auto cl = mm->add_literal(migraphx::literal{c_shape, c}); auto cl = mm->add_literal(migraphx::literal{c_shape, c});
mm->add_instruction(migraphx::make_op("dot"), al, bl, cl); migraphx::add_dot_apply_alpha_beta(*mm, {al, bl, cl}, 1, 1);
std::vector<float> gold = {-1.60947, std::vector<float> gold = {-1.60947,
0.703083, 0.703083,
-5.46156, -5.46156,
...@@ -561,7 +565,8 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -561,7 +565,8 @@ TEST_CASE(matmul_vv_inner_product)
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.32f; float alpha = 0.32f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, ubl); migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, ubl}, alpha);
std::vector<float> gold = {-0.4590752}; std::vector<float> gold = {-0.4590752};
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -634,7 +639,8 @@ TEST_CASE(matmul_vm) ...@@ -634,7 +639,8 @@ TEST_CASE(matmul_vm)
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
float alpha = 0.5f; float alpha = 0.5f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, bl); migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, bl}, alpha);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163}; std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
...@@ -718,7 +724,8 @@ TEST_CASE(matmul_vm) ...@@ -718,7 +724,8 @@ TEST_CASE(matmul_vm)
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual); migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl); migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{bual, bl}, 0.21f);
std::vector<float> gold = {0.25812, std::vector<float> gold = {0.25812,
-0.247582, -0.247582,
0.480051, 0.480051,
...@@ -805,7 +812,8 @@ TEST_CASE(matmul_mv) ...@@ -805,7 +812,8 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.3f; float alpha = 0.3f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), al, ubl); migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{al, ubl}, alpha);
std::vector<float> gold = {0.395946, 0.357067, -0.588187}; std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/common.hpp>
bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; } bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; }
bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; } bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; }
...@@ -131,8 +132,7 @@ TEST_CASE(dot) ...@@ -131,8 +132,7 @@ TEST_CASE(dot)
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d1, d2});
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot}); m1.add_return({dot});
} }
...@@ -172,8 +172,7 @@ TEST_CASE(dot_non_zero_point) ...@@ -172,8 +172,7 @@ TEST_CASE(dot_non_zero_point)
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d1, d2});
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot}); m1.add_return({dot});
} }
...@@ -181,9 +180,7 @@ TEST_CASE(dot_non_zero_point) ...@@ -181,9 +180,7 @@ TEST_CASE(dot_non_zero_point)
{ {
auto t1 = m2.add_parameter("t1", sh1); auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2); auto t2 = m2.add_parameter("t2", sh2);
auto dot = migraphx::add_dot_apply_alpha_beta(m2, {t1, t2});
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
m2.add_return({dot}); m2.add_return({dot});
} }
...@@ -207,8 +204,7 @@ TEST_CASE(dot_uint8) ...@@ -207,8 +204,7 @@ TEST_CASE(dot_uint8)
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d1, d2});
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot}); m1.add_return({dot});
} }
...@@ -216,9 +212,7 @@ TEST_CASE(dot_uint8) ...@@ -216,9 +212,7 @@ TEST_CASE(dot_uint8)
{ {
auto t1 = m2.add_parameter("t1", sh1); auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2); auto t2 = m2.add_parameter("t2", sh2);
auto dot = migraphx::add_dot_apply_alpha_beta(m2, {t1, t2});
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
m2.add_return({dot}); m2.add_return({dot});
} }
...@@ -244,8 +238,7 @@ TEST_CASE(dot_add) ...@@ -244,8 +238,7 @@ TEST_CASE(dot_add)
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d1, d2});
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero); auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab); auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab);
...@@ -482,8 +475,7 @@ TEST_CASE(conv_pooling_dot) ...@@ -482,8 +475,7 @@ TEST_CASE(conv_pooling_dot)
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot = auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d8, d4});
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero); auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto mb1 = auto mb1 =
...@@ -590,8 +582,7 @@ TEST_CASE(mobilenet_snippet) ...@@ -590,8 +582,7 @@ TEST_CASE(mobilenet_snippet)
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7); auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero); auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot = auto dot = migraphx::add_dot_apply_alpha_beta(mm, {d8, d4});
mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero); auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero); auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto mb1 = auto mb1 =
...@@ -703,8 +694,7 @@ TEST_CASE(dot_correctness) ...@@ -703,8 +694,7 @@ TEST_CASE(dot_correctness)
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero); auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero); auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero); auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot = auto dot = migraphx::add_dot_apply_alpha_beta(*m1, {d1, d2});
m1->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1->add_return({dot}); m1->add_return({dot});
run_pass(*m1); run_pass(*m1);
...@@ -715,8 +705,7 @@ TEST_CASE(dot_correctness) ...@@ -715,8 +705,7 @@ TEST_CASE(dot_correctness)
auto* m2 = p2.get_main_module(); auto* m2 = p2.get_main_module();
auto a = m2->add_parameter("a", sh1); auto a = m2->add_parameter("a", sh1);
auto b = m2->add_parameter("b", sh2); auto b = m2->add_parameter("b", sh2);
auto dot = migraphx::add_dot_apply_alpha_beta(*m2, {a, b});
auto dot = m2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), a, b);
m2->add_return({dot}); m2->add_return({dot});
} }
......
#include "migraphx/common.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -17,8 +18,7 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv> ...@@ -17,8 +18,7 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
float alpha = 0.23f; float alpha = 0.23f;
auto res = migraphx::add_dot_apply_alpha_beta(*mm, {ul1, ul2}, alpha);
auto res = mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ul1, ul2);
auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres);
......
#include "migraphx/common.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -19,9 +20,7 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args> ...@@ -19,9 +20,7 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
auto l3 = mm->add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction( migraphx::add_dot_apply_alpha_beta(*mm, {l1, l2, l3}, alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -19,9 +19,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> ...@@ -19,9 +19,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
float alpha = 0.0f; float alpha = 0.0f;
float beta = 1.0f; float beta = 1.0f;
mm->add_instruction( migraphx::add_dot_apply_alpha_beta(*mm, {l1, l2, l3}, alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
#include "migraphx/common.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -19,9 +20,7 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0> ...@@ -19,9 +20,7 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
mm->add_instruction( migraphx::add_dot_apply_alpha_beta(*mm, {l1, l2, l3}, alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
#include "migraphx/common.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -19,9 +20,7 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25> ...@@ -19,9 +20,7 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
auto l3 = mm->add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction( migraphx::add_dot_apply_alpha_beta(*mm, {l1, l2, l3}, alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
#include "migraphx/common.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -19,8 +20,7 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose> ...@@ -19,8 +20,7 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, tl2); migraphx::add_dot_apply_alpha_beta(*mm, {l1, tl2}, alpha, beta);
return p; return p;
} }
}; };
#include "migraphx/common.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -12,13 +13,12 @@ struct test_gemm_copy : verify_program<test_gemm_copy> ...@@ -12,13 +13,12 @@ struct test_gemm_copy : verify_program<test_gemm_copy>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {1, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc); auto pc = mm->add_parameter("c", sc);
auto dr = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto dr = migraphx::add_dot_apply_alpha_beta(*mm, {pa, pb, pc}, 1, 1);
mm->add_instruction(migraphx::make_op("add"), dr, dr); mm->add_instruction(migraphx::make_op("add"), dr, dr);
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment