Commit 44614513 authored by Khalique's avatar Khalique
Browse files

testing decomposition with beta != 0

parent 436b459e
...@@ -409,7 +409,7 @@ struct onnx_parser ...@@ -409,7 +409,7 @@ struct onnx_parser
} }
if(contains(attributes, "beta")) if(contains(attributes, "beta"))
{ {
alpha = parse_value(attributes.at("beta")).at<float>(); beta = parse_value(attributes.at("beta")).at<float>();
} }
if(contains(attributes, "transA")) if(contains(attributes, "transA"))
{ {
...@@ -424,10 +424,14 @@ struct onnx_parser ...@@ -424,10 +424,14 @@ struct onnx_parser
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; if (beta != 0.f)
auto l3 = prog.add_instruction(op::dot{alpha, beta}, l1, l2); {
auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]); auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
return prog.add_instruction(op::add{}, l3, l4); auto beta_val = prog.add_literal(beta);
auto l4 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
auto l5 = prog.add_instruction(op::mul{}, args[2], l4);
return add_broadcastable_binary_op(l3, l5, op::add{});
}
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment