Commit 86576438 authored by Paul's avatar Paul
Browse files

Add shape for each

parent 9af6974d
...@@ -3,11 +3,10 @@ ...@@ -3,11 +3,10 @@
#define RTG_GUARD_RAW_DATA_HPP #define RTG_GUARD_RAW_DATA_HPP
#include <rtg/tensor_view.hpp> #include <rtg/tensor_view.hpp>
#include <rtg/requires.hpp>
namespace rtg { namespace rtg {
#define RTG_REQUIRES(...) class = typename std::enable_if<(__VA_ARGS__)>::type
struct raw_data_base struct raw_data_base
{ {
}; };
......
#ifndef RTG_GUARD_RTGLIB_REQUIRES_HPP
#define RTG_GUARD_RTGLIB_REQUIRES_HPP
#include <type_traits>
namespace rtg {
template<bool... Bs>
struct and_
: std::is_same<and_<Bs...>, and_<(Bs || true)...>>
{};
#define RTG_REQUIRES(...) class = typename std::enable_if<and_<__VA_ARGS__, true>{}>::type
} // namespace rtg
#endif
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <vector> #include <vector>
#include <cassert> #include <cassert>
#include <ostream> #include <ostream>
#include <numeric>
#include <rtg/errors.hpp> #include <rtg/errors.hpp>
...@@ -60,6 +61,14 @@ struct shape ...@@ -60,6 +61,14 @@ struct shape
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
std::size_t index(const std::vector<std::size_t>& l) const; std::size_t index(const std::vector<std::size_t>& l) const;
template<class Iterator>
std::size_t index(Iterator start, Iterator last) const
{
assert(std::distance(start, last) <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(start, last, this->strides().begin(), std::size_t{0});
}
// Map element index to space index // Map element index to space index
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
......
#ifndef RTG_GUARD_RTGLIB_SHAPE_FOR_EACH_HPP
#define RTG_GUARD_RTGLIB_SHAPE_FOR_EACH_HPP
#include <rtg/shape.hpp>
#include <algorithm>
namespace rtg {
template<class F>
void shape_for_each(const rtg::shape& s, F f)
{
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
std::vector<std::size_t> indices(s.lens().size());
for(std::size_t i = 0;i < s.elements();i++) {
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) { return (i / stride) % len; });
call(indices);
}
}
} // namespace rtg
#endif
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <rtg/float_equal.hpp> #include <rtg/float_equal.hpp>
#include <rtg/requires.hpp>
#include <iostream> #include <iostream>
...@@ -25,18 +26,30 @@ struct tensor_view ...@@ -25,18 +26,30 @@ struct tensor_view
const T* data() const { return this->m_data; } const T* data() const { return this->m_data; }
template <class... Ts> template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const const T& operator()(Ts... xs) const
{ {
return m_data[m_shape.index({xs...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class... Ts> template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class Iterator, RTG_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const
{
return m_data[m_shape.index(start, last)];
}
template <class Iterator, RTG_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last)
{
return m_data[m_shape.index(start, last)];
}
T& operator[](std::size_t i) T& operator[](std::size_t i)
{ {
assert(!this->empty() && i < this->size()); assert(!this->empty() && i < this->size());
......
...@@ -63,7 +63,8 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const ...@@ -63,7 +63,8 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std::size_t shape::index(std::size_t i) const std::size_t shape::index(std::size_t i) const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product( if (this->packed()) return i;
else return std::inner_product(
this->lens().begin(), this->lens().begin(),
this->lens().end(), this->lens().end(),
this->strides().begin(), this->strides().begin(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment