Commit f199ea9e authored by Paul's avatar Paul
Browse files

Fix tidy error

parent 50edee84
...@@ -82,7 +82,6 @@ struct raw_data ...@@ -82,7 +82,6 @@ struct raw_data
/** /**
* @brief Retrieves a single element of data * @brief Retrieves a single element of data
* @details [long description]
* *
* @param n The index to retrieve the data from * @param n The index to retrieve the data from
* @tparam T The type of data to be retrieved * @tparam T The type of data to be retrieved
...@@ -97,6 +96,20 @@ struct raw_data ...@@ -97,6 +96,20 @@ struct raw_data
} }
}; };
template<class T, class... Ts>
auto visit_all(T&& x, Ts&&... xs)
{
auto&& s = x.get_shape();
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if (!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
RTG_THROW("Types must be the same");
return [&](auto v) {
s.visit_type([&](auto as) {
v(make_view(s, as.from(x.data())), make_view(xs.get_shape(), as.from(xs.data()))...);
});
};
}
} // namespace rtg } // namespace rtg
#endif #endif
...@@ -17,39 +17,35 @@ struct cpu_convolution ...@@ -17,39 +17,35 @@ struct cpu_convolution
{ {
shape output_shape = compute_shape({args[0].get_shape(), args[1].get_shape()}); shape output_shape = compute_shape({args[0].get_shape(), args[1].get_shape()});
argument result{compute_shape({args[0].get_shape(), args[1].get_shape()})}; argument result{compute_shape({args[0].get_shape(), args[1].get_shape()})};
result.visit([&](auto output) { visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
args[0].visit([&](auto input) { auto in_n = input.get_shape().lens()[0];
args[1].visit([&](auto weights) { auto in_c = input.get_shape().lens()[1];
auto in_n = input.get_shape().lens()[0]; auto in_h = input.get_shape().lens()[2];
auto in_c = input.get_shape().lens()[1]; auto in_w = input.get_shape().lens()[3];
auto in_h = input.get_shape().lens()[2];
auto in_w = input.get_shape().lens()[3]; auto wei_c = weights.get_shape().lens()[1];
auto wei_h = weights.get_shape().lens()[2];
auto wei_c = weights.get_shape().lens()[1]; auto wei_w = weights.get_shape().lens()[3];
auto wei_h = weights.get_shape().lens()[2];
auto wei_w = weights.get_shape().lens()[3]; dfor(in_n,
in_c,
dfor(in_n, in_h,
in_c, in_w)([&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
in_h, const int start_x = i * op.stride[0] - op.padding[0];
in_w)([&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { const int start_y = j * op.stride[1] - op.padding[1];
const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1]; double acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
double acc = 0; const int in_x = start_x + x;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) { const int in_y = start_y + y;
const int in_x = start_x + x; if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
const int in_y = start_y + y; {
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w) acc += input(o, k, in_x, in_y) * weights(w, k, x, y);
{ }
acc += input(o, k, in_x, in_y) * weights(w, k, x, y);
}
});
output(o, w, i, j) = acc;
});
}); });
output(o, w, i, j) = acc;
}); });
}); });
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment