Commit 50edee84 authored by Paul's avatar Paul
Browse files

Formatting

parent fd26582a
...@@ -6,17 +6,13 @@ namespace rtg { ...@@ -6,17 +6,13 @@ namespace rtg {
// Multidimensional for loop // Multidimensional for loop
inline auto dfor() inline auto dfor()
{ {
return [](auto f) return [](auto f) { f(); };
{
f();
};
} }
template<class T, class... Ts> template <class T, class... Ts>
auto dfor(T x, Ts... xs) auto dfor(T x, Ts... xs)
{ {
return [=](auto f) return [=](auto f) {
{
for(T i = 0; i < x; i++) for(T i = 0; i < x; i++)
{ {
dfor(xs...)([&](Ts... is) { f(i, is...); }); dfor(xs...)([&](Ts... is) { f(i, is...); });
......
...@@ -37,16 +37,14 @@ struct element_type ...@@ -37,16 +37,14 @@ struct element_type
}; };
template <class T> template <class T>
using remove_ptr = typename std::conditional_t<std::is_pointer<T>{}, using remove_ptr = typename std::
std::remove_pointer<T>, conditional_t<std::is_pointer<T>{}, std::remove_pointer<T>, element_type<T>>::type;
element_type<T>>::type;
template <class T> template <class T>
using shared = std::shared_ptr<remove_ptr<T>>; using shared = std::shared_ptr<remove_ptr<T>>;
} // namespace rtg } // namespace rtg
#define RTG_MANAGE_PTR(T, F) \ #define RTG_MANAGE_PTR(T, F) rtg::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
rtg::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#endif #endif
...@@ -4,20 +4,15 @@ ...@@ -4,20 +4,15 @@
#include <rtg/dfor.hpp> #include <rtg/dfor.hpp>
#include <rtg/operators.hpp> #include <rtg/operators.hpp>
namespace rtg { namespace cpu { namespace rtg {
namespace cpu {
struct cpu_convolution struct cpu_convolution
{ {
convolution op; convolution op;
std::string name() const std::string name() const { return "cpu::convolution"; }
{ shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
return "cpu::convolution";
}
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(inputs);
}
argument compute(std::vector<argument> args) const argument compute(std::vector<argument> args) const
{ {
shape output_shape = compute_shape({args[0].get_shape(), args[1].get_shape()}); shape output_shape = compute_shape({args[0].get_shape(), args[1].get_shape()});
...@@ -34,7 +29,10 @@ struct cpu_convolution ...@@ -34,7 +29,10 @@ struct cpu_convolution
auto wei_h = weights.get_shape().lens()[2]; auto wei_h = weights.get_shape().lens()[2];
auto wei_w = weights.get_shape().lens()[3]; auto wei_w = weights.get_shape().lens()[3];
dfor(in_n, in_c, in_h, in_w)([&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { dfor(in_n,
in_c,
in_h,
in_w)([&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0]; const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1]; const int start_y = j * op.stride[1] - op.padding[1];
...@@ -59,16 +57,10 @@ struct cpu_convolution ...@@ -59,16 +57,10 @@ struct cpu_convolution
struct relu struct relu
{ {
std::string name() const std::string name() const { return "cpu::relu"; }
{ shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
return "cpu::relu";
}
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.front();
}
argument compute(std::vector<argument> args) const argument compute(std::vector<argument> args) const
{ {
argument result{args[0].get_shape()}; argument result{args[0].get_shape()};
result.visit([&](auto output) { result.visit([&](auto output) {
...@@ -84,14 +76,18 @@ struct relu ...@@ -84,14 +76,18 @@ struct relu
struct cpu_apply struct cpu_apply
{ {
program * prog; program* prog;
void apply() void apply()
{ {
for(auto it = prog->begin();it != prog->end();it++) { for(auto it = prog->begin(); it != prog->end(); it++)
if (it->op.name() == "convolution") { {
if(it->op.name() == "convolution")
{
apply_convolution(it); apply_convolution(it);
} else if (it->op.name() == "activation") { }
else if(it->op.name() == "activation")
{
apply_activation(it); apply_activation(it);
} }
} }
...@@ -109,18 +105,11 @@ struct cpu_apply ...@@ -109,18 +105,11 @@ struct cpu_apply
if(op.mode == "relu") if(op.mode == "relu")
prog->replace_instruction(ins, relu{}, ins->arguments); prog->replace_instruction(ins, relu{}, ins->arguments);
} }
}; };
std::string cpu_target::name() const std::string cpu_target::name() const { return "cpu"; }
{
return "cpu";
}
void cpu_target::apply(program& p) const void cpu_target::apply(program& p) const { cpu_apply{&p}.apply(); }
{
cpu_apply{&p}.apply();
}
} // namespace cpu } // namespace cpu
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
#include <rtg/program.hpp> #include <rtg/program.hpp>
namespace rtg { namespace cpu { namespace rtg {
namespace cpu {
struct cpu_target struct cpu_target
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment