"src/include/blockwise_4d_tensor_op.hpp" did not exist on "5ce19234a4538d52e18837a84ebe7c1fef224c71"
Commit 97d5b3ca authored by Paul's avatar Paul
Browse files

Formatting

parent 86576438
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
namespace rtg { namespace rtg {
template<bool... Bs> template <bool... Bs>
struct and_ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>>
: std::is_same<and_<Bs...>, and_<(Bs || true)...>> {
{}; };
#define RTG_REQUIRES(...) class = typename std::enable_if<and_<__VA_ARGS__, true>{}>::type #define RTG_REQUIRES(...) class = typename std::enable_if<and_<__VA_ARGS__, true>{}>::type
......
...@@ -61,8 +61,8 @@ struct shape ...@@ -61,8 +61,8 @@ struct shape
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
std::size_t index(const std::vector<std::size_t>& l) const; std::size_t index(const std::vector<std::size_t>& l) const;
template<class Iterator> template <class Iterator>
std::size_t index(Iterator start, Iterator last) const std::size_t index(Iterator start, Iterator last) const
{ {
assert(std::distance(start, last) <= this->lens().size()); assert(std::distance(start, last) <= this->lens().size());
......
...@@ -6,23 +6,23 @@ ...@@ -6,23 +6,23 @@
namespace rtg { namespace rtg {
template<class F> template <class F>
void shape_for_each(const rtg::shape& s, F f) void shape_for_each(const rtg::shape& s, F f)
{ {
// Ensure calls to f use const ref to vector // Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); }; auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
std::vector<std::size_t> indices(s.lens().size()); std::vector<std::size_t> indices(s.lens().size());
for(std::size_t i = 0;i < s.elements();i++) { for(std::size_t i = 0; i < s.elements(); i++)
{
std::transform(s.strides().begin(), std::transform(s.strides().begin(),
s.strides().end(), s.strides().end(),
s.lens().begin(), s.lens().begin(),
indices.begin(), indices.begin(),
[&](std::size_t stride, std::size_t len) { return (i / stride) % len; }); [&](std::size_t stride, std::size_t len) { return (i / stride) % len; });
call(indices); call(indices);
} }
} }
} // namespace rtg } // namespace rtg
#endif #endif
...@@ -63,14 +63,16 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const ...@@ -63,14 +63,16 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std::size_t shape::index(std::size_t i) const std::size_t shape::index(std::size_t i) const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if (this->packed()) return i; if(this->packed())
else return std::inner_product( return i;
this->lens().begin(), else
this->lens().end(), return std::inner_product(
this->strides().begin(), this->lens().begin(),
std::size_t{0}, this->lens().end(),
std::plus<std::size_t>{}, this->strides().begin(),
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len) * stride; }); std::size_t{0},
std::plus<std::size_t>{},
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len) * stride; });
} }
bool shape::packed() const { return this->m_packed; } bool shape::packed() const { return this->m_packed; }
std::size_t shape::element_space() const std::size_t shape::element_space() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment