Commit 8f330074 authored by Paul's avatar Paul
Browse files

Formatting

parent a4e698cc
...@@ -70,10 +70,7 @@ struct cpu_contiguous ...@@ -70,10 +70,7 @@ struct cpu_contiguous
{ {
contiguous op; contiguous op;
std::string name() const { return "cpu::contiguous"; } std::string name() const { return "cpu::contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
{
return op.compute_shape(inputs);
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -82,31 +79,30 @@ struct cpu_contiguous ...@@ -82,31 +79,30 @@ struct cpu_contiguous
auto ndim = output_shape.lens().size(); auto ndim = output_shape.lens().size();
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
value_type* ptr = static_cast<value_type*>(output.data()); value_type* ptr = static_cast<value_type*>(output.data());
if (ndim == 2) { if(ndim == 2)
dfor(input_shape.lens()[0], {
input_shape.lens()[1])( dfor(input_shape.lens()[0], input_shape.lens()[1])(
[&](std::size_t i0, std::size_t i1) { [&](std::size_t i0, std::size_t i1) { *ptr++ = input(i0, i1); });
*ptr++ = input(i0,i1);
});
} }
else if (ndim == 3) { else if(ndim == 3)
dfor(input_shape.lens()[0], {
input_shape.lens()[1], dfor(input_shape.lens()[0], input_shape.lens()[1], input_shape.lens()[2])(
input_shape.lens()[2])(
[&](std::size_t i0, std::size_t i1, std::size_t i2) { [&](std::size_t i0, std::size_t i1, std::size_t i2) {
*ptr++ = input(i0,i1,i2); *ptr++ = input(i0, i1, i2);
}); });
} }
else if (ndim == 4) { else if(ndim == 4)
{
dfor(input_shape.lens()[0], dfor(input_shape.lens()[0],
input_shape.lens()[1], input_shape.lens()[1],
input_shape.lens()[2], input_shape.lens()[2],
input_shape.lens()[3])( input_shape.lens()[3])(
[&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) { [&](std::size_t i0, std::size_t i1, std::size_t i2, std::size_t i3) {
*ptr++ = input(i0,i1,i2,i3); *ptr++ = input(i0, i1, i2, i3);
}); });
} }
else if (ndim == 5) { else if(ndim == 5)
{
dfor(input_shape.lens()[0], dfor(input_shape.lens()[0],
input_shape.lens()[1], input_shape.lens()[1],
input_shape.lens()[2], input_shape.lens()[2],
...@@ -116,11 +112,10 @@ struct cpu_contiguous ...@@ -116,11 +112,10 @@ struct cpu_contiguous
std::size_t i1, std::size_t i1,
std::size_t i2, std::size_t i2,
std::size_t i3, std::size_t i3,
std::size_t i4) { std::size_t i4) { *ptr++ = input(i0, i1, i2, i3, i4); });
*ptr++ = input(i0,i1,i2,i3,i4);
});
} }
else if (ndim == 6) { else if(ndim == 6)
{
dfor(input_shape.lens()[0], dfor(input_shape.lens()[0],
input_shape.lens()[1], input_shape.lens()[1],
input_shape.lens()[2], input_shape.lens()[2],
...@@ -132,9 +127,7 @@ struct cpu_contiguous ...@@ -132,9 +127,7 @@ struct cpu_contiguous
std::size_t i2, std::size_t i2,
std::size_t i3, std::size_t i3,
std::size_t i4, std::size_t i4,
std::size_t i5) { std::size_t i5) { *ptr++ = input(i0, i1, i2, i3, i4, i5); });
*ptr++ = input(i0,i1,i2,i3,i4,i5);
});
} }
}); });
return result; return result;
...@@ -425,22 +418,16 @@ struct cpu_apply ...@@ -425,22 +418,16 @@ struct cpu_apply
program* prog; program* prog;
std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{};
template<class T> template <class T>
auto simple_op() auto simple_op()
{ {
return [this](instruction_ref ins) return [this](instruction_ref ins) { apply_simple_op<T>(ins); };
{
apply_simple_op<T>(ins);
};
} }
template<class T, class Op> template <class T, class Op>
auto extend_op() auto extend_op()
{ {
return [this](instruction_ref ins) return [this](instruction_ref ins) { apply_extend_op<T, Op>(ins); };
{
apply_extend_op<T, Op>(ins);
};
} }
void init() void init()
...@@ -479,13 +466,13 @@ struct cpu_apply ...@@ -479,13 +466,13 @@ struct cpu_apply
} }
} }
template<class T> template <class T>
void apply_simple_op(instruction_ref ins) void apply_simple_op(instruction_ref ins)
{ {
prog->replace_instruction(ins, T{}, ins->arguments); prog->replace_instruction(ins, T{}, ins->arguments);
} }
template<class T, class Op> template <class T, class Op>
void apply_extend_op(instruction_ref ins) void apply_extend_op(instruction_ref ins)
{ {
auto&& op = any_cast<Op>(ins->op); auto&& op = any_cast<Op>(ins->op);
......
...@@ -394,21 +394,21 @@ void conv2d_padding_stride_test() ...@@ -394,21 +394,21 @@ void conv2d_padding_stride_test()
void transpose_test() void transpose_test()
{ {
rtg::shape a_shape{rtg::shape::float_type, {1,2,2,3}}; rtg::shape a_shape{rtg::shape::float_type, {1, 2, 2, 3}};
std::vector<float> data(12); std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
{ {
rtg::program p; rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data}); auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> perm = {0,3,1,2}; std::vector<int64_t> perm = {0, 3, 1, 2};
p.add_instruction(rtg::transpose{perm}, l); p.add_instruction(rtg::transpose{perm}, l);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
result.visit([&] (auto output){ result.visit([&](auto output) {
std::vector<size_t> new_lens = {1,3,2,2}; std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12,1,6,3}; std::vector<size_t> new_strides = {12, 1, 6, 3};
EXPECT(bool{output.get_shape().lens() == new_lens}); EXPECT(bool{output.get_shape().lens() == new_lens});
EXPECT(bool{output.get_shape().strides() == new_strides}); EXPECT(bool{output.get_shape().strides() == new_strides});
}); });
...@@ -416,7 +416,7 @@ void transpose_test() ...@@ -416,7 +416,7 @@ void transpose_test()
{ {
rtg::program p; rtg::program p;
auto l = p.add_literal(rtg::literal{a_shape, data}); auto l = p.add_literal(rtg::literal{a_shape, data});
std::vector<int64_t> perm = {0,3,1,2}; std::vector<int64_t> perm = {0, 3, 1, 2};
auto result = p.add_instruction(rtg::transpose{perm}, l); auto result = p.add_instruction(rtg::transpose{perm}, l);
p.add_instruction(rtg::contiguous{}, result); p.add_instruction(rtg::contiguous{}, result);
p.compile(rtg::cpu::cpu_target{}); p.compile(rtg::cpu::cpu_target{});
...@@ -429,8 +429,9 @@ void transpose_test() ...@@ -429,8 +429,9 @@ void transpose_test()
} }
} }
void contiguous_test() { void contiguous_test()
rtg::shape a_shape{rtg::shape::float_type, {1,3,2,2}, {12,1,6,3}}; {
rtg::shape a_shape{rtg::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> data(12); std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment