Commit f199ea9e authored by Paul's avatar Paul
Browse files

Fix tidy error

parent 50edee84
......@@ -82,7 +82,6 @@ struct raw_data
/**
* @brief Retrieves a single element of data
* @details [long description]
*
* @param n The index to retrieve the data from
* @tparam T The type of data to be retrieved
......@@ -97,6 +96,20 @@ struct raw_data
}
};
template<class T, class... Ts>
auto visit_all(T&& x, Ts&&... xs)
{
auto&& s = x.get_shape();
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if (!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
RTG_THROW("Types must be the same");
return [&](auto v) {
s.visit_type([&](auto as) {
v(make_view(s, as.from(x.data())), make_view(xs.get_shape(), as.from(xs.data()))...);
});
};
}
} // namespace rtg
#endif
......@@ -17,9 +17,7 @@ struct cpu_convolution
{
shape output_shape = compute_shape({args[0].get_shape(), args[1].get_shape()});
argument result{compute_shape({args[0].get_shape(), args[1].get_shape()})};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
args[1].visit([&](auto weights) {
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_n = input.get_shape().lens()[0];
auto in_c = input.get_shape().lens()[1];
auto in_h = input.get_shape().lens()[2];
......@@ -49,8 +47,6 @@ struct cpu_convolution
});
});
});
});
return result;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment