Commit e77913c5 authored by Paul's avatar Paul
Browse files

Use shape_for_each

parent 2da3e1d0
......@@ -3,6 +3,7 @@
#include <rtg/instruction.hpp>
#include <rtg/dfor.hpp>
#include <rtg/operators.hpp>
#include <rtg/shape_for_each.hpp>
namespace rtg {
namespace cpu {
......@@ -434,60 +435,6 @@ struct softmax2d
}
};
struct add_with_broadcast
{
add op;
std::string name() const { return "add_with_broadcast"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
{
size_t ndims = output_shape.lens().size();
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input0, auto input1) {
if(ndims == 0)
{
output(0) = input0(0) + input1(0);
}
if(ndims == 1)
{
for(size_t i = 0; i < output_shape.lens()[0]; i++)
{
output(i) = input0(i) + input1(i);
}
}
else if(ndims == 2)
{
dfor(output_shape.lens()[0],
output_shape.lens()[1])([&](std::size_t i0, std::size_t i1) {
output(i0, i1) = input0(i0, i1) + input1(i0, i1);
});
}
else if(ndims == 3)
{
dfor(output_shape.lens()[0], output_shape.lens()[1], output_shape.lens()[2])(
[&](std::size_t i0, std::size_t i1, std::size_t i2) {
output(i0, i1, i2) = input0(i0, i1, i2) + input1(i0, i1, i2);
});
}
else if(ndims == 4)
{
dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
output(i0, i1, i2, i3) = input0(i0, i1, i2, i3) + input1(i0, i1, i2, i3);
});
}
else
{
RTG_THROW("current not support tensors with ndim > 4");
}
});
return result;
}
};
struct add_op
{
std::string name() const { return "add"; }
......@@ -534,7 +481,13 @@ struct cpu_binary
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn());
if(input1.get_shape().packed() and input2.get_shape().packed()) {
std::transform(input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn());
} else {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = op.fcn()(input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
}
});
return result;
}
......@@ -573,6 +526,7 @@ struct cpu_apply
apply_map["sin"] = simple_op<cpu_unary<sin_op>>();
apply_map["cos"] = simple_op<cpu_unary<cos_op>>();
apply_map["tan"] = simple_op<cpu_unary<tan_op>>();
apply_map["add"] = simple_op<cpu_binary<add_op>>();
apply_map["sub"] = simple_op<cpu_binary<sub_op>>();
apply_map["mul"] = simple_op<cpu_binary<mul_op>>();
apply_map["div"] = simple_op<cpu_binary<div_op>>();
......@@ -593,10 +547,6 @@ struct cpu_apply
{
apply_pooling(it);
}
else if(it->op.name() == "add")
{
apply_add(it);
}
else if(apply_map.count(it->op.name()) > 0)
{
apply_map.at(it->op.name())(it);
......@@ -632,13 +582,6 @@ struct cpu_apply
else if(op.mode == "average")
prog->replace_instruction(ins, cpu_pooling<avg_pool>{op}, ins->arguments);
}
void apply_add(instruction_ref ins)
{
auto&& op = any_cast<add>(ins->op);
// prog->replace_instruction(ins, cpu_binary<add_op>{}, ins->arguments);
prog->replace_instruction(ins, add_with_broadcast{op}, ins->arguments);
}
};
std::string cpu_target::name() const { return "cpu"; }
......
......@@ -91,13 +91,6 @@ void broadcast_test()
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<int32_t> results_vector(4);
// result.visit([&](auto output) {
// EXPECT(output(0,0) == -2);
// EXPECT(output(0,1) == -2);
// EXPECT(output(1,0) == -3);
// EXPECT(output(1,1) == -3);
// });
}
void add_broadcast_test()
{
......@@ -113,6 +106,7 @@ void add_broadcast_test()
p.add_instruction(rtg::add{}, l1, l3);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment