Unverified Commit 985f58b0 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Revert "Remove alpha and beta attributes from dot operator (#945)" (#957)

This reverts commit 9e43cb8b.
parent 9e43cb8b
......@@ -7,55 +7,140 @@
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::decompose{}}); }
TEST_CASE(quant_dot_add)
TEST_CASE(dot_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto q_dot = m1.add_instruction(migraphx::make_op("quant_dot"), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), q_dot);
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m1.add_instruction(migraphx::make_op("dot"), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto q_dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto add = m2.add_instruction(migraphx::make_op("add"), q_dot, z);
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, z);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_add_beta)
TEST_CASE(dot_add_beta_float)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto q_dot = m1.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1.0}, {"beta", 2}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), q_dot);
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto q_dot =
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), q_dot, mul);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_half)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_double)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_int)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
......
#include "migraphx/instruction.hpp"
#include <migraphx/common.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
TEST_CASE(dot_apply_alpha_beta_half)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot_res = migraphx::insert_dot_apply_alpha_beta(m1, m1.end(), {x, y, z}, 3, 2);
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto ht = migraphx::shape::half_type;
auto ft = migraphx::shape::float_type;
auto x = m2.add_parameter("x", migraphx::shape{ht, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{ht, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{ht, {2, 2}});
auto alpha_literal = m2.add_literal(3.0f);
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto x_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), x);
auto x_alpha_float = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_float);
auto x_half =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), x_alpha_float);
auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_half, y);
auto beta_literal = m2.add_literal(2.0f);
auto z_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), z);
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}),
beta_literal);
auto z_beta_float = m2.add_instruction(migraphx::make_op("mul"), z_float, beta_broadcast);
auto z_beta_half =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), z_beta_float);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_half);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_apply_alpha_beta_double)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 1}});
auto dot_res = migraphx::add_dot_apply_alpha_beta(m1, {x, y, z}, 3, 2);
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto dt = migraphx::shape::double_type;
auto x = m2.add_parameter("x", migraphx::shape{dt, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{dt, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{dt, {2, 1}});
auto alpha_literal = m2.add_literal(3.0f);
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto alpha_double = m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}),
alpha_broadcast);
auto x_alpha_double = m2.add_instruction(migraphx::make_op("mul"), alpha_double, x);
auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_alpha_double, y);
auto z_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), z);
auto beta_literal = m2.add_literal(2.0f);
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z_broadcast->get_shape().lens()}}),
beta_literal);
auto beta_double =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}), beta_broadcast);
auto z_beta_double = m2.add_instruction(migraphx::make_op("mul"), z_broadcast, beta_double);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_double);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -1293,9 +1293,11 @@ TEST_CASE(gemm_test)
auto beta = 2.0f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = add_dot_apply_alpha_beta(*mm, {t_a, t1}, 1.0f, 0.0f);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, t1);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
......@@ -1320,7 +1322,9 @@ TEST_CASE(gemm_ex_test)
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = add_dot_apply_alpha_beta(*mm, {t_a, l1}, 1.0f, 0.0f);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
......@@ -1344,7 +1348,9 @@ TEST_CASE(gemm_ex_brcst_test)
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = add_dot_apply_alpha_beta(*mm, {t_a, l1}, 1.0f, 0.0f);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), l2);
......@@ -1372,7 +1378,8 @@ TEST_CASE(gemm_half_test)
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7};
auto dot = add_dot_apply_alpha_beta(*mm, {t_a, l1}, 1.0f, 0.0f);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
......@@ -1803,7 +1810,8 @@ TEST_CASE(initializer_not_an_input)
std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8};
auto l1 = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
add_dot_apply_alpha_beta(*mm, {l0, l1}, 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, l1);
auto prog = optimize_onnx("initializer_not_an_input.onnx");
EXPECT(p == prog);
......@@ -2104,7 +2112,8 @@ TEST_CASE(matmul_bmbm_test)
migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 6, 7}}}), l0);
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 7, 8}}}), l1);
add_dot_apply_alpha_beta(*mm, {bl0, bl1}, 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bl0, bl1);
auto prog = optimize_onnx("matmul_bmbm_test.onnx");
EXPECT(p == prog);
......@@ -2119,7 +2128,8 @@ TEST_CASE(matmul_bmv_test)
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto bsl1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 7, 1}}}), sl1);
auto res = add_dot_apply_alpha_beta(*mm, {l0, bsl1}, 1.0f, 0.0f);
auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, bsl1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res);
auto prog = optimize_onnx("matmul_bmv_test.onnx");
......@@ -2134,7 +2144,8 @@ TEST_CASE(matmul_mv_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = add_dot_apply_alpha_beta(*mm, {l0, sl1}, 1.0f, 0.0f);
auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, sl1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
auto prog = optimize_onnx("matmul_mv_test.onnx");
......@@ -2151,7 +2162,8 @@ TEST_CASE(matmul_vbm_test)
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto bsl0 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 1, 7}}}), sl0);
auto res = add_dot_apply_alpha_beta(*mm, {bsl0, l1}, 1.0f, 0.0f);
auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bsl0, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
auto prog = optimize_onnx("matmul_vbm_test.onnx");
......@@ -2166,7 +2178,8 @@ TEST_CASE(matmul_vm_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto res = add_dot_apply_alpha_beta(*mm, {sl0, l1}, 1.0f, 0.0f);
auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), sl0, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
auto prog = optimize_onnx("matmul_vm_test.onnx");
......@@ -2182,7 +2195,8 @@ TEST_CASE(matmul_vv_test)
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = add_dot_apply_alpha_beta(*mm, {sl0, sl1}, 1.0f, 0.0f);
auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), sl0, sl1);
auto sr0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sr0);
......
......@@ -4,7 +4,6 @@
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <sstream>
#include <migraphx/common.hpp>
#include "test.hpp"
#include <migraphx/make_op.hpp>
......@@ -161,7 +160,9 @@ TEST_CASE(program_copy)
auto para1 = mm1->add_parameter("m1", s1);
auto para2 = mm1->add_parameter("m2", s2);
auto para3 = mm1->add_parameter("m3", s3);
migraphx::add_dot_apply_alpha_beta(*mm1, {para1, para2, para3}, 0.31f, 0.28f);
mm1->add_instruction(
migraphx::make_op("dot", {{"alpha", 0.31f}, {"beta", 0.28f}}), para1, para2, para3);
migraphx::program p2{};
p2 = p1;
EXPECT(p2 == p1);
......
......@@ -111,33 +111,4 @@ TEST_CASE(const_scalar)
EXPECT(m1 == m2);
}
TEST_CASE(const_dot)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {1.0f, 2.0f, 1.0f, 2.0f};
auto l = m1.add_literal(migraphx::literal(s, vec));
auto dl = m1.add_instruction(migraphx::make_op("dot"), l, l);
auto x = m1.add_parameter("x", s);
auto r = m1.add_instruction(migraphx::make_op("add"), dl, x);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {3.0f, 6.0f, 3.0f, 6.0f};
auto x = m2.add_parameter("x", s);
auto l = m2.add_literal(migraphx::literal(s, vec));
auto r = m2.add_instruction(migraphx::make_op("add"), l, x);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -6,7 +6,6 @@
#include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/common.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_fp16.hpp>
......@@ -557,8 +556,10 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto r = migraphx::add_dot_apply_alpha_beta(*mm, {pa, pb});
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb, pc);
mm->add_return({r});
return p;
......@@ -572,6 +573,7 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0f);
scale_a = mm->add_instruction(
......@@ -590,7 +592,16 @@ TEST_CASE(dot_float)
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = migraphx::add_dot_apply_alpha_beta(*mm, {dqa, dqb});
auto zp_c = mm->add_literal(static_cast<int8_t>(100));
auto scale_c = mm->add_literal(10.0f);
scale_c = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), scale_c);
zp_c = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}),
zp_c);
auto qc = mm->add_instruction(migraphx::make_op("quantizelinear"), pc, scale_c, zp_c);
auto dqc = mm->add_instruction(migraphx::make_op("dequantizelinear"), qc, scale_c, zp_c);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), dqa, dqb, dqc);
mm->add_return({r});
return p;
......@@ -602,8 +613,9 @@ TEST_CASE(dot_float)
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
mm->add_parameter("c", sc);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
auto scale_a = mm->add_instruction(
......@@ -637,7 +649,6 @@ TEST_CASE(dot_float)
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
optimize_prog_int8(p);
......@@ -654,7 +665,8 @@ TEST_CASE(dot_double_2args)
migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto r = migraphx::add_dot_apply_alpha_beta(*mm, {pa, pb});
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb);
mm->add_return({r});
return p;
......@@ -684,7 +696,8 @@ TEST_CASE(dot_double_2args)
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = migraphx::add_dot_apply_alpha_beta(*mm, {dqa, dqb});
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r});
return p;
};
......@@ -740,7 +753,8 @@ TEST_CASE(dot_half_1arg)
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", s);
auto r = migraphx::add_dot_apply_alpha_beta(*mm, {x, x});
auto r =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), x, x);
mm->add_return({r});
return p;
......@@ -768,7 +782,8 @@ TEST_CASE(dot_half_1arg)
zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = migraphx::add_dot_apply_alpha_beta(*mm, {dqa, dqb});
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r});
return p;
};
......@@ -1040,8 +1055,9 @@ TEST_CASE(int8_quantization_dot)
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto r = migraphx::add_dot_apply_alpha_beta(*mm, {pa, pb, pc}, 1, 1);
auto r = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
mm->add_return({r});
return p;
};
......@@ -1059,7 +1075,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000));
EXPECT(migraphx::verify_range(quant_result, no_quant_result));
}
}
......@@ -1126,7 +1142,8 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if");
auto out1 = migraphx::add_dot_apply_alpha_beta(*then_mod, {a, b});
auto out1 = then_mod->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), a, b);
then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else");
......
......@@ -6,7 +6,6 @@
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>
......@@ -212,8 +211,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f;
float beta = 0.0f;
migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{l1, l2, l3}, alpha, beta);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> m;
......@@ -276,8 +274,7 @@ TEST_CASE(gemm_beta_0)
float alpha = 1.0f;
float beta = 0.0f;
migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{l1, l2, l3}, alpha, beta);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> m;
......@@ -362,13 +359,13 @@ TEST_CASE(gemm_mutli_dim1_2_3)
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha = migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{l1, l2}, alpha);
auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
auto m12_alpha =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
auto l_beta = mm->add_literal(beta);
auto b_beta = mm->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta);
......@@ -421,8 +418,7 @@ TEST_CASE(gemm_mutli_3args)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35;
float beta = 0.41;
migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{l1, l2, l3}, alpha, beta);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> m;
......@@ -483,7 +479,7 @@ TEST_CASE(gemm_3args)
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}};
auto cl = mm->add_literal(migraphx::literal{c_shape, c});
migraphx::add_dot_apply_alpha_beta(*mm, {al, bl, cl}, 1, 1);
mm->add_instruction(migraphx::make_op("dot"), al, bl, cl);
std::vector<float> gold = {-1.60947,
0.703083,
-5.46156,
......@@ -565,8 +561,7 @@ TEST_CASE(matmul_vv_inner_product)
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.32f;
migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, ubl}, alpha);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, ubl);
std::vector<float> gold = {-0.4590752};
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
......@@ -639,8 +634,7 @@ TEST_CASE(matmul_vm)
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
float alpha = 0.5f;
migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, bl}, alpha);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, bl);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::ref::target{});
......@@ -724,8 +718,7 @@ TEST_CASE(matmul_vm)
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{bual, bl}, 0.21f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl);
std::vector<float> gold = {0.25812,
-0.247582,
0.480051,
......@@ -812,8 +805,7 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.3f;
migraphx::add_dot_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{al, ubl}, alpha);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), al, ubl);
std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
......
......@@ -10,7 +10,6 @@
#include <migraphx/generate.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/common.hpp>
bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; }
bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; }
......@@ -128,11 +127,12 @@ TEST_CASE(dot)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d1, d2});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot});
}
......@@ -168,19 +168,22 @@ TEST_CASE(dot_non_zero_point)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{1});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d1, d2});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = migraphx::add_dot_apply_alpha_beta(m2, {t1, t2});
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
m2.add_return({dot});
}
......@@ -200,19 +203,22 @@ TEST_CASE(dot_uint8)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::uint8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d1, d2});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = migraphx::add_dot_apply_alpha_beta(m2, {t1, t2});
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
m2.add_return({dot});
}
......@@ -234,11 +240,12 @@ TEST_CASE(dot_add)
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d1, d2});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab);
......@@ -464,20 +471,21 @@ TEST_CASE(conv_pooling_dot)
d1);
auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling",
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
a1);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot = migraphx::add_dot_apply_alpha_beta(m1, {d8, d4});
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto mb1 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1);
......@@ -567,24 +575,25 @@ TEST_CASE(mobilenet_snippet)
d1);
auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling",
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
d6);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot = migraphx::add_dot_apply_alpha_beta(mm, {d8, d4});
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot =
mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto mb1 =
mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1);
......@@ -690,11 +699,12 @@ TEST_CASE(dot_correctness)
auto scale_b = m1->add_literal(0.5f);
auto zero = m1->add_literal(std::int8_t{0});
auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero);
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot = migraphx::add_dot_apply_alpha_beta(*m1, {d1, d2});
auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero);
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot =
m1->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1->add_return({dot});
run_pass(*m1);
......@@ -705,7 +715,8 @@ TEST_CASE(dot_correctness)
auto* m2 = p2.get_main_module();
auto a = m2->add_parameter("a", sh1);
auto b = m2->add_parameter("b", sh2);
auto dot = migraphx::add_dot_apply_alpha_beta(*m2, {a, b});
auto dot = m2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), a, b);
m2->add_return({dot});
}
......
#include "migraphx/common.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -18,8 +17,9 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
float alpha = 0.23f;
auto res = migraphx::add_dot_apply_alpha_beta(*mm, {ul1, ul2}, alpha);
auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
auto res = mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ul1, ul2);
auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres);
return p;
......
#include "migraphx/common.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -20,7 +19,9 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
migraphx::add_dot_apply_alpha_beta(*mm, {l1, l2, l3}, alpha, beta);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p;
}
};
......@@ -3,7 +3,7 @@
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{
migraphx::program create_program() const
......@@ -19,7 +19,9 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
float alpha = 0.0f;
float beta = 1.0f;
migraphx::add_dot_apply_alpha_beta(*mm, {l1, l2, l3}, alpha, beta);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p;
}
};
#include "migraphx/common.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -20,7 +19,9 @@ struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
float alpha = 1.0f;
float beta = 0.0f;
migraphx::add_dot_apply_alpha_beta(*mm, {l1, l2, l3}, alpha, beta);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p;
}
};
#include "migraphx/common.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -20,7 +19,9 @@ struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
migraphx::add_dot_apply_alpha_beta(*mm, {l1, l2, l3}, alpha, beta);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p;
}
};
#include "migraphx/common.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -20,7 +19,8 @@ struct gemm_multi_transpose : verify_program<gemm_multi_transpose>
float alpha = 1.0f;
float beta = 1.0f;
migraphx::add_dot_apply_alpha_beta(*mm, {l1, tl2}, alpha, beta);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, tl2);
return p;
}
};
#include "migraphx/common.hpp"
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
......@@ -13,12 +12,13 @@ struct test_gemm_copy : verify_program<test_gemm_copy>
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {1, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto dr = migraphx::add_dot_apply_alpha_beta(*mm, {pa, pb, pc}, 1, 1);
auto dr = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
mm->add_instruction(migraphx::make_op("add"), dr, dr);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment