"mmdet3d/vscode:/vscode.git/clone" did not exist on "9b77c74929c7c7983de9dcb491a4a26ea01cda66"
Commit 46f750ea authored by Scott Thornton's avatar Scott Thornton
Browse files

clang formatting

parent 4d555bcb
......@@ -379,33 +379,34 @@ struct broadcast
auto shape1_lens = shape1.lens();
auto shape0_strides = shape0.lens();
auto shape1_strides = shape1.lens();
if (std::all_of(shape0_lens.cbegin(),
shape1_lens.cend(),
[&](auto x) { return x == 1; }))
if(std::all_of(shape0_lens.cbegin(), shape1_lens.cend(), [&](auto x) { return x == 1; }))
{
if (axis != 0) RTG_THROW("when broadcasting tensor of size 1, axis should be 0");
if(axis != 0)
RTG_THROW("when broadcasting tensor of size 1, axis should be 0");
std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
return {t, bcast_shape_lens, bcast_shape_strides};
}
else
{
for (size_t i = 0; i < shape1_lens.size(); i++)
for(size_t i = 0; i < shape1_lens.size(); i++)
{
if (shape0_lens[i+axis] != shape1_lens[i]) RTG_THROW("when broadcasting success sizes must match");
if(shape0_lens[i + axis] != shape1_lens[i])
RTG_THROW("when broadcasting success sizes must match");
}
std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
for (size_t i = 0; i < shape1_strides.size(); i++) {
bcast_shape_strides[i+axis] = shape1_strides[i];
for(size_t i = 0; i < shape1_strides.size(); i++)
{
bcast_shape_strides[i + axis] = shape1_strides[i];
}
return {t, bcast_shape_lens, bcast_shape_strides};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument compute(shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.at(1).data)};
}
}
};
struct binary
......
......@@ -110,15 +110,17 @@ struct onnx_parser
return prog.add_literal(v);
});
add_op("Add", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
if (contains(attributes, "broadcast"))
if(contains(attributes, "broadcast"))
{
uint64_t broadcast = parse_value(attributes.at("broadcast")).at<uint64_t>();
if (broadcast != 0) {
uint64_t axis = (contains(attributes, "axis")) ?
parse_value(attributes.at("axis")).at<uint64_t>() : 0;
if(broadcast != 0)
{
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = prog.add_instruction(rtg::broadcast{axis}, args);
return prog.add_instruction(rtg::add{}, args[0], l);
}
}
}
return prog.add_instruction(rtg::add{}, args);
});
......
......@@ -291,47 +291,44 @@ struct add_with_broadcast
size_t ndims = output_shape.lens().size();
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input0, auto input1) {
if (ndims == 0)
if(ndims == 0)
{
output(0) = input0(0) + input1(0);
}
if (ndims == 1)
if(ndims == 1)
{
for (size_t i = 0; i < output_shape.lens()[0]; i++)
for(size_t i = 0; i < output_shape.lens()[0]; i++)
{
output(i) = input0(i) + input1(i);
}
}
else if (ndims == 2)
else if(ndims == 2)
{
dfor(output_shape.lens()[0],
output_shape.lens()[1])(
[&](std::size_t i0, std::size_t i1) {
output(i0,i1) = input0(i0,i1) + input1(i0,i1);
output_shape.lens()[1])([&](std::size_t i0, std::size_t i1) {
output(i0, i1) = input0(i0, i1) + input1(i0, i1);
});
}
else if (ndims == 3)
else if(ndims == 3)
{
dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2])(
dfor(output_shape.lens()[0], output_shape.lens()[1], output_shape.lens()[2])(
[&](std::size_t i0, std::size_t i1, std::size_t i2) {
output(i0,i1,i2) = input0(i0,i1,i2) + input1(i0,i1,i2);
});
output(i0, i1, i2) = input0(i0, i1, i2) + input1(i0, i1, i2);
});
}
else if (ndims == 4)
else if(ndims == 4)
{
dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
output(i0,i1,i2,i3) = input0(i0,i1,i2,i3) + input1(i0,i1,i2,i3);
});
output(i0, i1, i2, i3) = input0(i0, i1, i2, i3) + input1(i0, i1, i2, i3);
});
}
else
{
RTG_THROW("current not support tensors with ndim > 4");
RTG_THROW("current not support tensors with ndim > 4");
}
});
return result;
......@@ -542,7 +539,7 @@ struct cpu_apply
void apply_add(instruction_ref ins)
{
auto&& op = any_cast<add>(ins->op);
//prog->replace_instruction(ins, cpu_binary<add_op>{}, ins->arguments);
// prog->replace_instruction(ins, cpu_binary<add_op>{}, ins->arguments);
prog->replace_instruction(ins, add_with_broadcast{op}, ins->arguments);
}
......
......@@ -9,23 +9,27 @@
void fred()
{
size_t axis = 1;
rtg::shape shape0{rtg::shape::float_type, {2,4,3,4}};
rtg::shape shape1{rtg::shape::float_type, {4,3}};
std::vector<size_t> shape0_lens = shape0.lens();
std::vector<size_t> shape1_lens = shape1.lens();
rtg::shape shape0{rtg::shape::float_type, {2, 4, 3, 4}};
rtg::shape shape1{rtg::shape::float_type, {4, 3}};
std::vector<size_t> shape0_lens = shape0.lens();
std::vector<size_t> shape1_lens = shape1.lens();
std::vector<size_t> shape0_strides = shape0.strides();
std::vector<size_t> shape1_strides = shape1.strides();
for (size_t i = 0; i < shape1.lens().size(); i++) {
assert(shape0_lens[i+axis] == shape1_lens[i]);
for(size_t i = 0; i < shape1.lens().size(); i++)
{
assert(shape0_lens[i + axis] == shape1_lens[i]);
}
std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
for (size_t i = 0; i < shape1_strides.size(); i++) {
bcast_shape_strides[i+axis] = shape1_strides[i];
for(size_t i = 0; i < shape1_strides.size(); i++)
{
bcast_shape_strides[i + axis] = shape1_strides[i];
}
for (auto x : bcast_shape_lens) std::cout << x << " ";
for(auto x : bcast_shape_lens)
std::cout << x << " ";
std::cout << "\n";
for (auto x : bcast_shape_strides) std::cout << x << " ";
for(auto x : bcast_shape_strides)
std::cout << x << " ";
std::cout << "\n";
}
......@@ -90,7 +94,7 @@ void add_test()
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(rtg::literal{s, { 1, 2, 3}});
auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
p.add_instruction(rtg::add{}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
......@@ -105,7 +109,7 @@ void sub_test()
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(rtg::literal{s, { 1, 2, 3}});
auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
p.add_instruction(rtg::sub{}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
......@@ -120,7 +124,7 @@ void mul_test()
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(rtg::literal{s, { 1, 2, 3}});
auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
p.add_instruction(rtg::mul{}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
......@@ -135,7 +139,7 @@ void div_test()
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1.0f, 0.5f, 1.0f}});
auto l2 = p.add_literal(rtg::literal{s, { 1.0f, 2.0f, 4.0f}});
auto l2 = p.add_literal(rtg::literal{s, {1.0f, 2.0f, 4.0f}});
p.add_instruction(rtg::div{}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment