Unverified Commit e8738144 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Remove vector of indices and use a lazy iota range instead (#838)

* Create lazy range

* Formatting

* Use lazy iota
parent c310bc5c
...@@ -2,14 +2,15 @@ ...@@ -2,14 +2,15 @@
#define MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP #define MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <iterator> #include <iterator>
#include <type_traits> #include <type_traits>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class F, class Iterator = std::size_t> template <class F, class Iterator = std::ptrdiff_t>
struct iota_iterator struct basic_iota_iterator
{ {
Iterator index; Iterator index;
F f; F f;
...@@ -20,40 +21,40 @@ struct iota_iterator ...@@ -20,40 +21,40 @@ struct iota_iterator
using pointer = typename std::add_pointer<value_type>::type; using pointer = typename std::add_pointer<value_type>::type;
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
iota_iterator& operator+=(int n) basic_iota_iterator& operator+=(int n)
{ {
index += n; index += n;
return *this; return *this;
} }
iota_iterator& operator-=(int n) basic_iota_iterator& operator-=(int n)
{ {
index -= n; index -= n;
return *this; return *this;
} }
iota_iterator& operator++() basic_iota_iterator& operator++()
{ {
index++; index++;
return *this; return *this;
} }
iota_iterator& operator--() basic_iota_iterator& operator--()
{ {
index--; index--;
return *this; return *this;
} }
iota_iterator operator++(int) // NOLINT basic_iota_iterator operator++(int) // NOLINT
{ {
iota_iterator it = *this; basic_iota_iterator it = *this;
index++; index++;
return it; return it;
} }
iota_iterator operator--(int) // NOLINT basic_iota_iterator operator--(int) // NOLINT
{ {
iota_iterator it = *this; basic_iota_iterator it = *this;
index--; index--;
return it; return it;
} }
...@@ -61,55 +62,71 @@ struct iota_iterator ...@@ -61,55 +62,71 @@ struct iota_iterator
reference operator*() const { return f(index); } reference operator*() const { return f(index); }
}; };
template <class T, class F>
inline basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
{
return basic_iota_iterator<F, T>{x, f};
}
template <class F, class Iterator>
inline basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x,
std::ptrdiff_t y)
{
return x += y;
}
template <class F, class Iterator> template <class F, class Iterator>
inline iota_iterator<F, Iterator> operator+(iota_iterator<F, Iterator> x, inline basic_iota_iterator<F, Iterator> operator+(std::ptrdiff_t x,
iota_iterator<F, Iterator> y) basic_iota_iterator<F, Iterator> y)
{ {
return iota_iterator<F, Iterator>(x.index + y.index, x.f); return y + x;
} }
template <class F, class Iterator> template <class F, class Iterator>
inline std::ptrdiff_t operator-(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x,
basic_iota_iterator<F, Iterator> y)
{ {
return x.index - y.index; return x.index - y.index;
} }
template <class F, class Iterator> template <class F, class Iterator>
inline bool operator==(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{ {
return x.index == y.index; return x.index == y.index;
} }
template <class F, class Iterator> template <class F, class Iterator>
inline bool operator!=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) inline bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{ {
return x.index != y.index; return x.index != y.index;
} }
template <class F, class Iterator> template <class F, class Iterator>
inline bool operator<(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) inline bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{ {
return x.index < y.index; return x.index < y.index;
} }
template <class F, class Iterator> template <class F, class Iterator>
inline bool operator>(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) inline bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{ {
return x.index > y.index; return x.index > y.index;
} }
template <class F, class Iterator> template <class F, class Iterator>
inline bool operator>=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) inline bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{ {
return x.index >= y.index; return x.index >= y.index;
} }
template <class F, class Iterator> template <class F, class Iterator>
inline bool operator<=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) inline bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{ {
return x.index <= y.index; return x.index <= y.index;
} }
using iota_iterator = basic_iota_iterator<id>;
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <vector> #include <vector>
#include <initializer_list> #include <initializer_list>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/iota_iterator.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
...@@ -208,12 +209,18 @@ struct iterator_range ...@@ -208,12 +209,18 @@ struct iterator_range
Iterator end() const { return last; } Iterator end() const { return last; }
}; };
template <class Iterator> template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
iterator_range<Iterator> range(Iterator start, Iterator last) iterator_range<Iterator> range(Iterator start, Iterator last)
{ {
return {start, last}; return {start, last};
} }
inline iterator_range<iota_iterator> range(std::ptrdiff_t start, std::ptrdiff_t last)
{
return {{start, {}}, {last, {}}};
}
inline iterator_range<iota_iterator> range(std::ptrdiff_t last) { return range(0, last); }
template <class Iterator> template <class Iterator>
iterator_range<Iterator> range(std::pair<Iterator, Iterator> p) iterator_range<Iterator> range(std::pair<Iterator, Iterator> p)
{ {
......
...@@ -35,9 +35,10 @@ struct tensor_view_iterator_read ...@@ -35,9 +35,10 @@ struct tensor_view_iterator_read
template <class T> template <class T>
struct tensor_view struct tensor_view
{ {
using value_type = T; using value_type = T;
using iterator = iota_iterator<tensor_view_iterator_read<tensor_view<T>>>; using iterator = basic_iota_iterator<tensor_view_iterator_read<tensor_view<T>>, std::size_t>;
using const_iterator = iota_iterator<tensor_view_iterator_read<const tensor_view<T>>>; using const_iterator =
basic_iota_iterator<tensor_view_iterator_read<const tensor_view<T>>, std::size_t>;
tensor_view() : m_data(nullptr) {} tensor_view() : m_data(nullptr) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {} tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {}
......
...@@ -376,9 +376,7 @@ struct find_resize ...@@ -376,9 +376,7 @@ struct find_resize
return; return;
} }
arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); }); arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); });
std::vector<int> index(out_shape.elements()); if(not all_of(range(out_shape.elements()), [&](auto i) {
std::iota(index.begin(), index.end(), 0);
if(not std::all_of(index.begin(), index.end(), [&](auto i) {
auto out_idx = out_shape.multi(i); auto out_idx = out_shape.multi(i);
auto in_idx = out_idx; auto in_idx = out_idx;
std::transform(out_idx.begin(), std::transform(out_idx.begin(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment