Commit 4d555bcb authored by Scott Thornton's avatar Scott Thornton
Browse files

Added broadcast operator for add ... still need to add test

parent 9af6974d
......@@ -366,14 +366,57 @@ struct flatten
std::string name() const { return "flatten"; }
};
struct broadcast
{
uint64_t axis = 0;
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto t = inputs.at(0).type();
auto shape0 = inputs.at(0);
auto shape1 = inputs.at(1);
auto shape0_lens = shape0.lens();
auto shape1_lens = shape1.lens();
auto shape0_strides = shape0.lens();
auto shape1_strides = shape1.lens();
if (std::all_of(shape0_lens.cbegin(),
shape1_lens.cend(),
[&](auto x) { return x == 1; }))
{
if (axis != 0) RTG_THROW("when broadcasting tensor of size 1, axis should be 0");
std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
return {t, bcast_shape_lens, bcast_shape_strides};
}
else
{
for (size_t i = 0; i < shape1_lens.size(); i++)
{
if (shape0_lens[i+axis] != shape1_lens[i]) RTG_THROW("when broadcasting success sizes must match");
}
std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
for (size_t i = 0; i < shape1_strides.size(); i++) {
bcast_shape_strides[i+axis] = shape1_strides[i];
}
return {t, bcast_shape_lens, bcast_shape_strides};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.at(1).data)};
}
};
struct binary
{
uint64_t broadcast = 0;
shape compute_shape(std::vector<shape> inputs) const
{
// TODO(wsttiger@gmail.com) Check this for numpy-style broadcasting operations
check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
};
struct add : binary
......
......@@ -76,6 +76,9 @@ struct onnx_parser
}
return prog.add_instruction(op, args);
});
add_op("MatMul", [this](attribute_map, std::vector<rtg::instruction_ref> args) {
return prog.add_instruction(rtg::gemm{}, args);
});
add_op("MaxPool", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
rtg::pooling op{"max"};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
......@@ -106,6 +109,28 @@ struct onnx_parser
rtg::literal v = parse_value(attributes.at("value"));
return prog.add_literal(v);
});
add_op("Add", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
if (contains(attributes, "broadcast"))
{
uint64_t broadcast = parse_value(attributes.at("broadcast")).at<uint64_t>();
if (broadcast != 0) {
uint64_t axis = (contains(attributes, "axis")) ?
parse_value(attributes.at("axis")).at<uint64_t>() : 0;
auto l = prog.add_instruction(rtg::broadcast{axis}, args);
return prog.add_instruction(rtg::add{}, args[0], l);
}
}
return prog.add_instruction(rtg::add{}, args);
});
add_op("Sub", [this](attribute_map, std::vector<rtg::instruction_ref> args) {
return prog.add_instruction(rtg::sub{}, args);
});
add_op("Mul", [this](attribute_map, std::vector<rtg::instruction_ref> args) {
return prog.add_instruction(rtg::mul{}, args);
});
add_op("Div", [this](attribute_map, std::vector<rtg::instruction_ref> args) {
return prog.add_instruction(rtg::div{}, args);
});
}
template <class F>
......
......@@ -281,6 +281,63 @@ struct softmax2d
}
};
struct add_with_broadcast
{
add op;
std::string name() const { return "add_with_broadcast"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const
{
size_t ndims = output_shape.lens().size();
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input0, auto input1) {
if (ndims == 0)
{
output(0) = input0(0) + input1(0);
}
if (ndims == 1)
{
for (size_t i = 0; i < output_shape.lens()[0]; i++)
{
output(i) = input0(i) + input1(i);
}
}
else if (ndims == 2)
{
dfor(output_shape.lens()[0],
output_shape.lens()[1])(
[&](std::size_t i0, std::size_t i1) {
output(i0,i1) = input0(i0,i1) + input1(i0,i1);
});
}
else if (ndims == 3)
{
dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2])(
[&](std::size_t i0, std::size_t i1, std::size_t i2) {
output(i0,i1,i2) = input0(i0,i1,i2) + input1(i0,i1,i2);
});
}
else if (ndims == 4)
{
dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
output(i0,i1,i2,i3) = input0(i0,i1,i2,i3) + input1(i0,i1,i2,i3);
});
}
else
{
RTG_THROW("current not support tensors with ndim > 4");
}
});
return result;
}
};
struct add_op
{
std::string name() const { return "add"; }
......@@ -393,6 +450,22 @@ struct cpu_apply
{
apply_tan(it);
}
else if(it->op.name() == "add")
{
apply_add(it);
}
else if(it->op.name() == "sub")
{
apply_sub(it);
}
else if(it->op.name() == "mul")
{
apply_mul(it);
}
else if(it->op.name() == "div")
{
apply_div(it);
}
}
}
......@@ -465,6 +538,28 @@ struct cpu_apply
{
prog->replace_instruction(ins, cpu_unary<tan_op>{}, ins->arguments);
}
void apply_add(instruction_ref ins)
{
auto&& op = any_cast<add>(ins->op);
//prog->replace_instruction(ins, cpu_binary<add_op>{}, ins->arguments);
prog->replace_instruction(ins, add_with_broadcast{op}, ins->arguments);
}
void apply_sub(instruction_ref ins)
{
prog->replace_instruction(ins, cpu_binary<sub_op>{}, ins->arguments);
}
void apply_mul(instruction_ref ins)
{
prog->replace_instruction(ins, cpu_binary<mul_op>{}, ins->arguments);
}
void apply_div(instruction_ref ins)
{
prog->replace_instruction(ins, cpu_binary<div_op>{}, ins->arguments);
}
};
std::string cpu_target::name() const { return "cpu"; }
......
......@@ -6,6 +6,29 @@
#include "test.hpp"
#include "verify.hpp"
void fred()
{
size_t axis = 1;
rtg::shape shape0{rtg::shape::float_type, {2,4,3,4}};
rtg::shape shape1{rtg::shape::float_type, {4,3}};
std::vector<size_t> shape0_lens = shape0.lens();
std::vector<size_t> shape1_lens = shape1.lens();
std::vector<size_t> shape0_strides = shape0.strides();
std::vector<size_t> shape1_strides = shape1.strides();
for (size_t i = 0; i < shape1.lens().size(); i++) {
assert(shape0_lens[i+axis] == shape1_lens[i]);
}
std::vector<size_t> bcast_shape_lens = shape0_lens;
std::vector<size_t> bcast_shape_strides(bcast_shape_lens.size(), 0);
for (size_t i = 0; i < shape1_strides.size(); i++) {
bcast_shape_strides[i+axis] = shape1_strides[i];
}
for (auto x : bcast_shape_lens) std::cout << x << " ";
std::cout << "\n";
for (auto x : bcast_shape_strides) std::cout << x << " ";
std::cout << "\n";
}
void exp_test()
{
rtg::program p;
......@@ -62,6 +85,66 @@ void tan_test()
EXPECT(test::verify_range(results_vector, gold));
}
void add_test()
{
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(rtg::literal{s, { 1, 2, 3}});
p.add_instruction(rtg::add{}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4};
EXPECT(test::verify_range(results_vector, gold));
}
void sub_test()
{
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(rtg::literal{s, { 1, 2, 3}});
p.add_instruction(rtg::sub{}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, -2, -2};
EXPECT(test::verify_range(results_vector, gold));
}
void mul_test()
{
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
auto l2 = p.add_literal(rtg::literal{s, { 1, 2, 3}});
p.add_instruction(rtg::mul{}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1, 0, 3};
EXPECT(test::verify_range(results_vector, gold));
}
void div_test()
{
rtg::program p;
rtg::shape s{rtg::shape::float_type, {3}};
auto l1 = p.add_literal(rtg::literal{s, {-1.0f, 0.5f, 1.0f}});
auto l2 = p.add_literal(rtg::literal{s, { 1.0f, 2.0f, 4.0f}});
p.add_instruction(rtg::div{}, l1, l2);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.f, 0.25f, 0.25f};
EXPECT(test::verify_range(results_vector, gold));
}
void reshape_test()
{
rtg::shape a_shape{rtg::shape::float_type, {24, 1, 1, 1}};
......@@ -394,10 +477,14 @@ void conv2d_padding_stride_test()
int main()
{
fred();
exp_test();
sin_test();
cos_test();
tan_test();
add_test();
sub_test();
mul_test();
gemm_test();
reshape_test();
softmax_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment