Commit 5637af7b authored by Dean Moldovan's avatar Dean Moldovan Committed by Wenzel Jakob
Browse files

Add lightweight iterators for tuple, list and sequence

Slightly reduces binary size (range for loops over tuple/list benefit
a lot). The iterators are compatible with std algorithms.
parent 1fac1b9f
...@@ -499,24 +499,125 @@ struct tuple_item { ...@@ -499,24 +499,125 @@ struct tuple_item {
}; };
NAMESPACE_END(accessor_policies) NAMESPACE_END(accessor_policies)
struct dict_iterator { /// STL iterator template used for tuple, list, sequence and dict
template <typename Policy>
class generic_iterator : public Policy {
using It = generic_iterator;
public: public:
explicit dict_iterator(handle dict = handle(), ssize_t pos = -1) : dict(dict), pos(pos) { } using difference_type = ssize_t;
dict_iterator& operator++() { using iterator_category = typename Policy::iterator_category;
if (!PyDict_Next(dict.ptr(), &pos, &key.ptr(), &value.ptr())) using value_type = typename Policy::value_type;
pos = -1; using reference = typename Policy::reference;
return *this; using pointer = typename Policy::pointer;
}
std::pair<handle, handle> operator*() const { generic_iterator() = default;
return std::make_pair(key, value); generic_iterator(handle seq, ssize_t index) : Policy(seq, index) { }
}
bool operator==(const dict_iterator &it) const { return it.pos == pos; } reference operator*() const { return Policy::dereference(); }
bool operator!=(const dict_iterator &it) const { return it.pos != pos; } reference operator[](difference_type n) const { return *(*this + n); }
pointer operator->() const { return **this; }
It &operator++() { Policy::increment(); return *this; }
It operator++(int) { auto copy = *this; Policy::increment(); return copy; }
It &operator--() { Policy::decrement(); return *this; }
It operator--(int) { auto copy = *this; Policy::decrement(); return copy; }
It &operator+=(difference_type n) { Policy::advance(n); return *this; }
It &operator-=(difference_type n) { Policy::advance(-n); return *this; }
friend It operator+(const It &a, difference_type n) { auto copy = a; return copy += n; }
friend It operator+(difference_type n, const It &b) { return b + n; }
friend It operator-(const It &a, difference_type n) { auto copy = a; return copy -= n; }
friend difference_type operator-(const It &a, const It &b) { return a.distance_to(b); }
friend bool operator==(const It &a, const It &b) { return a.equal(b); }
friend bool operator!=(const It &a, const It &b) { return !(a == b); }
friend bool operator< (const It &a, const It &b) { return b - a > 0; }
friend bool operator> (const It &a, const It &b) { return b < a; }
friend bool operator>=(const It &a, const It &b) { return !(a < b); }
friend bool operator<=(const It &a, const It &b) { return !(a > b); }
};
NAMESPACE_BEGIN(iterator_policies)
/// Quick proxy class needed to implement ``operator->`` for iterators which can't return pointers
template <typename T>
struct arrow_proxy {
T value;
arrow_proxy(T &&value) : value(std::move(value)) { }
T *operator->() const { return &value; }
};
/// Lightweight iterator policy using just a simple pointer: see ``PySequence_Fast_ITEMS``
class sequence_fast_readonly {
protected:
using iterator_category = std::random_access_iterator_tag;
using value_type = handle;
using reference = const handle;
using pointer = arrow_proxy<const handle>;
sequence_fast_readonly(handle obj, ssize_t n) : ptr(PySequence_Fast_ITEMS(obj.ptr()) + n) { }
reference dereference() const { return *ptr; }
void increment() { ++ptr; }
void decrement() { --ptr; }
void advance(ssize_t n) { ptr += n; }
bool equal(const sequence_fast_readonly &b) const { return ptr == b.ptr; }
ssize_t distance_to(const sequence_fast_readonly &b) const { return ptr - b.ptr; }
private: private:
handle dict, key, value; PyObject **ptr;
ssize_t pos = 0;
}; };
/// Full read and write access using the sequence protocol: see ``detail::sequence_accessor``
class sequence_slow_readwrite {
protected:
using iterator_category = std::random_access_iterator_tag;
using value_type = object;
using reference = sequence_accessor;
using pointer = arrow_proxy<const sequence_accessor>;
sequence_slow_readwrite(handle obj, ssize_t index) : obj(obj), index(index) { }
reference dereference() const { return {obj, static_cast<size_t>(index)}; }
void increment() { ++index; }
void decrement() { --index; }
void advance(ssize_t n) { index += n; }
bool equal(const sequence_slow_readwrite &b) const { return index == b.index; }
ssize_t distance_to(const sequence_slow_readwrite &b) const { return index - b.index; }
private:
handle obj;
ssize_t index;
};
/// Python's dictionary protocol permits this to be a forward iterator
class dict_readonly {
protected:
using iterator_category = std::forward_iterator_tag;
using value_type = std::pair<handle, handle>;
using reference = const value_type;
using pointer = arrow_proxy<const value_type>;
dict_readonly() = default;
dict_readonly(handle obj, ssize_t pos) : obj(obj), pos(pos) { increment(); }
reference dereference() const { return {key, value}; }
void increment() { if (!PyDict_Next(obj.ptr(), &pos, &key, &value)) { pos = -1; } }
bool equal(const dict_readonly &b) const { return pos == b.pos; }
private:
handle obj;
PyObject *key, *value;
ssize_t pos = -1;
};
NAMESPACE_END(iterator_policies)
using tuple_iterator = generic_iterator<iterator_policies::sequence_fast_readonly>;
using list_iterator = generic_iterator<iterator_policies::sequence_fast_readonly>;
using sequence_iterator = generic_iterator<iterator_policies::sequence_slow_readwrite>;
using dict_iterator = generic_iterator<iterator_policies::dict_readonly>;
inline bool PyIterable_Check(PyObject *obj) { inline bool PyIterable_Check(PyObject *obj) {
PyObject *iter = PyObject_GetIter(obj); PyObject *iter = PyObject_GetIter(obj);
if (iter) { if (iter) {
...@@ -916,6 +1017,8 @@ public: ...@@ -916,6 +1017,8 @@ public:
} }
size_t size() const { return (size_t) PyTuple_Size(m_ptr); } size_t size() const { return (size_t) PyTuple_Size(m_ptr); }
detail::tuple_accessor operator[](size_t index) const { return {*this, index}; } detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }
detail::tuple_iterator begin() const { return {*this, 0}; }
detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }
}; };
class dict : public object { class dict : public object {
...@@ -931,8 +1034,8 @@ public: ...@@ -931,8 +1034,8 @@ public:
explicit dict(Args &&...args) : dict(collector(std::forward<Args>(args)...).kwargs()) { } explicit dict(Args &&...args) : dict(collector(std::forward<Args>(args)...).kwargs()) { }
size_t size() const { return (size_t) PyDict_Size(m_ptr); } size_t size() const { return (size_t) PyDict_Size(m_ptr); }
detail::dict_iterator begin() const { return (++detail::dict_iterator(*this, 0)); } detail::dict_iterator begin() const { return {*this, 0}; }
detail::dict_iterator end() const { return detail::dict_iterator(); } detail::dict_iterator end() const { return {}; }
void clear() const { PyDict_Clear(ptr()); } void clear() const { PyDict_Clear(ptr()); }
bool contains(handle key) const { return PyDict_Contains(ptr(), key.ptr()) == 1; } bool contains(handle key) const { return PyDict_Contains(ptr(), key.ptr()) == 1; }
bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; } bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; }
...@@ -948,9 +1051,11 @@ private: ...@@ -948,9 +1051,11 @@ private:
class sequence : public object { class sequence : public object {
public: public:
PYBIND11_OBJECT(sequence, object, PySequence_Check) PYBIND11_OBJECT_DEFAULT(sequence, object, PySequence_Check)
size_t size() const { return (size_t) PySequence_Size(m_ptr); } size_t size() const { return (size_t) PySequence_Size(m_ptr); }
detail::sequence_accessor operator[](size_t index) const { return {*this, index}; } detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
detail::sequence_iterator begin() const { return {*this, 0}; }
detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; }
}; };
class list : public object { class list : public object {
...@@ -961,6 +1066,8 @@ public: ...@@ -961,6 +1066,8 @@ public:
} }
size_t size() const { return (size_t) PyList_Size(m_ptr); } size_t size() const { return (size_t) PyList_Size(m_ptr); }
detail::list_accessor operator[](size_t index) const { return {*this, index}; } detail::list_accessor operator[](size_t index) const { return {*this, index}; }
detail::list_iterator begin() const { return {*this, 0}; }
detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; }
template <typename T> void append(T &&val) const { template <typename T> void append(T &&val) const {
PyList_Append(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()); PyList_Append(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr());
} }
......
...@@ -169,6 +169,47 @@ bool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentine ...@@ -169,6 +169,47 @@ bool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentine
return !(*it).first || !(*it).second; return !(*it).first || !(*it).second;
} }
template <typename PythonType>
py::list test_random_access_iterator(PythonType x) {
if (x.size() < 5)
throw py::value_error("Please provide at least 5 elements for testing.");
auto checks = py::list();
auto assert_equal = [&checks](py::handle a, py::handle b) {
auto result = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_EQ);
if (result == -1) { throw py::error_already_set(); }
checks.append(result != 0);
};
auto it = x.begin();
assert_equal(x[0], *it);
assert_equal(x[0], it[0]);
assert_equal(x[1], it[1]);
assert_equal(x[1], *(++it));
assert_equal(x[1], *(it++));
assert_equal(x[2], *it);
assert_equal(x[3], *(it += 1));
assert_equal(x[2], *(--it));
assert_equal(x[2], *(it--));
assert_equal(x[1], *it);
assert_equal(x[0], *(it -= 1));
assert_equal(it->attr("real"), x[0].attr("real"));
assert_equal((it + 1)->attr("real"), x[1].attr("real"));
assert_equal(x[1], *(it + 1));
assert_equal(x[1], *(1 + it));
it += 3;
assert_equal(x[1], *(it - 2));
checks.append(static_cast<std::size_t>(x.end() - x.begin()) == x.size());
checks.append((x.begin() + static_cast<std::ptrdiff_t>(x.size())) == x.end());
checks.append(x.begin() < x.end());
return checks;
}
test_initializer sequences_and_iterators([](py::module &pm) { test_initializer sequences_and_iterators([](py::module &pm) {
auto m = pm.def_submodule("sequences_and_iterators"); auto m = pm.def_submodule("sequences_and_iterators");
...@@ -300,4 +341,14 @@ test_initializer sequences_and_iterators([](py::module &pm) { ...@@ -300,4 +341,14 @@ test_initializer sequences_and_iterators([](py::module &pm) {
auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });
return it->is_none(); return it->is_none();
}); });
m.def("count_nonzeros", [](py::dict d) {
return std::count_if(d.begin(), d.end(), [](std::pair<py::handle, py::handle> p) {
return p.second.cast<int>() != 0;
});
});
m.def("tuple_iterator", [](py::tuple x) { return test_random_access_iterator(x); });
m.def("list_iterator", [](py::list x) { return test_random_access_iterator(x); });
m.def("sequence_iterator", [](py::sequence x) { return test_random_access_iterator(x); });
}); });
...@@ -117,3 +117,9 @@ def test_python_iterator_in_cpp(): ...@@ -117,3 +117,9 @@ def test_python_iterator_in_cpp():
l = [1, None, 0, None] l = [1, None, 0, None]
assert m.count_none(l) == 2 assert m.count_none(l) == 2
assert m.find_none(l) is True assert m.find_none(l) is True
assert m.count_nonzeros({"a": 0, "b": 1, "c": 2}) == 2
r = range(5)
assert all(m.tuple_iterator(tuple(r)))
assert all(m.list_iterator(list(r)))
assert all(m.sequence_iterator(r))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment