Commit 95549529 authored by Paul's avatar Paul
Browse files

Merge branch 'shape-for-each'

parents cbf4c8d6 01e21e12
......@@ -106,6 +106,8 @@ rocm_enable_cppcheck(
${CMAKE_CURRENT_SOURCE_DIR}/src/include
${CMAKE_CURRENT_SOURCE_DIR}/src/targets/cpu/include
${CMAKE_CURRENT_SOURCE_DIR}/src/targets/miopen/include
DEFINE
CPPCHECK=1
)
add_subdirectory(src)
......
......@@ -3,11 +3,10 @@
#define RTG_GUARD_RAW_DATA_HPP
#include <rtg/tensor_view.hpp>
#include <rtg/requires.hpp>
namespace rtg {
#define RTG_REQUIRES(...) class = typename std::enable_if<(__VA_ARGS__)>::type
struct raw_data_base
{
};
......
#ifndef RTG_GUARD_RTGLIB_REQUIRES_HPP
#define RTG_GUARD_RTGLIB_REQUIRES_HPP
#include <type_traits>
namespace rtg {
template <bool... Bs>
struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
{
};
#ifdef CPPCHECK
#define RTG_REQUIRES(...) class = void
#else
#define RTG_REQUIRES(...) class = typename std::enable_if<and_<__VA_ARGS__, true>{}>::type
#endif
} // namespace rtg
#endif
......@@ -4,6 +4,7 @@
#include <vector>
#include <cassert>
#include <ostream>
#include <numeric>
#include <rtg/errors.hpp>
......@@ -61,6 +62,14 @@ struct shape
std::size_t index(std::initializer_list<std::size_t> l) const;
std::size_t index(const std::vector<std::size_t>& l) const;
template <class Iterator>
std::size_t index(Iterator start, Iterator last) const
{
assert(std::distance(start, last) <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(start, last, this->strides().begin(), std::size_t{0});
}
// Map element index to space index
std::size_t index(std::size_t i) const;
......
#ifndef RTG_GUARD_RTGLIB_SHAPE_FOR_EACH_HPP
#define RTG_GUARD_RTGLIB_SHAPE_FOR_EACH_HPP
#include <rtg/shape.hpp>
#include <algorithm>
namespace rtg {
template <class F>
void shape_for_each(const rtg::shape& s, F f)
{
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
std::vector<std::size_t> indices(s.lens().size());
for(std::size_t i = 0; i < s.elements(); i++)
{
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) { return (i / stride) % len; });
call(indices);
}
}
} // namespace rtg
#endif
......@@ -3,6 +3,7 @@
#include <rtg/shape.hpp>
#include <rtg/float_equal.hpp>
#include <rtg/requires.hpp>
#include <iostream>
......@@ -25,18 +26,30 @@ struct tensor_view
const T* data() const { return this->m_data; }
template <class... Ts>
template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const
{
return m_data[m_shape.index({xs...})];
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class... Ts>
template <class... Ts, RTG_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs)
{
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class Iterator, RTG_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const
{
return m_data[m_shape.index(start, last)];
}
template <class Iterator, RTG_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last)
{
return m_data[m_shape.index(start, last)];
}
T& operator[](std::size_t i)
{
assert(!this->empty() && i < this->size());
......
......@@ -63,13 +63,16 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std::size_t shape::index(std::size_t i) const
{
assert(this->lens().size() == this->strides().size());
return std::inner_product(
this->lens().begin(),
this->lens().end(),
this->strides().begin(),
std::size_t{0},
std::plus<std::size_t>{},
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len) * stride; });
if(this->packed())
return i;
else
return std::inner_product(
this->lens().begin(),
this->lens().end(),
this->strides().begin(),
std::size_t{0},
std::plus<std::size_t>{},
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len) * stride; });
}
bool shape::packed() const { return this->m_packed; }
std::size_t shape::element_space() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment