Unverified Commit 027e1fa7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #149 from ROCmSoftwarePlatform/gemm_beta

Gemm beta value parsing fix
parents a3bdb08f c33b8a63
...@@ -150,7 +150,7 @@ struct onnx_parser ...@@ -150,7 +150,7 @@ struct onnx_parser
if(s0->size() > s1->size()) if(s0->size() > s1->size())
std::swap(s0, s1); std::swap(s0, s1);
std::vector<std::size_t> output_lens(s1->size()); std::vector<std::size_t> output_lens(*s1);
auto offset = s1->size() - s0->size(); auto offset = s1->size() - s0->size();
std::transform(s0->begin(), std::transform(s0->begin(),
s0->end(), s0->end(),
...@@ -384,7 +384,7 @@ struct onnx_parser ...@@ -384,7 +384,7 @@ struct onnx_parser
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 1.0f;
bool transa = false; bool transa = false;
bool transb = false; bool transb = false;
if(contains(attributes, "alpha")) if(contains(attributes, "alpha"))
...@@ -408,10 +408,20 @@ struct onnx_parser ...@@ -408,10 +408,20 @@ struct onnx_parser
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; if(beta != 0.f)
auto l3 = prog.add_instruction(op::dot{alpha, beta}, l1, l2); {
auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]); auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
return prog.add_instruction(op::add{}, l3, l4); auto l4 = args[2];
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
return l3;
if(beta != 1.f)
{
auto beta_val = prog.add_literal(beta);
auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
l4 = prog.add_instruction(op::mul{}, args[2], l5);
}
return add_broadcastable_binary_op(l3, l4, op::add{});
}
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
......
...@@ -351,8 +351,8 @@ TEST_CASE(implicit_bcast_test) ...@@ -351,8 +351,8 @@ TEST_CASE(implicit_bcast_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3); p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_onnx("implicit_bcast_test.onnx"); auto prog = migraphx::parse_onnx("implicit_bcast_test.onnx");
...@@ -460,12 +460,11 @@ TEST_CASE(gemm_test) ...@@ -460,12 +460,11 @@ TEST_CASE(gemm_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}}); p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}});
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto d0 = p.add_instruction(migraphx::op::dot{2, 2}, t0, t1); auto alpha = 2.f;
auto b0 = p.add_instruction(migraphx::op::broadcast{1, d0->get_shape()}, l2); p.add_instruction(migraphx::op::dot{alpha}, t0, t1);
p.add_instruction(migraphx::op::add{}, d0, b0);
auto prog = migraphx::parse_onnx("gemm_test.onnx"); auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -477,8 +476,8 @@ TEST_CASE(add_scalar_test) ...@@ -477,8 +476,8 @@ TEST_CASE(add_scalar_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l0); auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l1); auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1); p.add_instruction(migraphx::op::add{}, m0, m1);
auto prog = migraphx::parse_onnx("add_scalar_test.onnx"); auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment