Unverified Commit 21193e87 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Remove alpha and beta from `dot` and `quant_dot` (#961)

Previously dot operator was defined as C = alpha * A . B + beta * C where * is scalar multiplication and . is dot product or matrix multiplication depending on dimension of the inputs.

Aim is to have the definition of dot operator as C = A . B without having alpha or beta.

In order to achieve the same effect as alpha and beta (1) it multiplies the one of the inputs to the dot operator with alpha value. (2) if beta is present then, multiplies the C with beta and then adds into the output from step 1.
parent 87978f03
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/check_context.hpp> #include <migraphx/check_context.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/decompose.hpp>
#include <migraphx/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
...@@ -17,7 +16,6 @@ ...@@ -17,7 +16,6 @@
#include <migraphx/preallocate_param.hpp> #include <migraphx/preallocate_param.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
...@@ -59,7 +57,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -59,7 +57,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
return return
{ {
normalize_ops{}, normalize_ops{},
decompose{},
dead_code_elimination{}, dead_code_elimination{},
simplify_qdq{}, simplify_qdq{},
rewrite_quantization{}, rewrite_quantization{},
......
...@@ -518,42 +518,12 @@ struct ref_gemm ...@@ -518,42 +518,12 @@ struct ref_gemm
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
std::string name() const { return "ref::dot"; } std::string name() const { return "ref::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}, *this}.not_broadcasted();
}
return op.compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then migemm(result, args[0], args[1], 1.0f, 0.0f);
// A and B are matrices, and C is of the same shape as A * B
if(args.size() == 3)
{
// no need to consider the value of args[2]
if(op.beta == 0.0f)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
else
{
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, args[0], args[1], op.alpha, op.beta);
return result;
}
// 2 input arguments
migemm(result, args[0], args[1], op.alpha, 0.0f);
return result; return result;
} }
...@@ -571,22 +541,11 @@ struct ref_quant_gemm ...@@ -571,22 +541,11 @@ struct ref_quant_gemm
} }
std::string name() const { return "ref::quant_dot"; } std::string name() const { return "ref::quant_dot"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
{
if(inputs.size() == 3)
{
auto c_shape = inputs.at(2);
check_shapes{{c_shape}, *this}.not_broadcasted();
}
return op.compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrices, and C is of the same shape to A * B
// first, convert the args[0] and args[1] from int8_t to int32_t // first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}}; argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}}; argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
...@@ -600,27 +559,7 @@ struct ref_quant_gemm ...@@ -600,27 +559,7 @@ struct ref_quant_gemm
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
}); });
if(args.size() == 3) migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});
{
// no need to consider the value of args[2]
if(op.beta == 0)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
}
else
{
visit_all(result, args[2])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
}
migemm(result, arg_0, arg_1, op.alpha, op.beta);
return result;
}
// 2 input arguments
migemm(result, arg_0, arg_1, op.alpha, int32_t{0});
return result; return result;
} }
......
#include <migraphx/decompose.hpp>
#include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::decompose{}}); }
TEST_CASE(dot_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m1.add_instruction(migraphx::make_op("dot"), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, z);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_float)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_half)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_double)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_beta_int)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot =
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}});
auto beta_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <cstdint>
#include <migraphx/instruction.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
TEST_CASE(dot_apply_alpha_beta_half)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}});
auto dot_res = migraphx::insert_apply_alpha_beta(
m1, m1.end(), {x, y, z}, migraphx::make_op("dot"), 3.0f, 2.0f);
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto ht = migraphx::shape::half_type;
auto ft = migraphx::shape::float_type;
auto x = m2.add_parameter("x", migraphx::shape{ht, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{ht, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{ht, {2, 2}});
auto alpha_literal = m2.add_literal(3.0f);
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto x_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), x);
auto x_alpha_float = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_float);
auto x_half =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), x_alpha_float);
auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_half, y);
auto beta_literal = m2.add_literal(2.0f);
auto z_float = m2.add_instruction(migraphx::make_op("convert", {{"target_type", ft}}), z);
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}),
beta_literal);
auto z_beta_float = m2.add_instruction(migraphx::make_op("mul"), z_float, beta_broadcast);
auto z_beta_half =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", ht}}), z_beta_float);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_half);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
TEST_CASE(dot_apply_alpha_beta_double)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 1}});
auto dot_res =
migraphx::add_apply_alpha_beta(m1, {x, y, z}, migraphx::make_op("dot"), 3.0f, 2.0f);
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto dt = migraphx::shape::double_type;
auto x = m2.add_parameter("x", migraphx::shape{dt, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{dt, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{dt, {2, 1}});
auto alpha_literal = m2.add_literal(3.0f);
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto alpha_double = m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}),
alpha_broadcast);
auto x_alpha_double = m2.add_instruction(migraphx::make_op("mul"), alpha_double, x);
auto dot_res = m2.add_instruction(migraphx::make_op("dot"), x_alpha_double, y);
auto z_broadcast =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 2}}}), z);
auto beta_literal = m2.add_literal(2.0f);
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z_broadcast->get_shape().lens()}}),
beta_literal);
auto beta_double =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", dt}}), beta_broadcast);
auto z_beta_double = m2.add_instruction(migraphx::make_op("mul"), z_broadcast, beta_double);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_double);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_apply_alpha_beta)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto y = m1.add_parameter("y", migraphx::shape{migraphx::shape::int8_type, {2, 2}});
auto z = m1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot_res = migraphx::insert_apply_alpha_beta(m1,
m1.end(),
{x, y, z},
migraphx::make_op("quant_dot"),
migraphx::literal{int32_t{3}},
migraphx::literal{int32_t{2}});
m1.add_instruction(migraphx::make_op("identity"), dot_res);
}
migraphx::module m2;
{
auto i8 = migraphx::shape::int8_type;
auto i32 = migraphx::shape::int32_type;
auto x = m2.add_parameter("x", migraphx::shape{i8, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{i8, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{i32, {2, 2}});
auto alpha_literal = m2.add_literal(int32_t(3));
auto alpha_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}),
alpha_literal);
auto x_i32 = m2.add_instruction(migraphx::make_op("convert", {{"target_type", i32}}), x);
auto x_alpha_i32 = m2.add_instruction(migraphx::make_op("mul"), alpha_broadcast, x_i32);
auto x_i8 =
m2.add_instruction(migraphx::make_op("convert", {{"target_type", i8}}), x_alpha_i32);
auto dot_res = m2.add_instruction(migraphx::make_op("quant_dot"), x_i8, y);
auto beta_literal = m2.add_literal(int32_t(2));
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", z->get_shape().lens()}}),
beta_literal);
auto z_beta_i32 = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto z_add = m2.add_instruction(migraphx::make_op("add"), dot_res, z_beta_i32);
m2.add_instruction(migraphx::make_op("identity"), z_add);
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/adjust_allocation.hpp> #include <migraphx/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp> #include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
...@@ -48,7 +49,8 @@ TEST_CASE(quant_dot) ...@@ -48,7 +49,8 @@ TEST_CASE(quant_dot)
auto l1 = m.add_parameter("a", m1_shape); auto l1 = m.add_parameter("a", m1_shape);
auto l2 = m.add_parameter("b", m2_shape); auto l2 = m.add_parameter("b", m2_shape);
auto l3 = m.add_parameter("c", m3_shape); auto l3 = m.add_parameter("c", m3_shape);
auto r = m.add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3); auto r =
migraphx::add_apply_alpha_beta(m, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
m.add_return({r}); m.add_return({r});
return m; return m;
}; };
...@@ -59,12 +61,14 @@ TEST_CASE(quant_dot) ...@@ -59,12 +61,14 @@ TEST_CASE(quant_dot)
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", m1_shape); auto l1 = m.add_parameter("a", m1_shape);
auto l2 = m.add_parameter("b", m2_shape); auto l2 = m.add_parameter("b", m2_shape);
auto l3 = m.add_parameter("c", m3_shape); auto l3 = m.add_parameter("c", m3_shape);
auto output = m.add_parameter("test:#output_0", m3_shape); auto beta = m.add_literal(1);
auto output = m.add_parameter("test:#output_0", m3_shape);
auto gemm_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto cout = m.add_instruction(migraphx::make_op("hip::copy"), l3, output);
auto packa = l2; auto packa = l2;
if(int8_x4) if(int8_x4)
{ {
...@@ -72,14 +76,24 @@ TEST_CASE(quant_dot) ...@@ -72,14 +76,24 @@ TEST_CASE(quant_dot)
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m2_shape)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m2_shape)}}));
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), l2, alloc); packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), l2, alloc);
} }
auto gemm = m.add_instruction( auto gemm =
migraphx::make_op("gpu::quant_gemm", m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}),
{{"alpha", 1}, {"beta", 1}, {"int8_x4_format", int8_x4}}), l1,
l1, packa,
packa, gemm_alloc);
cout,
cout); auto beta_broadcast = m.add_instruction(
m.add_return({gemm}); migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta);
auto beta_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto beta_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto m3_beta =
m.add_instruction(migraphx::make_op("gpu::mul"), l3, beta_contiguous, mul_alloc);
auto gemm_add = m.add_instruction(migraphx::make_op("gpu::add"), gemm, m3_beta, output);
m.add_return({gemm_add});
return m; return m;
}; };
...@@ -89,7 +103,6 @@ TEST_CASE(quant_dot) ...@@ -89,7 +103,6 @@ TEST_CASE(quant_dot)
bool flag = get_int8_x4_format(); bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag); auto m2 = create_optimized_int8_x4(flag);
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
...@@ -106,8 +119,7 @@ TEST_CASE(quant_dot_trans) ...@@ -106,8 +119,7 @@ TEST_CASE(quant_dot_trans)
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto tl2 = auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction( auto r = migraphx::add_apply_alpha_beta(m, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r}); m.add_return({r});
return m; return m;
}; };
...@@ -120,6 +132,7 @@ TEST_CASE(quant_dot_trans) ...@@ -120,6 +132,7 @@ TEST_CASE(quant_dot_trans)
auto l1 = m.add_parameter("a", s1); auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto alpha = m.add_literal(3);
auto output = m.add_parameter("test:#output_0", s3); auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = auto tl1 =
...@@ -136,6 +149,34 @@ TEST_CASE(quant_dot_trans) ...@@ -136,6 +149,34 @@ TEST_CASE(quant_dot_trans)
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, allocb); auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, allocb);
auto alpha_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha);
auto alpha_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate",
{{"shape",
migraphx::to_value(migraphx::shape(migraphx::shape::int32_type, {3, 2, 5, 8}))}}));
auto alpha_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc);
// alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert
// back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert = m.add_instruction(
migraphx::make_op("gpu::convert", {{"target_type", alpha->get_shape().type()}}),
conta,
tl1_convert_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}}));
auto tl1_alpha_int32 = m.add_instruction(
migraphx::make_op("gpu::mul"), alpha_contiguous, tl1_convert, mul_alloc);
// convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}}));
auto tl1_alpha_int8 = m.add_instruction(
migraphx::make_op("gpu::convert", {{"target_type", conta->get_shape().type()}}),
tl1_alpha_int32,
tl1_alpha_int8_alloc);
auto packb = contb; auto packb = contb;
if(int8_x4) if(int8_x4)
{ {
...@@ -143,12 +184,12 @@ TEST_CASE(quant_dot_trans) ...@@ -143,12 +184,12 @@ TEST_CASE(quant_dot_trans)
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb); packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb);
} }
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm", auto gemm =
{{"alpha", 3}, {"beta", 0}, {"int8_x4_format", int8_x4}}), m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}),
conta, tl1_alpha_int8,
packb, packb,
output); output);
m.add_return({gemm}); m.add_return({gemm});
return m; return m;
...@@ -174,7 +215,8 @@ TEST_CASE(quant_dot_pad) ...@@ -174,7 +215,8 @@ TEST_CASE(quant_dot_pad)
auto l1 = m.add_parameter("a", s1); auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto l3 = m.add_parameter("c", s3); auto l3 = m.add_parameter("c", s3);
auto r = m.add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3); auto r =
migraphx::add_apply_alpha_beta(m, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
m.add_return({r}); m.add_return({r});
return m; return m;
}; };
...@@ -190,6 +232,7 @@ TEST_CASE(quant_dot_pad) ...@@ -190,6 +232,7 @@ TEST_CASE(quant_dot_pad)
auto l1 = m.add_parameter("a", s1); auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto l3 = m.add_parameter("c", s3); auto l3 = m.add_parameter("c", s3);
auto beta = m.add_literal(1);
auto output = m.add_parameter("test:#output_0", s3); auto output = m.add_parameter("test:#output_0", s3);
auto pl1 = l1; auto pl1 = l1;
...@@ -213,7 +256,9 @@ TEST_CASE(quant_dot_pad) ...@@ -213,7 +256,9 @@ TEST_CASE(quant_dot_pad)
po2); po2);
} }
auto cout = m.add_instruction(migraphx::make_op("hip::copy"), l3, output); auto gemm_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
if(int8_x4) if(int8_x4)
{ {
auto alloc = m.add_instruction( auto alloc = m.add_instruction(
...@@ -221,15 +266,24 @@ TEST_CASE(quant_dot_pad) ...@@ -221,15 +266,24 @@ TEST_CASE(quant_dot_pad)
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pl2, alloc); packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pl2, alloc);
} }
auto gemm = m.add_instruction( auto gemm =
migraphx::make_op("gpu::quant_gemm", m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}),
{{"alpha", 1}, {"beta", 1}, {"int8_x4_format", int8_x4}}), pl1,
pl1, packa,
packa, gemm_alloc);
cout,
cout); auto beta_broadcast =
m.add_return({gemm}); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta);
auto beta_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto beta_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto m3_beta =
m.add_instruction(migraphx::make_op("gpu::mul"), l3, beta_contiguous, mul_alloc);
auto gemm_add = m.add_instruction(migraphx::make_op("gpu::add"), gemm, m3_beta, output);
m.add_return({gemm_add});
return m; return m;
}; };
...@@ -255,8 +309,7 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -255,8 +309,7 @@ TEST_CASE(quant_dot_trans_pad)
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto tl2 = auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction( auto r = migraphx::add_apply_alpha_beta(m, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r}); m.add_return({r});
return m; return m;
}; };
...@@ -271,6 +324,7 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -271,6 +324,7 @@ TEST_CASE(quant_dot_trans_pad)
auto l1 = m.add_parameter("a", s1); auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2); auto l2 = m.add_parameter("b", s2);
auto alpha = m.add_literal(3);
auto output = m.add_parameter("test:#output_0", s3); auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = auto tl1 =
...@@ -278,27 +332,14 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -278,27 +332,14 @@ TEST_CASE(quant_dot_trans_pad)
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}}; migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}};
auto ta = m.add_instruction( auto ta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
migraphx::instruction_ref pta{};
if(int8_x4)
{
pta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
}
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, ta); auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, ta);
auto pa = conta;
if(int8_x4)
{
pa = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 0, 3, 0, 0, 0, 0}}}),
conta,
pta);
}
auto tl2 = auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}}; migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
auto tb = m.add_instruction( auto tb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
migraphx::instruction_ref ptb{}; migraphx::instruction_ref ptb{};
if(int8_x4) if(int8_x4)
{ {
...@@ -306,24 +347,72 @@ TEST_CASE(quant_dot_trans_pad) ...@@ -306,24 +347,72 @@ TEST_CASE(quant_dot_trans_pad)
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
} }
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, tb); auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, tb);
auto packb = contb; auto pb = contb;
if(int8_x4) if(int8_x4)
{ {
auto pb = m.add_instruction( pb = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 3, 0, 0, 0, 0, 0}}}), migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 3, 0, 0, 0, 0, 0}}}),
contb, contb,
ptb); ptb);
}
auto alpha_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha);
auto alpha_alloc = m.add_instruction(
migraphx::make_op("hip::allocate",
{{"shape",
migraphx::to_value(migraphx::shape(migraphx::shape::int32_type,
conta->get_shape().lens()))}}));
auto alpha_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc);
// alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert
// back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert = m.add_instruction(
migraphx::make_op("gpu::convert", {{"target_type", alpha->get_shape().type()}}),
conta,
tl1_convert_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}}));
auto tl1_alpha_int32 = m.add_instruction(
migraphx::make_op("gpu::mul"), alpha_contiguous, tl1_convert, mul_alloc);
// convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}}));
migraphx::instruction_ref pta{};
if(int8_x4)
{
pta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
}
auto tl1_alpha_int8 = m.add_instruction(
migraphx::make_op("gpu::convert", {{"target_type", conta->get_shape().type()}}),
tl1_alpha_int32,
tl1_alpha_int8_alloc);
auto pa = tl1_alpha_int8;
if(int8_x4)
{
pa = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 0, 3, 0, 0, 0, 0}}}),
tl1_alpha_int8,
pta);
}
auto packb = pb;
if(int8_x4)
{
auto allocpb = m.add_instruction( auto allocpb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}})); migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pb, allocpb); packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pb, allocpb);
} }
auto gemm = m.add_instruction( auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm", migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), pa, packb, output);
{{"alpha", 3}, {"beta", 0}, {"int8_x4_format", int8_x4}}),
pa,
packb,
output);
m.add_return({gemm}); m.add_return({gemm});
return m; return m;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <vector> #include <vector>
#include <random> #include <random>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -1340,11 +1341,9 @@ TEST_CASE(gemm_test) ...@@ -1340,11 +1341,9 @@ TEST_CASE(gemm_test)
auto beta = 2.0f; auto beta = 2.0f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, t1);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b = auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
...@@ -1369,9 +1368,7 @@ TEST_CASE(gemm_ex_test) ...@@ -1369,9 +1368,7 @@ TEST_CASE(gemm_ex_test)
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction( auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l); migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
...@@ -1395,9 +1392,7 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -1395,9 +1392,7 @@ TEST_CASE(gemm_ex_brcst_test)
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto l2_b = auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), l2);
...@@ -1425,10 +1420,9 @@ TEST_CASE(gemm_half_test) ...@@ -1425,10 +1420,9 @@ TEST_CASE(gemm_half_test)
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7}; std::vector<std::size_t> lens = {1, 1, 6, 7};
auto dot = auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1); l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); l2 = mm->add_instruction(
l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l); auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l);
...@@ -1857,8 +1851,7 @@ TEST_CASE(initializer_not_an_input) ...@@ -1857,8 +1851,7 @@ TEST_CASE(initializer_not_an_input)
std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8}; std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8};
auto l1 = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w)); auto l1 = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}}); auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, l1); migraphx::add_apply_alpha_beta(*mm, {l0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto prog = optimize_onnx("initializer_not_an_input.onnx"); auto prog = optimize_onnx("initializer_not_an_input.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -2159,8 +2152,7 @@ TEST_CASE(matmul_bmbm_test) ...@@ -2159,8 +2152,7 @@ TEST_CASE(matmul_bmbm_test)
migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 6, 7}}}), l0); migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 6, 7}}}), l0);
auto bl1 = mm->add_instruction( auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 7, 8}}}), l1); migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3, 7, 8}}}), l1);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bl0, bl1); migraphx::add_apply_alpha_beta(*mm, {bl0, bl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto prog = optimize_onnx("matmul_bmbm_test.onnx"); auto prog = optimize_onnx("matmul_bmbm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -2176,7 +2168,7 @@ TEST_CASE(matmul_bmv_test) ...@@ -2176,7 +2168,7 @@ TEST_CASE(matmul_bmv_test)
auto bsl1 = auto bsl1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 7, 1}}}), sl1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 7, 1}}}), sl1);
auto res = auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, bsl1); migraphx::add_apply_alpha_beta(*mm, {l0, bsl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), res);
auto prog = optimize_onnx("matmul_bmv_test.onnx"); auto prog = optimize_onnx("matmul_bmv_test.onnx");
...@@ -2191,8 +2183,7 @@ TEST_CASE(matmul_mv_test) ...@@ -2191,8 +2183,7 @@ TEST_CASE(matmul_mv_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}}); auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1); auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = auto res = migraphx::add_apply_alpha_beta(*mm, {l0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, sl1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
auto prog = optimize_onnx("matmul_mv_test.onnx"); auto prog = optimize_onnx("matmul_mv_test.onnx");
...@@ -2210,7 +2201,7 @@ TEST_CASE(matmul_vbm_test) ...@@ -2210,7 +2201,7 @@ TEST_CASE(matmul_vbm_test)
auto bsl0 = auto bsl0 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 1, 7}}}), sl0); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 1, 7}}}), sl0);
auto res = auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), bsl0, l1); migraphx::add_apply_alpha_beta(*mm, {bsl0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
auto prog = optimize_onnx("matmul_vbm_test.onnx"); auto prog = optimize_onnx("matmul_vbm_test.onnx");
...@@ -2225,8 +2216,7 @@ TEST_CASE(matmul_vm_test) ...@@ -2225,8 +2216,7 @@ TEST_CASE(matmul_vm_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto res = auto res = migraphx::add_apply_alpha_beta(*mm, {sl0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), sl0, l1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
auto prog = optimize_onnx("matmul_vm_test.onnx"); auto prog = optimize_onnx("matmul_vm_test.onnx");
...@@ -2243,7 +2233,7 @@ TEST_CASE(matmul_vv_test) ...@@ -2243,7 +2233,7 @@ TEST_CASE(matmul_vv_test)
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0); auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1); auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = auto res =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), sl0, sl1); migraphx::add_apply_alpha_beta(*mm, {sl0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto sr0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); auto sr0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sr0); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sr0);
...@@ -2258,7 +2248,7 @@ TEST_CASE(matmulinteger_test) ...@@ -2258,7 +2248,7 @@ TEST_CASE(matmulinteger_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {3, 6, 16}}); auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {3, 6, 16}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 16, 8}}); auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::int8_type, {3, 16, 8}});
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), l0, l1); mm->add_instruction(migraphx::make_op("quant_dot"), l0, l1);
auto prog = optimize_onnx("matmulinteger_test.onnx"); auto prog = optimize_onnx("matmulinteger_test.onnx");
......
...@@ -312,87 +312,38 @@ TEST_CASE(gemm) ...@@ -312,87 +312,38 @@ TEST_CASE(gemm)
{ {
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {10, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1}}; throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}}; throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"), migraphx::make_op("dot"),
s_m1, s_m1,
s_m2, s_m2);
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"), migraphx::make_op("dot"),
s_m1, s_m1,
s_m2, s_m2);
s_m3);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {2, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}}; throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2, s_m3);
} }
} }
...@@ -1115,7 +1066,7 @@ TEST_CASE(quant_dot_2args) ...@@ -1115,7 +1066,7 @@ TEST_CASE(quant_dot_2args)
migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}}; migraphx::shape s_m1{migraphx::shape::int8_type, {3, 8}};
migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}}; migraphx::shape s_m2{migraphx::shape::int8_type, {8, 7}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {3, 7}},
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), migraphx::make_op("quant_dot"),
s_m1, s_m1,
s_m2); s_m2);
} }
...@@ -1127,27 +1078,6 @@ TEST_CASE(quant_dot_2args) ...@@ -1127,27 +1078,6 @@ TEST_CASE(quant_dot_2args)
} }
} }
TEST_CASE(quant_dot_3args)
{
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int32_type, {2, 8}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 8}},
migraphx::make_op("quant_dot"),
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::int8_type, {2, 4}};
migraphx::shape s_m2{migraphx::shape::int8_type, {4, 8}};
migraphx::shape s_m3{migraphx::shape::int8_type, {2, 8}};
throws_shape(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), s_m1, s_m2, s_m3);
}
}
template <class T> template <class T>
void test_reduce_ops() void test_reduce_ops()
{ {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <sstream> #include <sstream>
#include <migraphx/apply_alpha_beta.hpp>
#include "test.hpp" #include "test.hpp"
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -160,9 +161,8 @@ TEST_CASE(program_copy) ...@@ -160,9 +161,8 @@ TEST_CASE(program_copy)
auto para1 = mm1->add_parameter("m1", s1); auto para1 = mm1->add_parameter("m1", s1);
auto para2 = mm1->add_parameter("m2", s2); auto para2 = mm1->add_parameter("m2", s2);
auto para3 = mm1->add_parameter("m3", s3); auto para3 = mm1->add_parameter("m3", s3);
mm1->add_instruction( migraphx::add_apply_alpha_beta(
migraphx::make_op("dot", {{"alpha", 0.31f}, {"beta", 0.28f}}), para1, para2, para3); *mm1, {para1, para2, para3}, migraphx::make_op("dot"), 0.31f, 0.28f);
migraphx::program p2{}; migraphx::program p2{};
p2 = p1; p2 = p1;
EXPECT(p2 == p1); EXPECT(p2 == p1);
......
...@@ -111,4 +111,33 @@ TEST_CASE(const_scalar) ...@@ -111,4 +111,33 @@ TEST_CASE(const_scalar)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(const_dot)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {1.0f, 2.0f, 1.0f, 2.0f};
auto l = m1.add_literal(migraphx::literal(s, vec));
auto dl = m1.add_instruction(migraphx::make_op("dot"), l, l);
auto x = m1.add_parameter("x", s);
auto r = m1.add_instruction(migraphx::make_op("add"), dl, x);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
std::vector<float> vec = {3.0f, 6.0f, 3.0f, 6.0f};
auto x = m2.add_parameter("x", s);
auto l = m2.add_literal(migraphx::literal(s, vec));
auto r = m2.add_instruction(migraphx::make_op("add"), l, x);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp> #include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_fp16.hpp> #include <migraphx/quantize_fp16.hpp>
...@@ -431,7 +432,8 @@ TEST_CASE(op_capture) ...@@ -431,7 +432,8 @@ TEST_CASE(op_capture)
auto pb = mm->add_parameter("b", s2); auto pb = mm->add_parameter("b", s2);
auto pc = mm->add_parameter("c", s2); auto pc = mm->add_parameter("c", s2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2); auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto ps =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("dot"), pa, ps); mm->add_instruction(migraphx::make_op("dot"), pa, ps);
return p; return p;
...@@ -450,10 +452,10 @@ TEST_CASE(op_capture) ...@@ -450,10 +452,10 @@ TEST_CASE(op_capture)
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2); auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto opa = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), pa); auto opa = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 0}}), pa);
auto opb = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), pb); auto opb = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 1}}), pb);
auto opc = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pc); auto ps = migraphx::add_apply_alpha_beta(
auto ps = mm->add_instruction(migraphx::make_op("dot"), opa, opb, opc); *mm, {opa, opb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), pa); auto opm = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 2}}), pa);
auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 4}}), ps); auto ops = mm->add_instruction(migraphx::make_op("capture", {{"ins_index", 3}}), ps);
mm->add_instruction(migraphx::make_op("dot"), opm, ops); mm->add_instruction(migraphx::make_op("dot"), opm, ops);
return p; return p;
...@@ -556,10 +558,8 @@ TEST_CASE(dot_float) ...@@ -556,10 +558,8 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction( auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot"));
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb, pc);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -573,7 +573,6 @@ TEST_CASE(dot_float) ...@@ -573,7 +573,6 @@ TEST_CASE(dot_float)
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
auto zp_a = mm->add_literal(static_cast<int8_t>(0)); auto zp_a = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_literal(10.0f); auto scale_a = mm->add_literal(10.0f);
scale_a = mm->add_instruction( scale_a = mm->add_instruction(
...@@ -592,16 +591,7 @@ TEST_CASE(dot_float) ...@@ -592,16 +591,7 @@ TEST_CASE(dot_float)
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto zp_c = mm->add_literal(static_cast<int8_t>(100)); auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot"));
auto scale_c = mm->add_literal(10.0f);
scale_c = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), scale_c);
zp_c = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}),
zp_c);
auto qc = mm->add_instruction(migraphx::make_op("quantizelinear"), pc, scale_c, zp_c);
auto dqc = mm->add_instruction(migraphx::make_op("dequantizelinear"), qc, scale_c, zp_c);
auto r = mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), dqa, dqb, dqc);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -613,9 +603,8 @@ TEST_CASE(dot_float) ...@@ -613,9 +603,8 @@ TEST_CASE(dot_float)
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
mm->add_parameter("c", sc);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f); auto scale = mm->add_literal(10.0f);
auto scale_a = mm->add_instruction( auto scale_a = mm->add_instruction(
...@@ -628,8 +617,7 @@ TEST_CASE(dot_float) ...@@ -628,8 +617,7 @@ TEST_CASE(dot_float)
auto zp_b = auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction( auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f); std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f); auto dc = mm->add_literal(100.0f);
auto mdc = auto mdc =
...@@ -649,6 +637,7 @@ TEST_CASE(dot_float) ...@@ -649,6 +637,7 @@ TEST_CASE(dot_float)
p, p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); {migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
optimize_prog_int8(p); optimize_prog_int8(p);
...@@ -665,8 +654,7 @@ TEST_CASE(dot_double_2args) ...@@ -665,8 +654,7 @@ TEST_CASE(dot_double_2args)
migraphx::shape sb{migraphx::shape::double_type, {16, 8}}; migraphx::shape sb{migraphx::shape::double_type, {16, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto r = mm->add_instruction( auto r = migraphx::add_apply_alpha_beta(*mm, {pa, pb}, migraphx::make_op("dot"));
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), pa, pb);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -696,8 +684,7 @@ TEST_CASE(dot_double_2args) ...@@ -696,8 +684,7 @@ TEST_CASE(dot_double_2args)
zp_b); zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction( auto r = migraphx::add_apply_alpha_beta(*mm, {dqa, dqb}, migraphx::make_op("dot"));
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -722,9 +709,8 @@ TEST_CASE(dot_double_2args) ...@@ -722,9 +709,8 @@ TEST_CASE(dot_double_2args)
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b); migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto zp_b = auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction( auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qa, qb);
auto scale = mm->add_literal(50.0); auto scale = mm->add_literal(50.0);
scale = mm->add_instruction( scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale); migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
...@@ -753,8 +739,7 @@ TEST_CASE(dot_half_1arg) ...@@ -753,8 +739,7 @@ TEST_CASE(dot_half_1arg)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {9, 9}}; migraphx::shape s{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto r = auto r = mm->add_instruction(migraphx::make_op("dot"), x, x);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), x, x);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -782,8 +767,7 @@ TEST_CASE(dot_half_1arg) ...@@ -782,8 +767,7 @@ TEST_CASE(dot_half_1arg)
zp_b); zp_b);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale_b, zp_b);
auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b); auto dqb = mm->add_instruction(migraphx::make_op("dequantizelinear"), qb, scale_b, zp_b);
auto r = mm->add_instruction( auto r = mm->add_instruction(migraphx::make_op("dot"), dqa, dqb);
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), dqa, dqb);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -800,10 +784,8 @@ TEST_CASE(dot_half_1arg) ...@@ -800,10 +784,8 @@ TEST_CASE(dot_half_1arg)
scale); scale);
zp = zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction( auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0})); auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction( dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
...@@ -1055,9 +1037,9 @@ TEST_CASE(int8_quantization_dot) ...@@ -1055,9 +1037,9 @@ TEST_CASE(int8_quantization_dot)
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc); auto pc = mm->add_parameter("c", sc);
auto r = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto r =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -1075,7 +1057,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1075,7 +1057,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result)); EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000));
} }
} }
...@@ -1142,8 +1124,7 @@ TEST_CASE(int8_subgraph) ...@@ -1142,8 +1124,7 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw); auto w = mm->add_parameter("w", sw);
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto out1 = then_mod->add_instruction( auto out1 = migraphx::add_apply_alpha_beta(*then_mod, {a, b}, migraphx::make_op("dot"));
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), a, b);
then_mod->add_return({out1}); then_mod->add_return({out1});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
...@@ -1181,11 +1162,10 @@ TEST_CASE(int8_subgraph) ...@@ -1181,11 +1162,10 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1); migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction( auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1); migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb); auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
then_mod->add_instruction(migraphx::make_op("quant_dot", {{"beta", 0}}), qa, qb); auto so = then_mod->add_literal(100.0f);
auto so = then_mod->add_literal(100.0f); so = then_mod->add_instruction(
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so); migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so); auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r}); then_mod->add_return({r});
...@@ -1251,7 +1231,8 @@ TEST_CASE(test_op_capture) ...@@ -1251,7 +1231,8 @@ TEST_CASE(test_op_capture)
auto pb = mm->add_literal(s2, d2); auto pb = mm->add_literal(s2, d2);
auto pc = mm->add_literal(s2, d2); auto pc = mm->add_literal(s2, d2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2); auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc); auto ps =
migraphx::add_apply_alpha_beta(*mm, {pa, pb, pc}, migraphx::make_op("dot"), 1.0f, 1.0f);
mm->add_instruction(migraphx::make_op("dot"), pa, ps); mm->add_instruction(migraphx::make_op("dot"), pa, ps);
auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {}; auto calc = [](std::size_t, const std::vector<migraphx::argument>&) {};
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "test.hpp" #include "test.hpp"
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
...@@ -211,7 +212,11 @@ TEST_CASE(gemm_mutli_dim_2_beta0) ...@@ -211,7 +212,11 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
...@@ -274,7 +279,11 @@ TEST_CASE(gemm_beta_0) ...@@ -274,7 +279,11 @@ TEST_CASE(gemm_beta_0)
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
...@@ -359,13 +368,13 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -359,13 +368,13 @@ TEST_CASE(gemm_mutli_dim1_2_3)
0.49759611, 0.10021662, 0.00592602, 0.90862000}; 0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}}; migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, m1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, m2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
auto m12_alpha = auto m12_alpha = migraphx::add_apply_alpha_beta(
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2); *mm, std::vector<migraphx::instruction_ref>{l1, l2}, migraphx::make_op("dot"), alpha);
auto l_beta = mm->add_literal(beta); auto l_beta = mm->add_literal(beta);
auto b_beta = mm->add_instruction( auto b_beta = mm->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta); migraphx::make_op("scalar", {{"scalar_bcst_dims", m12_alpha->get_shape().lens()}}), l_beta);
...@@ -418,7 +427,11 @@ TEST_CASE(gemm_mutli_3args) ...@@ -418,7 +427,11 @@ TEST_CASE(gemm_mutli_3args)
auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, m3});
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); migraphx::add_apply_alpha_beta(*mm,
std::vector<migraphx::instruction_ref>{l1, l2, l3},
migraphx::make_op("dot"),
alpha,
beta);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
...@@ -479,7 +492,7 @@ TEST_CASE(gemm_3args) ...@@ -479,7 +492,7 @@ TEST_CASE(gemm_3args)
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}}; migraphx::shape c_shape{migraphx::shape::float_type, {3, 3}};
auto cl = mm->add_literal(migraphx::literal{c_shape, c}); auto cl = mm->add_literal(migraphx::literal{c_shape, c});
mm->add_instruction(migraphx::make_op("dot"), al, bl, cl); migraphx::add_apply_alpha_beta(*mm, {al, bl, cl}, migraphx::make_op("dot"), 1.0f, 1.0f);
std::vector<float> gold = {-1.60947, std::vector<float> gold = {-1.60947,
0.703083, 0.703083,
-5.46156, -5.46156,
...@@ -561,7 +574,8 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -561,7 +574,8 @@ TEST_CASE(matmul_vv_inner_product)
auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al); auto ual = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), al);
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.32f; float alpha = 0.32f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, ubl); migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-0.4590752}; std::vector<float> gold = {-0.4590752};
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -634,7 +648,8 @@ TEST_CASE(matmul_vm) ...@@ -634,7 +648,8 @@ TEST_CASE(matmul_vm)
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
float alpha = 0.5f; float alpha = 0.5f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ual, bl); migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{ual, bl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163}; std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
...@@ -718,7 +733,8 @@ TEST_CASE(matmul_vm) ...@@ -718,7 +733,8 @@ TEST_CASE(matmul_vm)
migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual); migraphx::make_op("multibroadcast", {{"out_lens", {3, 1, 6}}}), ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 0.21f}}), bual, bl); migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{bual, bl}, migraphx::make_op("dot"), 0.21f);
std::vector<float> gold = {0.25812, std::vector<float> gold = {0.25812,
-0.247582, -0.247582,
0.480051, 0.480051,
...@@ -805,7 +821,8 @@ TEST_CASE(matmul_mv) ...@@ -805,7 +821,8 @@ TEST_CASE(matmul_mv)
auto bl = mm->add_literal(migraphx::literal{b_shape, b}); auto bl = mm->add_literal(migraphx::literal{b_shape, b});
auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl); auto ubl = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bl);
float alpha = 0.3f; float alpha = 0.3f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), al, ubl); migraphx::add_apply_alpha_beta(
*mm, std::vector<migraphx::instruction_ref>{al, ubl}, migraphx::make_op("dot"), alpha);
std::vector<float> gold = {0.395946, 0.357067, -0.588187}; std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -1337,7 +1354,8 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1337,7 +1354,8 @@ TEST_CASE(quant_dot_2args_general)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 2}}), l1, tl2);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2}, migraphx::make_op("quant_dot"), 2);
std::vector<int> gold = { std::vector<int> gold = {
28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340}; 28, 76, 124, 172, 220, 76, 252, 428, 604, 780, 124, 428, 732, 1036, 1340};
...@@ -1366,7 +1384,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1366,7 +1384,7 @@ TEST_CASE(quant_dot_2args_general)
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2); migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
std::vector<int> gold = { std::vector<int> gold = {
126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410}; 126, 342, 558, 774, 990, 144, 408, 672, 936, 1200, 162, 474, 786, 1098, 1410};
...@@ -1398,7 +1416,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1398,7 +1416,7 @@ TEST_CASE(quant_dot_3args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3); migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
std::vector<int> gold = { std::vector<int> gold = {
982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115}; 982, 1011, 1040, 1069, 1098, 1127, 1156, 2557, 2650, 2743, 2836, 2929, 3022, 3115};
...@@ -1426,9 +1444,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1426,9 +1444,7 @@ TEST_CASE(quant_dot_3args_general)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2);
mm->add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), l1, l2, l3);
std::vector<int> gold = { std::vector<int> gold = {
70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462}; 70, 76, 82, 88, 94, 190, 212, 234, 256, 278, 310, 348, 386, 424, 462};
...@@ -1459,8 +1475,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1459,8 +1475,7 @@ TEST_CASE(quant_dot_3args_general)
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, l2, l3);
std::vector<int> gold = { std::vector<int> gold = {
1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585}; 1966, 2025, 2084, 2143, 2202, 2261, 2320, 2183, 2250, 2317, 2384, 2451, 2518, 2585};
...@@ -1491,8 +1506,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1491,8 +1506,7 @@ TEST_CASE(quant_dot_3args_general)
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), l1, tl2, l3);
std::vector<int> gold = { std::vector<int> gold = {
286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605}; 286, 737, 1188, 1639, 2090, 2541, 2992, 755, 2230, 3705, 5180, 6655, 8130, 9605};
...@@ -1525,8 +1539,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1525,8 +1539,7 @@ TEST_CASE(quant_dot_3args_general)
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
std::vector<int> gold = { std::vector<int> gold = {
844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170}; 844, 2190, 3536, 4882, 6228, 7574, 8920, 942, 2480, 4018, 5556, 7094, 8632, 10170};
...@@ -1558,8 +1571,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1558,8 +1571,7 @@ TEST_CASE(quant_dot_3args_batch)
auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1}); auto l1 = mm->add_literal(migraphx::literal{m1_shape, data1});
auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2}); auto l2 = mm->add_literal(migraphx::literal{m2_shape, data2});
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 2);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 2}}), l1, l2, l3);
std::vector<int> gold = { std::vector<int> gold = {
102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380, 102, 110, 118, 126, 134, 142, 150, 284, 308, 332, 356, 380,
...@@ -1596,8 +1608,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1596,8 +1608,7 @@ TEST_CASE(quant_dot_3args_batch)
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3}); auto l3 = mm->add_literal(migraphx::literal{m3_shape, data3});
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
migraphx::make_op("quant_dot", {{"alpha", 2}, {"beta", 3}}), tl1, tl2, l3);
std::vector<int> gold = { std::vector<int> gold = {
90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015, 90, 237, 384, 531, 678, 825, 120, 299, 478, 657, 836, 1015,
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/apply_alpha_beta.hpp>
bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; } bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; }
bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; } bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; }
...@@ -127,12 +128,11 @@ TEST_CASE(dot) ...@@ -127,12 +128,11 @@ TEST_CASE(dot)
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0}); auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot}); m1.add_return({dot});
} }
...@@ -144,11 +144,10 @@ TEST_CASE(dot) ...@@ -144,11 +144,10 @@ TEST_CASE(dot)
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f); auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q1, q2); auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
m2.add_return({d3}); m2.add_return({d3});
} }
...@@ -168,22 +167,19 @@ TEST_CASE(dot_non_zero_point) ...@@ -168,22 +167,19 @@ TEST_CASE(dot_non_zero_point)
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{1}); auto zero = m1.add_literal(std::int8_t{1});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot}); m1.add_return({dot});
} }
migraphx::module m2; migraphx::module m2;
{ {
auto t1 = m2.add_parameter("t1", sh1); auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2); auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
m2.add_return({dot}); m2.add_return({dot});
} }
...@@ -203,22 +199,19 @@ TEST_CASE(dot_uint8) ...@@ -203,22 +199,19 @@ TEST_CASE(dot_uint8)
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::uint8_t{0}); auto zero = m1.add_literal(std::uint8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1.add_return({dot}); m1.add_return({dot});
} }
migraphx::module m2; migraphx::module m2;
{ {
auto t1 = m2.add_parameter("t1", sh1); auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2); auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
auto dot =
m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), t1, t2);
m2.add_return({dot}); m2.add_return({dot});
} }
...@@ -240,12 +233,11 @@ TEST_CASE(dot_add) ...@@ -240,12 +233,11 @@ TEST_CASE(dot_add)
auto scale = m1.add_literal(0.5f); auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0}); auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); auto q3 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero); auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab); auto add = m1.add_instruction(migraphx::make_op("add"), d3, ab);
...@@ -261,10 +253,9 @@ TEST_CASE(dot_add) ...@@ -261,10 +253,9 @@ TEST_CASE(dot_add)
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f); auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1); auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab); auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
m2.add_return({add}); m2.add_return({add});
...@@ -471,21 +462,20 @@ TEST_CASE(conv_pooling_dot) ...@@ -471,21 +462,20 @@ TEST_CASE(conv_pooling_dot)
d1); d1);
auto bc1 = m1.add_instruction( auto bc1 = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = m1.add_instruction(migraphx::make_op("add"), c1, bc1);
auto ap = m1.add_instruction(migraphx::make_op("pooling", auto ap = m1.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {7, 7}}, {"lengths", {7, 7}},
{"ceil_mode", 0}}), {"ceil_mode", 0}}),
a1); a1);
auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m1.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m1, "quantizelinear", fl, scale, zero);
auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(m1, "dequantizelinear", q4, scale, zero);
auto dot = auto dot = m1.add_instruction(migraphx::make_op("dot"), d8, d4);
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4); auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero);
auto q5 = add_quantize_op(m1, "quantizelinear", dot, scale, zero); auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto d9 = add_quantize_op(m1, "dequantizelinear", q5, scale, zero);
auto mb1 = auto mb1 =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = m1.add_instruction(migraphx::make_op("add"), d9, mb1);
...@@ -518,19 +508,18 @@ TEST_CASE(conv_pooling_dot) ...@@ -518,19 +508,18 @@ TEST_CASE(conv_pooling_dot)
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction( auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap = m2.add_instruction(migraphx::make_op("pooling", auto ap = m2.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {7, 7}}, {"lengths", {7, 7}},
{"ceil_mode", 0}}), {"ceil_mode", 0}}),
a1); a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
m2.add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), q4, db); auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto mb1 = auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
...@@ -575,25 +564,24 @@ TEST_CASE(mobilenet_snippet) ...@@ -575,25 +564,24 @@ TEST_CASE(mobilenet_snippet)
d1); d1);
auto bc1 = mm.add_instruction( auto bc1 = mm.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1); auto a1 = mm.add_instruction(migraphx::make_op("add"), c1, bc1);
auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero); auto q2 = add_quantize_op(mm, "quantizelinear", a1, scale, zero);
auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero); auto d6 = add_quantize_op(mm, "dequantizelinear", q2, scale, zero);
auto ap = mm.add_instruction(migraphx::make_op("pooling", auto ap = mm.add_instruction(migraphx::make_op("pooling",
{{"mode", "average"}, {{"mode", "average"},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {7, 7}}, {"lengths", {7, 7}},
{"ceil_mode", 0}}), {"ceil_mode", 0}}),
d6); d6);
auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero); auto q3 = add_quantize_op(mm, "quantizelinear", ap, scale, zero);
auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero); auto d7 = add_quantize_op(mm, "dequantizelinear", q3, scale, zero);
auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7); auto rs = mm.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1}}}), d7);
auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero); auto q4 = add_quantize_op(mm, "quantizelinear", rs, scale, zero);
auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero); auto d8 = add_quantize_op(mm, "dequantizelinear", q4, scale, zero);
auto dot = auto dot = mm.add_instruction(migraphx::make_op("dot"), d8, d4);
mm.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d8, d4); auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero);
auto q5 = add_quantize_op(mm, "quantizelinear", dot, scale, zero); auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto d9 = add_quantize_op(mm, "dequantizelinear", q5, scale, zero);
auto mb1 = auto mb1 =
mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); mm.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = mm.add_instruction(migraphx::make_op("add"), d9, mb1);
...@@ -699,12 +687,11 @@ TEST_CASE(dot_correctness) ...@@ -699,12 +687,11 @@ TEST_CASE(dot_correctness)
auto scale_b = m1->add_literal(0.5f); auto scale_b = m1->add_literal(0.5f);
auto zero = m1->add_literal(std::int8_t{0}); auto zero = m1->add_literal(std::int8_t{0});
auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero); auto q1 = add_quantize_op(*m1, "quantizelinear", a, scale_a, zero);
auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero); auto d1 = add_quantize_op(*m1, "dequantizelinear", q1, scale_a, zero);
auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero); auto q2 = add_quantize_op(*m1, "quantizelinear", b, scale_b, zero);
auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero); auto d2 = add_quantize_op(*m1, "dequantizelinear", q2, scale_b, zero);
auto dot = auto dot = m1->add_instruction(migraphx::make_op("dot"), d1, d2);
m1->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), d1, d2);
m1->add_return({dot}); m1->add_return({dot});
run_pass(*m1); run_pass(*m1);
...@@ -715,8 +702,7 @@ TEST_CASE(dot_correctness) ...@@ -715,8 +702,7 @@ TEST_CASE(dot_correctness)
auto* m2 = p2.get_main_module(); auto* m2 = p2.get_main_module();
auto a = m2->add_parameter("a", sh1); auto a = m2->add_parameter("a", sh1);
auto b = m2->add_parameter("b", sh2); auto b = m2->add_parameter("b", sh2);
auto dot = m2->add_instruction(migraphx::make_op("dot"), a, b);
auto dot = m2->add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), a, b);
m2->add_return({dot}); m2->add_return({dot});
} }
......
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -21,8 +22,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> ...@@ -21,8 +22,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2, l3);
return p; return p;
} }
}; };
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
...@@ -17,8 +18,7 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2> ...@@ -17,8 +18,7 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2, l3);
return p; return p;
} }
}; };
...@@ -15,7 +15,7 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3> ...@@ -15,7 +15,7 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2); mm->add_instruction(migraphx::make_op("quant_dot"), l1, l2);
return p; return p;
} }
}; };
...@@ -19,7 +19,7 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> ...@@ -19,7 +19,7 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1); migraphx::make_op("transpose", {{"permutation", {3, 0, 1, 2}}}), l1);
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2); migraphx::make_op("transpose", {{"permutation", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2); mm->add_instruction(migraphx::make_op("quant_dot"), tl1, tl2);
return p; return p;
} }
}; };
...@@ -21,7 +21,7 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> ...@@ -21,7 +21,7 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2); auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2); mm->add_instruction(migraphx::make_op("quant_dot"), sl1, sl2);
return p; return p;
} }
}; };
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -17,8 +18,7 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv> ...@@ -17,8 +18,7 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
auto l2 = mm->add_parameter("2", m2_shape); auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2);
float alpha = 0.23f; float alpha = 0.23f;
auto res = migraphx::add_apply_alpha_beta(*mm, {ul1, ul2}, migraphx::make_op("dot"), alpha);
auto res = mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}}), ul1, ul2);
auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res); auto sres = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sres);
......
#include <migraphx/apply_alpha_beta.hpp>
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -19,9 +20,7 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args> ...@@ -19,9 +20,7 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
auto l3 = mm->add_parameter("3", m3_shape); auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35; float alpha = 0.35;
float beta = 0.41; float beta = 0.41;
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -19,9 +19,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0> ...@@ -19,9 +19,7 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
float alpha = 0.0f; float alpha = 0.0f;
float beta = 1.0f; float beta = 1.0f;
mm->add_instruction( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("dot"), alpha, beta);
migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
return p; return p;
} }
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment