Commit 129eda46 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent a03db463
...@@ -43,9 +43,8 @@ struct broadcast ...@@ -43,9 +43,8 @@ struct broadcast
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
if(std::all_of(broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { if(std::all_of(
return x == 1; broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
}))
{ {
if(axis != 0) if(axis != 0)
MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0"); MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0");
...@@ -54,8 +53,7 @@ struct broadcast ...@@ -54,8 +53,7 @@ struct broadcast
else else
{ {
assert(broadcast_lens.size() - axis >= input.lens().size()); assert(broadcast_lens.size() - axis >= input.lens().size());
if(!std::equal( if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
MIGRAPHX_THROW("when broadcasting success sizes must match"); MIGRAPHX_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_lens, std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
......
...@@ -141,8 +141,8 @@ struct onnx_parser ...@@ -141,8 +141,8 @@ struct onnx_parser
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>(); uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
auto l = auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]); args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
} }
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
...@@ -679,7 +679,8 @@ struct onnx_parser ...@@ -679,7 +679,8 @@ struct onnx_parser
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_shape}, scale_val); auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_shape}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor); auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_shape.lens()}, bias_vals); auto bias_bcast =
prog.add_instruction(migraphx::op::broadcast{1, input_shape.lens()}, bias_vals);
return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast); return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment