Commit 8b023ae3 authored by Khalique's avatar Khalique
Browse files

fix beta val

parent ef0fe6e6
...@@ -400,7 +400,7 @@ struct onnx_parser ...@@ -400,7 +400,7 @@ struct onnx_parser
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 1.0f;
bool transa = false; bool transa = false;
bool transb = false; bool transb = false;
if(contains(attributes, "alpha")) if(contains(attributes, "alpha"))
...@@ -427,10 +427,14 @@ struct onnx_parser ...@@ -427,10 +427,14 @@ struct onnx_parser
if(beta != 0.f) if(beta != 0.f)
{ {
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2); auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
auto beta_val = prog.add_literal(beta); auto l4 = args[2];
auto l4 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val); if(beta == 1.f)
auto l5 = prog.add_instruction(op::mul{}, args[2], l4); {
return add_broadcastable_binary_op(l3, l5, op::add{}); auto beta_val = prog.add_literal(beta);
auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
l4 = prog.add_instruction(op::mul{}, args[2], l5);
}
return add_broadcastable_binary_op(l3, l4, op::add{});
} }
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment