Unverified Commit eacf042e authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Remove Alpha Beta from onnx gemm parsing (#874)

* gemm_test_workign

clang_formatting

tests passing

clang formatting

look for beta not equal to one

* make_use of broadcastable_binary_op

clang formatting

* make use of common_op

clang formatting

* move transposes after multiplication

clang formatting

fix transpose

formatting

fix cpp check

foramtting

* fix parsing conditions and ci fails
parent 3282e01a
......@@ -42,13 +42,23 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l1 = args[0];
if(alpha != 1.0f)
{
auto alpha_literal = info.add_literal(alpha);
auto alpha_l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
l1 = info.add_instruction(make_op("convert", {{"target_type", l1->get_shape().type()}}),
alpha_l1);
}
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1;
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
if(args.size() == 3)
{
if(beta != 0.f && args[2]->get_shape().elements() > 0)
if(beta != 0.0f && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
......@@ -59,12 +69,17 @@ struct parse_gemm : op_parser<parse_gemm>
l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
}
auto beta_literal = info.add_literal(beta);
auto beta_broadcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), beta_literal);
l3 = info.add_instruction(make_op("mul"), l3, beta_broadcast);
return info.add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, l3);
}
}
return info.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
return info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2);
}
};
......
#include <iostream>
#include <fstream>
#include <vector>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
......@@ -1268,19 +1269,28 @@ TEST_CASE(gather_elements_axis1_test)
TEST_CASE(gemm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type});
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l0);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto bl2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {7, 11}}}), l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type});
auto alpha = 2.f;
auto beta = 2.0f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), t0, t1, bl2);
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, t1, l2_bb);
auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog);
}
......@@ -1291,10 +1301,21 @@ TEST_CASE(gemm_ex_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0);
auto alpha = 0.5f;
auto beta = 0.8f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), t0, l1, l2);
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_b);
auto prog = optimize_onnx("gemm_ex_test.onnx");
EXPECT(p == prog);
......@@ -1307,13 +1328,25 @@ TEST_CASE(gemm_ex_brcst_test)
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}});
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l0);
std::vector<std::size_t> out_lens{1, 1, 6, 7};
auto t2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", out_lens}}), l2);
auto alpha = 0.5f;
auto beta = 0.8f;
mm->add_instruction(migraphx::make_op("dot", {{"alpha", alpha}, {"beta", beta}}), t0, l1, t2);
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", out_lens}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_bb);
auto prog = optimize_onnx("gemm_ex_brcst_test.onnx");
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment