Commit 8f330074 authored by Paul's avatar Paul
Browse files

Formatting

parent a4e698cc
...@@ -57,8 +57,8 @@ struct cpu_convolution ...@@ -57,8 +57,8 @@ struct cpu_convolution
struct cpu_transpose struct cpu_transpose
{ {
transpose op; transpose op;
std::string name() const { return "cpu::transpose"; } std::string name() const { return "cpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); } shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
...@@ -70,71 +70,64 @@ struct cpu_contiguous ...@@ -70,71 +70,64 @@ struct cpu_contiguous
{ {
contiguous op; contiguous op;
std::string name() const { return "cpu::contiguous"; } std::string name() const { return "cpu::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
{
return op.compute_shape(inputs);
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
auto input_shape = args[0].get_shape(); auto input_shape = args[0].get_shape();
auto ndim = output_shape.lens().size(); auto ndim = output_shape.lens().size();
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
value_type* ptr = static_cast<value_type*>(output.data()); value_type* ptr = static_cast<value_type*>(output.data());
if (ndim == 2) { if(ndim == 2)
dfor(input_shape.lens()[0], {
input_shape.lens()[1])( dfor(input_shape.lens()[0], input_shape.lens()[1])(
[&](std::size_t i0, std::size_t i1) { [&](std::size_t i0, std::size_t i1) { *ptr++ = input(i0, i1); });
*ptr++ = input(i0,i1);
});
} }
else if (ndim == 3) { else if(ndim == 3)
dfor(input_shape.lens()[0], {
input_shape.lens()[1], dfor(input_shape.lens()[0], input_shape.lens()[1], input_shape.lens()[2])(
input_shape.lens()[2])(
[&](std::size_t i0, std::size_t i1, std::size_t i2) { [&](std::size_t i0, std::size_t i1, std::size_t i2) {
*ptr++ = input(i0,i1,i2); *ptr++ = input(i0, i1, i2);
}); });
} }
else if (ndim == 4) { else if(ndim == 4)
{
dfor(input_shape.lens()[0], dfor(input_shape.lens()[0],
input_shape.lens()[1], input_shape.lens()[1],
input_shape.lens()[2], input_shape.lens()[2],
input_shape.lens()[3])( input_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) { [&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
*ptr++ = input(i0,i1,i2,i3); *ptr++ = input(i0, i1, i2, i3);
}); });
} }
else if (ndim == 5) { else if(ndim == 5)
{
dfor(input_shape.lens()[0], dfor(input_shape.lens()[0],
input_shape.lens()[1], input_shape.lens()[1],
input_shape.lens()[2], input_shape.lens()[2],
input_shape.lens()[3], input_shape.lens()[3],
input_shape.lens()[4])( input_shape.lens()[4])(
[&](std::size_t i0, [&](std::size_t i0,
std::size_t i1, std::size_t i1,
std::size_t i2, std::size_t i2,
std::size_t i3, std::size_t i3,
std::size_t i4) { std::size_t i4) { *ptr++ = input(i0, i1, i2, i3, i4); });
*ptr++ = input(i0,i1,i2,i3,i4);
});
} }
else if (ndim == 6) { else if(ndim == 6)
{
dfor(input_shape.lens()[0], dfor(input_shape.lens()[0],
input_shape.lens()[1], input_shape.lens()[1],
input_shape.lens()[2], input_shape.lens()[2],
input_shape.lens()[3], input_shape.lens()[3],
input_shape.lens()[4], input_shape.lens()[4],
input_shape.lens()[5])( input_shape.lens()[5])(
[&](std::size_t i0, [&](std::size_t i0,
std::size_t i1, std::size_t i1,
std::size_t i2, std::size_t i2,
std::size_t i3, std::size_t i3,
std::size_t i4, std::size_t i4,
std::size_t i5) { std::size_t i5) { *ptr++ = input(i0, i1, i2, i3, i4, i5); });
*ptr++ = input(i0,i1,i2,i3,i4,i5);
});
} }
}); });
return result; return result;
...@@ -425,40 +418,34 @@ struct cpu_apply ...@@ -425,40 +418,34 @@ struct cpu_apply
program* prog; program* prog;
std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{};
template<class T> template <class T>
auto simple_op() auto simple_op()
{ {
return [this](instruction_ref ins) return [this](instruction_ref ins) { apply_simple_op<T>(ins); };
{
apply_simple_op<T>(ins);
};
} }
template<class T, class Op> template <class T, class Op>
auto extend_op() auto extend_op()
{ {
return [this](instruction_ref ins) return [this](instruction_ref ins) { apply_extend_op<T, Op>(ins); };
{
apply_extend_op<T, Op>(ins);
};
} }
void init() void init()
{ {
apply_map["convolution"] = extend_op<cpu_convolution, convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, convolution>();
apply_map["gemm"] = extend_op<cpu_gemm, gemm>(); apply_map["gemm"] = extend_op<cpu_gemm, gemm>();
apply_map["reshape"] = extend_op<cpu_reshape, reshape>(); apply_map["reshape"] = extend_op<cpu_reshape, reshape>();
apply_map["contiguous"] = extend_op<cpu_contiguous, contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, contiguous>();
apply_map["transpose"] = extend_op<cpu_transpose, transpose>(); apply_map["transpose"] = extend_op<cpu_transpose, transpose>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>(); apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>();
apply_map["sigmoid"] = simple_op<cpu_unary<sigmoid_op>>(); apply_map["sigmoid"] = simple_op<cpu_unary<sigmoid_op>>();
apply_map["exp"] = simple_op<cpu_unary<exp_op>>(); apply_map["exp"] = simple_op<cpu_unary<exp_op>>();
apply_map["neg"] = simple_op<cpu_unary<neg_op>>(); apply_map["neg"] = simple_op<cpu_unary<neg_op>>();
apply_map["sin"] = simple_op<cpu_unary<sin_op>>(); apply_map["sin"] = simple_op<cpu_unary<sin_op>>();
apply_map["cos"] = simple_op<cpu_unary<cos_op>>(); apply_map["cos"] = simple_op<cpu_unary<cos_op>>();
apply_map["tan"] = simple_op<cpu_unary<tan_op>>(); apply_map["tan"] = simple_op<cpu_unary<tan_op>>();
apply_map["softmax"] = simple_op<softmax2d>(); apply_map["softmax"] = simple_op<softmax2d>();
} }
...@@ -471,7 +458,7 @@ struct cpu_apply ...@@ -471,7 +458,7 @@ struct cpu_apply
if(it->op.name() == "activation") if(it->op.name() == "activation")
{ {
apply_activation(it); apply_activation(it);
} }
else if(apply_map.count(it->op.name()) > 0) else if(apply_map.count(it->op.name()) > 0)
{ {
apply_map.at(it->op.name())(it); apply_map.at(it->op.name())(it);
...@@ -479,13 +466,13 @@ struct cpu_apply ...@@ -479,13 +466,13 @@ struct cpu_apply
} }
} }
template<class T> template <class T>
void apply_simple_op(instruction_ref ins) void apply_simple_op(instruction_ref ins)
{ {
prog->replace_instruction(ins, T{}, ins->arguments); prog->replace_instruction(ins, T{}, ins->arguments);
} }
template<class T, class Op> template <class T, class Op>
void apply_extend_op(instruction_ref ins) void apply_extend_op(instruction_ref ins)
{ {
auto&& op = any_cast<Op>(ins->op); auto&& op = any_cast<Op>(ins->op);
......
...@@ -394,30 +394,30 @@ void conv2d_padding_stride_test() ...@@ -394,30 +394,30 @@ void conv2d_padding_stride_test()
void transpose_test() void transpose_test()
{ {
rtg::shape a_shape{rtg::shape::float_type, {1,2,2,3}}; rtg::shape a_shape{rtg::shape::float_type, {1, 2, 2, 3}};
std::vector<float> data(12); std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
{ {
rtg::program p; rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data}); auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> perm = {0,3,1,2}; std::vector<int64_t> perm = {0, 3, 1, 2};
p.add_instruction(rtg::transpose{perm}, l); p.add_instruction(rtg::transpose{perm}, l);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
result.visit([&] (auto output){ result.visit([&](auto output) {
std::vector<size_t> new_lens = {1,3,2,2}; std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12,1,6,3}; std::vector<size_t> new_strides = {12, 1, 6, 3};
EXPECT(bool{output.get_shape().lens() == new_lens}); EXPECT(bool{output.get_shape().lens() == new_lens});
EXPECT(bool{output.get_shape().strides() == new_strides}); EXPECT(bool{output.get_shape().strides() == new_strides});
}); });
} }
{ {
rtg::program p; rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data}); auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> perm = {0,3,1,2}; std::vector<int64_t> perm = {0, 3, 1, 2};
auto result = p.add_instruction(rtg::transpose{perm}, l); auto result = p.add_instruction(rtg::transpose{perm}, l);
p.add_instruction(rtg::contiguous{}, result); p.add_instruction(rtg::contiguous{}, result);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result2 = p.eval({}); auto result2 = p.eval({});
...@@ -429,8 +429,9 @@ void transpose_test() ...@@ -429,8 +429,9 @@ void transpose_test()
} }
} }
void contiguous_test() { void contiguous_test()
rtg::shape a_shape{rtg::shape::float_type, {1,3,2,2}, {12,1,6,3}}; {
rtg::shape a_shape{rtg::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> data(12); std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
...@@ -442,9 +443,9 @@ void contiguous_test() { ...@@ -442,9 +443,9 @@ void contiguous_test() {
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2}; std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3}; std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(test::verify_range(results_vector, gold));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment