Unverified Commit e8738144 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Remove vector of indices and use a lazy iota range instead (#838)

* Create lazy range

* Formatting

* Use lazy iota
parent c310bc5c
......@@ -2,14 +2,15 @@
#define MIGRAPHX_GUARD_RTGLIB_IOTA_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <iterator>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class F, class Iterator = std::size_t>
struct iota_iterator
template <class F, class Iterator = std::ptrdiff_t>
struct basic_iota_iterator
{
Iterator index;
F f;
......@@ -20,40 +21,40 @@ struct iota_iterator
using pointer = typename std::add_pointer<value_type>::type;
using iterator_category = std::random_access_iterator_tag;
iota_iterator& operator+=(int n)
basic_iota_iterator& operator+=(int n)
{
index += n;
return *this;
}
iota_iterator& operator-=(int n)
basic_iota_iterator& operator-=(int n)
{
index -= n;
return *this;
}
iota_iterator& operator++()
basic_iota_iterator& operator++()
{
index++;
return *this;
}
iota_iterator& operator--()
basic_iota_iterator& operator--()
{
index--;
return *this;
}
iota_iterator operator++(int) // NOLINT
basic_iota_iterator operator++(int) // NOLINT
{
iota_iterator it = *this;
basic_iota_iterator it = *this;
index++;
return it;
}
iota_iterator operator--(int) // NOLINT
basic_iota_iterator operator--(int) // NOLINT
{
iota_iterator it = *this;
basic_iota_iterator it = *this;
index--;
return it;
}
......@@ -61,55 +62,71 @@ struct iota_iterator
reference operator*() const { return f(index); }
};
template <class T, class F>
inline basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
{
return basic_iota_iterator<F, T>{x, f};
}
template <class F, class Iterator>
inline basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x,
std::ptrdiff_t y)
{
return x += y;
}
template <class F, class Iterator>
inline iota_iterator<F, Iterator> operator+(iota_iterator<F, Iterator> x,
iota_iterator<F, Iterator> y)
inline basic_iota_iterator<F, Iterator> operator+(std::ptrdiff_t x,
basic_iota_iterator<F, Iterator> y)
{
return iota_iterator<F, Iterator>(x.index + y.index, x.f);
return y + x;
}
template <class F, class Iterator>
inline std::ptrdiff_t operator-(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x,
basic_iota_iterator<F, Iterator> y)
{
return x.index - y.index;
}
template <class F, class Iterator>
inline bool operator==(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator>
inline bool operator!=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
inline bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
template <class F, class Iterator>
inline bool operator<(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
inline bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index < y.index;
}
template <class F, class Iterator>
inline bool operator>(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
inline bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index > y.index;
}
template <class F, class Iterator>
inline bool operator>=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
inline bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index >= y.index;
}
template <class F, class Iterator>
inline bool operator<=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
inline bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index <= y.index;
}
using iota_iterator = basic_iota_iterator<id>;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -5,6 +5,7 @@
#include <vector>
#include <initializer_list>
#include <migraphx/rank.hpp>
#include <migraphx/iota_iterator.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/requires.hpp>
......@@ -208,12 +209,18 @@ struct iterator_range
Iterator end() const { return last; }
};
template <class Iterator>
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
iterator_range<Iterator> range(Iterator start, Iterator last)
{
return {start, last};
}
inline iterator_range<iota_iterator> range(std::ptrdiff_t start, std::ptrdiff_t last)
{
return {{start, {}}, {last, {}}};
}
inline iterator_range<iota_iterator> range(std::ptrdiff_t last) { return range(0, last); }
template <class Iterator>
iterator_range<Iterator> range(std::pair<Iterator, Iterator> p)
{
......
......@@ -35,9 +35,10 @@ struct tensor_view_iterator_read
template <class T>
struct tensor_view
{
using value_type = T;
using iterator = iota_iterator<tensor_view_iterator_read<tensor_view<T>>>;
using const_iterator = iota_iterator<tensor_view_iterator_read<const tensor_view<T>>>;
using value_type = T;
using iterator = basic_iota_iterator<tensor_view_iterator_read<tensor_view<T>>, std::size_t>;
using const_iterator =
basic_iota_iterator<tensor_view_iterator_read<const tensor_view<T>>, std::size_t>;
tensor_view() : m_data(nullptr) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(std::move(s)) {}
......
......@@ -376,9 +376,7 @@ struct find_resize
return;
}
arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); });
std::vector<int> index(out_shape.elements());
std::iota(index.begin(), index.end(), 0);
if(not std::all_of(index.begin(), index.end(), [&](auto i) {
if(not all_of(range(out_shape.elements()), [&](auto i) {
auto out_idx = out_shape.multi(i);
auto in_idx = out_idx;
std::transform(out_idx.begin(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment