Unverified Commit 4d28180c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Make tensor view work with non-standard shapes (#712)



* Add initial iterator implementation

* Formatting

* Access index with bracket

* Add cppcheck suppression
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 404f416d
#ifndef MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <iterator>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class F, class Iterator = std::size_t>
struct iota_iterator
{
Iterator index;
F f;
using difference_type = std::ptrdiff_t;
using reference = decltype(f(std::declval<Iterator>()));
using value_type = typename std::remove_reference<reference>::type;
using pointer = typename std::add_pointer<value_type>::type;
using iterator_category = std::random_access_iterator_tag;
iota_iterator& operator+=(int n)
{
index += n;
return *this;
}
iota_iterator& operator-=(int n)
{
index -= n;
return *this;
}
iota_iterator& operator++()
{
index++;
return *this;
}
iota_iterator& operator--()
{
index--;
return *this;
}
iota_iterator operator++(int) // NOLINT
{
iota_iterator it = *this;
index++;
return it;
}
iota_iterator operator--(int) // NOLINT
{
iota_iterator it = *this;
index--;
return it;
}
// TODO: operator->
reference operator*() const { return f(index); }
};
template <class F, class Iterator>
inline iota_iterator<F, Iterator> operator+(iota_iterator<F, Iterator> x,
iota_iterator<F, Iterator> y)
{
return iota_iterator<F, Iterator>(x.index + y.index, x.f);
}
template <class F, class Iterator>
inline std::ptrdiff_t operator-(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index - y.index;
}
template <class F, class Iterator>
inline bool operator==(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator>
inline bool operator!=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
template <class F, class Iterator>
inline bool operator<(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index < y.index;
}
template <class F, class Iterator>
inline bool operator>(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index > y.index;
}
template <class F, class Iterator>
inline bool operator>=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index >= y.index;
}
template <class F, class Iterator>
inline bool operator<=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index <= y.index;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/iota_iterator.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <iostream> #include <iostream>
...@@ -20,10 +21,23 @@ T as_number(T x) ...@@ -20,10 +21,23 @@ T as_number(T x)
inline int32_t as_number(int8_t x) { return static_cast<int32_t>(x); } inline int32_t as_number(int8_t x) { return static_cast<int32_t>(x); }
inline uint32_t as_number(uint8_t x) { return static_cast<uint32_t>(x); } inline uint32_t as_number(uint8_t x) { return static_cast<uint32_t>(x); }
template <class T>
struct tensor_view_iterator_read
{
T* view;
auto& operator()(std::size_t n) const
{
assert(view != nullptr);
return (*view)[n];
}
};
template <class T> template <class T>
struct tensor_view struct tensor_view
{ {
using value_type = T; using value_type = T;
using iterator = iota_iterator<tensor_view_iterator_read<tensor_view<T>>>;
using const_iterator = iota_iterator<tensor_view_iterator_read<const tensor_view<T>>>;
tensor_view() : m_data(nullptr) {} tensor_view() : m_data(nullptr) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {} tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {}
...@@ -105,36 +119,15 @@ struct tensor_view ...@@ -105,36 +119,15 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)]; return m_data[m_shape.index(this->size() - 1)];
} }
// TODO: Add iterators so it can handle nonstandard tensors // cppcheck-suppress functionConst
T* begin() iterator begin() { return {0, {this}}; }
{
assert(this->m_shape.standard() or this->empty());
return m_data;
}
T* end() // cppcheck-suppress functionConst
{ iterator end() { return {this->size(), {this}}; }
assert(this->m_shape.standard() or this->empty());
if(this->empty())
return m_data;
else
return m_data + this->size();
}
const T* begin() const const_iterator begin() const { return {0, {this}}; }
{
assert(this->m_shape.standard() or this->empty());
return m_data;
}
const T* end() const const_iterator end() const { return {this->size(), {this}}; }
{
assert(this->m_shape.standard() or this->empty());
if(this->empty())
return m_data;
else
return m_data + this->size();
}
template <class U = T> template <class U = T>
std::vector<U> to_vector() const std::vector<U> to_vector() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment