"git@developer.sourcefind.cn:chenzk/alphafold2_jax.git" did not exist on "665ebc301340b34efb85c72503fbbe8315b3a0f4"
Commit f1ab5ed2 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into pad_op

parents 66fa0083 3e9358c2
...@@ -107,14 +107,12 @@ struct onnx_parser ...@@ -107,14 +107,12 @@ struct onnx_parser
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast")) if(contains(attributes, "broadcast") and contains(attributes, "axis"))
{ {
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>(); uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = (contains(attributes, "axis")) uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment