Commit 46f750ea authored by Scott Thornton's avatar Scott Thornton
Browse files

clang formatting

parent 4d555bcb
...@@ -379,25 +379,26 @@ struct broadcast ...@@ -379,25 +379,26 @@ struct broadcast
auto shape1_lens = shape1.lens(); auto shape1_lens = shape1.lens();
auto shape0_strides = shape0.lens(); auto shape0_strides = shape0.lens();
auto shape1_strides = shape1.lens(); auto shape1_strides = shape1.lens();
if (std::all_of(shape0_lens.cbegin(), if(std::all_of(shape0_lens.cbegin(), shape1_lens.cend(), [&](auto x) { return x == 1; }))
shape1_lens.cend(),
[&](auto x) { return x == 1; }))
{ {
if (axis != 0) RTG_THROW("when broadcasting tensor of size 1, axis should be 0"); if(axis != 0)
RTG_THROW("when broadcasting tensor of size 1, axis should be 0");
std::vector<size_t> bcast_shape_lens = shape0_lens; std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0); std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
return {t, bcast_shape_lens, bcast_shape_strides}; return {t, bcast_shape_lens, bcast_shape_strides};
} }
else else
{ {
for (size_t i = 0; i < shape1_lens.size(); i++) for(size_t i = 0; i < shape1_lens.size(); i++)
{ {
if (shape0_lens[i+axis] != shape1_lens[i]) RTG_THROW("when broadcasting success sizes must match"); if(shape0_lens[i + axis] != shape1_lens[i])
RTG_THROW("when broadcasting success sizes must match");
} }
std::vector<size_t> bcast_shape_lens = shape0_lens; std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0); std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
for (size_t i = 0; i < shape1_strides.size(); i++) { for(size_t i = 0; i < shape1_strides.size(); i++)
bcast_shape_strides[i+axis] = shape1_strides[i]; {
bcast_shape_strides[i + axis] = shape1_strides[i];
} }
return {t, bcast_shape_lens, bcast_shape_strides}; return {t, bcast_shape_lens, bcast_shape_strides};
} }
......
...@@ -110,12 +110,14 @@ struct onnx_parser ...@@ -110,12 +110,14 @@ struct onnx_parser
return prog.add_literal(v); return prog.add_literal(v);
}); });
add_op("Add", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) { add_op("Add", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
if (contains(attributes, "broadcast")) if(contains(attributes, "broadcast"))
{ {
uint64_t broadcast = parse_value(attributes.at("broadcast")).at<uint64_t>(); uint64_t broadcast = parse_value(attributes.at("broadcast")).at<uint64_t>();
if (broadcast != 0) { if(broadcast != 0)
uint64_t axis = (contains(attributes, "axis")) ? {
parse_value(attributes.at("axis")).at<uint64_t>() : 0; uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = prog.add_instruction(rtg::broadcast{axis}, args); auto l = prog.add_instruction(rtg::broadcast{axis}, args);
return prog.add_instruction(rtg::add{}, args[0], l); return prog.add_instruction(rtg::add{}, args[0], l);
} }
......
...@@ -291,42 +291,39 @@ struct add_with_broadcast ...@@ -291,42 +291,39 @@ struct add_with_broadcast
size_t ndims = output_shape.lens().size(); size_t ndims = output_shape.lens().size();
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input0, auto input1) { visit_all(result, args[0], args[1])([&](auto output, auto input0, auto input1) {
if (ndims == 0) if(ndims == 0)
{ {
output(0) = input0(0) + input1(0); output(0) = input0(0) + input1(0);
} }
if (ndims == 1) if(ndims == 1)
{ {
for (size_t i = 0; i < output_shape.lens()[0]; i++) for(size_t i = 0; i < output_shape.lens()[0]; i++)
{ {
output(i) = input0(i) + input1(i); output(i) = input0(i) + input1(i);
} }
} }
else if (ndims == 2) else if(ndims == 2)
{ {
dfor(output_shape.lens()[0], dfor(output_shape.lens()[0],
output_shape.lens()[1])( output_shape.lens()[1])([&](std::size_t i0, std::size_t i1) {
[&](std::size_t i0, std::size_t i1) { output(i0, i1) = input0(i0, i1) + input1(i0, i1);
output(i0,i1) = input0(i0,i1) + input1(i0,i1);
}); });
} }
else if (ndims == 3) else if(ndims == 3)
{ {
dfor(output_shape.lens()[0], dfor(output_shape.lens()[0], output_shape.lens()[1], output_shape.lens()[2])(
output_shape.lens()[1],
output_shape.lens()[2])(
[&](std::size_t i0, std::size_t i1, std::size_t i2) { [&](std::size_t i0, std::size_t i1, std::size_t i2) {
output(i0,i1,i2) = input0(i0,i1,i2) + input1(i0,i1,i2); output(i0, i1, i2) = input0(i0, i1, i2) + input1(i0, i1, i2);
}); });
} }
else if (ndims == 4) else if(ndims == 4)
{ {
dfor(output_shape.lens()[0], dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) { [&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
output(i0,i1,i2,i3) = input0(i0,i1,i2,i3) + input1(i0,i1,i2,i3); output(i0, i1, i2, i3) = input0(i0, i1, i2, i3) + input1(i0, i1, i2, i3);
}); });
} }
else else
...@@ -542,7 +539,7 @@ struct cpu_apply ...@@ -542,7 +539,7 @@ struct cpu_apply
void apply_add(instruction_ref ins) void apply_add(instruction_ref ins)
{ {
auto&& op = any_cast<add>(ins->op); auto&& op = any_cast<add>(ins->op);
//prog->replace_instruction(ins, cpu_binary<add_op>{}, ins->arguments); // prog->replace_instruction(ins, cpu_binary<add_op>{}, ins->arguments);
prog->replace_instruction(ins, add_with_broadcast{op}, ins->arguments); prog->replace_instruction(ins, add_with_broadcast{op}, ins->arguments);
} }
......
...@@ -9,23 +9,27 @@ ...@@ -9,23 +9,27 @@
void fred() void fred()
{ {
size_t axis = 1; size_t axis = 1;
rtg::shape shape0{rtg::shape::float_type, {2,4,3,4}}; rtg::shape shape0{rtg::shape::float_type, {2, 4, 3, 4}};
rtg::shape shape1{rtg::shape::float_type, {4,3}}; rtg::shape shape1{rtg::shape::float_type, {4, 3}};
std::vector<size_t> shape0_lens = shape0.lens(); std::vector<size_t> shape0_lens = shape0.lens();
std::vector<size_t> shape1_lens = shape1.lens(); std::vector<size_t> shape1_lens = shape1.lens();
std::vector<size_t> shape0_strides = shape0.strides(); std::vector<size_t> shape0_strides = shape0.strides();
std::vector<size_t> shape1_strides = shape1.strides(); std::vector<size_t> shape1_strides = shape1.strides();
for (size_t i = 0; i < shape1.lens().size(); i++) { for(size_t i = 0; i < shape1.lens().size(); i++)
assert(shape0_lens[i+axis] == shape1_lens[i]); {
assert(shape0_lens[i + axis] == shape1_lens[i]);
} }
std::vector<size_t> bcast_shape_lens = shape0_lens; std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0); std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
for (size_t i = 0; i < shape1_strides.size(); i++) { for(size_t i = 0; i < shape1_strides.size(); i++)
bcast_shape_strides[i+axis] = shape1_strides[i]; {
bcast_shape_strides[i + axis] = shape1_strides[i];
} }
for (auto x : bcast_shape_lens) std::cout << x << " "; for(auto x : bcast_shape_lens)
std::cout << x << " ";
std::cout << "\n"; std::cout << "\n";
for (auto x : bcast_shape_strides) std::cout << x << " "; for(auto x : bcast_shape_strides)
std::cout << x << " ";
std::cout << "\n"; std::cout << "\n";
} }
...@@ -90,7 +94,7 @@ void add_test() ...@@ -90,7 +94,7 @@ void add_test()
rtg::program p; rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}}; rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}}); auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(rtg::literal{s, { 1, 2, 3}}); auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
p.add_instruction(rtg::add{}, l1, l2); p.add_instruction(rtg::add{}, l1, l2);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -105,7 +109,7 @@ void sub_test() ...@@ -105,7 +109,7 @@ void sub_test()
rtg::program p; rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}}; rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}}); auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(rtg::literal{s, { 1, 2, 3}}); auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
p.add_instruction(rtg::sub{}, l1, l2); p.add_instruction(rtg::sub{}, l1, l2);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -120,7 +124,7 @@ void mul_test() ...@@ -120,7 +124,7 @@ void mul_test()
rtg::program p; rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}}; rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}}); auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(rtg::literal{s, { 1, 2, 3}}); auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
p.add_instruction(rtg::mul{}, l1, l2); p.add_instruction(rtg::mul{}, l1, l2);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -135,7 +139,7 @@ void div_test() ...@@ -135,7 +139,7 @@ void div_test()
rtg::program p; rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}}; rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1.0f, 0.5f, 1.0f}}); auto l1 = p.add_literal(rtg::literal{s, {-1.0f, 0.5f, 1.0f}});
auto l2 = p.add_literal(rtg::literal{s, { 1.0f, 2.0f, 4.0f}}); auto l2 = p.add_literal(rtg::literal{s, {1.0f, 2.0f, 4.0f}});
p.add_instruction(rtg::div{}, l1, l2); p.add_instruction(rtg::div{}, l1, l2);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment