test_sequences_and_iterators.cpp 20.7 KB
Newer Older
Wenzel Jakob's avatar
Wenzel Jakob committed
1
/*
Dean Moldovan's avatar
Dean Moldovan committed
2
    tests/test_sequences_and_iterators.cpp -- supporting Pythons' sequence protocol, iterators,
3
    etc.
Wenzel Jakob's avatar
Wenzel Jakob committed
4

5
    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
Wenzel Jakob's avatar
Wenzel Jakob committed
6
7
8
9
10

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

11
12
#include <pybind11/operators.h>
#include <pybind11/stl.h>
Wenzel Jakob's avatar
Wenzel Jakob committed
13

14
15
16
#include "constructor_stats.h"
#include "pybind11_tests.h"

17
#include <algorithm>
18
#include <utility>
19
#include <vector>
20

21
#ifdef PYBIND11_HAS_OPTIONAL
22
23
#    include <optional>
#endif // PYBIND11_HAS_OPTIONAL
24

25
template <typename T>
26
class NonZeroIterator {
27
28
    const T *ptr_;

29
public:
30
    explicit NonZeroIterator(const T *ptr) : ptr_(ptr) {}
31
32
33
34
35
    const T &operator*() const { return *ptr_; }
    NonZeroIterator &operator++() {
        ++ptr_;
        return *this;
    }
36
37
38
39
};

class NonZeroSentinel {};

40
41
template <typename A, typename B>
bool operator==(const NonZeroIterator<std::pair<A, B>> &it, const NonZeroSentinel &) {
42
43
    return !(*it).first || !(*it).second;
}
44

45
/* Iterator where dereferencing returns prvalues instead of references. */
46
template <typename T>
47
class NonRefIterator {
48
49
    const T *ptr_;

50
51
52
public:
    explicit NonRefIterator(const T *ptr) : ptr_(ptr) {}
    T operator*() const { return T(*ptr_); }
53
54
55
56
    NonRefIterator &operator++() {
        ++ptr_;
        return *this;
    }
57
58
59
    bool operator==(const NonRefIterator &other) const { return ptr_ == other.ptr_; }
};

60
61
62
63
64
class NonCopyableInt {
public:
    explicit NonCopyableInt(int value) : value_(value) {}
    NonCopyableInt(const NonCopyableInt &) = delete;
    NonCopyableInt(NonCopyableInt &&other) noexcept : value_(other.value_) {
65
        other.value_ = -1; // detect when an unwanted move occurs
66
67
68
69
    }
    NonCopyableInt &operator=(const NonCopyableInt &) = delete;
    NonCopyableInt &operator=(NonCopyableInt &&other) noexcept {
        value_ = other.value_;
70
        other.value_ = -1; // detect when an unwanted move occurs
71
72
73
74
75
        return *this;
    }
    int get() const { return value_; }
    void set(int value) { value_ = value; }
    ~NonCopyableInt() = default;
76

77
78
79
80
81
82
83
private:
    int value_;
};
using NonCopyableIntPair = std::pair<NonCopyableInt, NonCopyableInt>;
PYBIND11_MAKE_OPAQUE(std::vector<NonCopyableInt>);
PYBIND11_MAKE_OPAQUE(std::vector<NonCopyableIntPair>);

84
85
template <typename PythonType>
py::list test_random_access_iterator(PythonType x) {
86
    if (x.size() < 5) {
87
        throw py::value_error("Please provide at least 5 elements for testing.");
88
    }
89
90
91
92

    auto checks = py::list();
    auto assert_equal = [&checks](py::handle a, py::handle b) {
        auto result = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_EQ);
93
94
95
        if (result == -1) {
            throw py::error_already_set();
        }
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        checks.append(result != 0);
    };

    auto it = x.begin();
    assert_equal(x[0], *it);
    assert_equal(x[0], it[0]);
    assert_equal(x[1], it[1]);

    assert_equal(x[1], *(++it));
    assert_equal(x[1], *(it++));
    assert_equal(x[2], *it);
    assert_equal(x[3], *(it += 1));
    assert_equal(x[2], *(--it));
    assert_equal(x[2], *(it--));
    assert_equal(x[1], *it);
    assert_equal(x[0], *(it -= 1));

    assert_equal(it->attr("real"), x[0].attr("real"));
    assert_equal((it + 1)->attr("real"), x[1].attr("real"));

    assert_equal(x[1], *(it + 1));
    assert_equal(x[1], *(1 + it));
    it += 3;
    assert_equal(x[1], *(it - 2));

    checks.append(static_cast<std::size_t>(x.end() - x.begin()) == x.size());
    checks.append((x.begin() + static_cast<std::ptrdiff_t>(x.size())) == x.end());
    checks.append(x.begin() < x.end());

    return checks;
}

128
TEST_SUBMODULE(sequences_and_iterators, m) {
129
    // test_sliceable
130
    class Sliceable {
131
    public:
132
133
134
        explicit Sliceable(int n) : size(n) {}
        int start, stop, step;
        int size;
135
    };
136
    py::class_<Sliceable>(m, "Sliceable")
137
        .def(py::init<int>())
138
        .def("__getitem__", [](const Sliceable &s, const py::slice &slice) {
139
            py::ssize_t start = 0, stop = 0, step = 0, slicelength = 0;
140
            if (!slice.compute(s.size, &start, &stop, &step, &slicelength)) {
141
                throw py::error_already_set();
142
            }
143
            int istart = static_cast<int>(start);
144
145
            int istop = static_cast<int>(stop);
            int istep = static_cast<int>(step);
146
147
            return std::make_tuple(istart, istop, istep);
        });
148

149
    m.def("make_forward_slice_size_t", []() { return py::slice(0, -1, 1); });
150
151
    m.def("make_reversed_slice_object",
          []() { return py::slice(py::none(), py::none(), py::int_(-1)); });
152
153
#ifdef PYBIND11_HAS_OPTIONAL
    m.attr("has_optional") = true;
154
155
156
157
    m.def("make_reversed_slice_size_t_optional_verbose",
          []() { return py::slice(std::nullopt, std::nullopt, -1); });
    // Warning: The following spelling may still compile if optional<> is not present and give
    // wrong answers. Please use with caution.
158
159
160
161
162
    m.def("make_reversed_slice_size_t_optional", []() { return py::slice({}, {}, -1); });
#else
    m.attr("has_optional") = false;
#endif

163
164
165
    // test_sequence
    class Sequence {
    public:
166
        explicit Sequence(size_t size) : m_size(size) {
167
            print_created(this, "of size", m_size);
168
            // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
169
170
171
            m_data = new float[size];
            memset(m_data, 0, sizeof(float) * size);
        }
172
        explicit Sequence(const std::vector<float> &value) : m_size(value.size()) {
173
            print_created(this, "of size", m_size, "from std::vector");
174
            // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
175
176
177
178
179
            m_data = new float[m_size];
            memcpy(m_data, &value[0], sizeof(float) * m_size);
        }
        Sequence(const Sequence &s) : m_size(s.m_size) {
            print_copy_created(this);
180
            // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
181
            m_data = new float[m_size];
182
            memcpy(m_data, s.m_data, sizeof(float) * m_size);
183
        }
184
        Sequence(Sequence &&s) noexcept : m_size(s.m_size), m_data(s.m_data) {
185
186
187
188
189
            print_move_created(this);
            s.m_size = 0;
            s.m_data = nullptr;
        }

190
191
192
193
        ~Sequence() {
            print_destroyed(this);
            delete[] m_data;
        }
194
195
196
197
198
199

        Sequence &operator=(const Sequence &s) {
            if (&s != this) {
                delete[] m_data;
                m_size = s.m_size;
                m_data = new float[m_size];
200
                memcpy(m_data, s.m_data, sizeof(float) * m_size);
201
202
203
204
205
            }
            print_copy_assigned(this);
            return *this;
        }

206
        Sequence &operator=(Sequence &&s) noexcept {
207
208
209
210
211
212
213
214
215
216
217
218
            if (&s != this) {
                delete[] m_data;
                m_size = s.m_size;
                m_data = s.m_data;
                s.m_size = 0;
                s.m_data = nullptr;
            }
            print_move_assigned(this);
            return *this;
        }

        bool operator==(const Sequence &s) const {
219
220
221
222
223
            if (m_size != s.size()) {
                return false;
            }
            for (size_t i = 0; i < m_size; ++i) {
                if (m_data[i] != s[i]) {
224
                    return false;
225
226
                }
            }
227
228
229
230
231
232
            return true;
        }
        bool operator!=(const Sequence &s) const { return !operator==(s); }

        float operator[](size_t index) const { return m_data[index]; }
        float &operator[](size_t index) { return m_data[index]; }
Wenzel Jakob's avatar
Wenzel Jakob committed
233

234
        bool contains(float v) const {
235
236
            for (size_t i = 0; i < m_size; ++i) {
                if (v == m_data[i]) {
237
                    return true;
238
239
                }
            }
240
241
242
243
244
            return false;
        }

        Sequence reversed() const {
            Sequence result(m_size);
245
            for (size_t i = 0; i < m_size; ++i) {
246
                result[m_size - i - 1] = m_data[i];
247
            }
248
249
250
251
252
253
            return result;
        }

        size_t size() const { return m_size; }

        const float *begin() const { return m_data; }
254
        const float *end() const { return m_data + m_size; }
255
256
257
258
259
260
261

    private:
        size_t m_size;
        float *m_data;
    };
    py::class_<Sequence>(m, "Sequence")
        .def(py::init<size_t>())
262
        .def(py::init<const std::vector<float> &>())
263
        /// Bare bones interface
264
265
        .def("__getitem__",
             [](const Sequence &s, size_t i) {
266
                 if (i >= s.size()) {
267
                     throw py::index_error();
268
                 }
269
270
271
272
                 return s[i];
             })
        .def("__setitem__",
             [](Sequence &s, size_t i, float v) {
273
                 if (i >= s.size()) {
274
                     throw py::index_error();
275
                 }
276
277
                 s[i] = v;
             })
278
279
        .def("__len__", &Sequence::size)
        /// Optional sequence protocol operations
280
281
282
283
        .def(
            "__iter__",
            [](const Sequence &s) { return py::make_iterator(s.begin(), s.end()); },
            py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */)
284
285
286
        .def("__contains__", [](const Sequence &s, float v) { return s.contains(v); })
        .def("__reversed__", [](const Sequence &s) -> Sequence { return s.reversed(); })
        /// Slicing protocol (optional)
287
288
        .def("__getitem__",
             [](const Sequence &s, const py::slice &slice) -> Sequence * {
289
                 size_t start = 0, stop = 0, step = 0, slicelength = 0;
290
                 if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) {
291
                     throw py::error_already_set();
292
                 }
293
294
295
296
297
298
299
300
301
                 auto *seq = new Sequence(slicelength);
                 for (size_t i = 0; i < slicelength; ++i) {
                     (*seq)[i] = s[start];
                     start += step;
                 }
                 return seq;
             })
        .def("__setitem__",
             [](Sequence &s, const py::slice &slice, const Sequence &value) {
302
                 size_t start = 0, stop = 0, step = 0, slicelength = 0;
303
                 if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) {
304
                     throw py::error_already_set();
305
306
                 }
                 if (slicelength != value.size()) {
307
308
                     throw std::runtime_error(
                         "Left and right hand size of slice assignment have different sizes!");
309
                 }
310
311
312
313
314
                 for (size_t i = 0; i < slicelength; ++i) {
                     s[start] = value[i];
                     start += step;
                 }
             })
315
316
317
318
319
        /// Comparisons
        .def(py::self == py::self)
        .def(py::self != py::self)
        // Could also define py::self + py::self for concatenation, etc.
        ;
320

321
    // test_map_iterator
322
323
    // Interface of a map-like object that isn't (directly) an unordered_map, but provides some
    // basic map-like functionality.
324
325
326
    class StringMap {
    public:
        StringMap() = default;
327
        explicit StringMap(std::unordered_map<std::string, std::string> init)
328
329
            : map(std::move(init)) {}

330
331
        void set(const std::string &key, std::string val) { map[key] = std::move(val); }
        std::string get(const std::string &key) const { return map.at(key); }
332
        size_t size() const { return map.size(); }
333

334
335
    private:
        std::unordered_map<std::string, std::string> map;
336

337
338
339
340
341
342
    public:
        decltype(map.cbegin()) begin() const { return map.cbegin(); }
        decltype(map.cend()) end() const { return map.cend(); }
    };
    py::class_<StringMap>(m, "StringMap")
        .def(py::init<>())
343
        .def(py::init<std::unordered_map<std::string, std::string>>())
344
345
346
347
348
349
350
351
        .def("__getitem__",
             [](const StringMap &map, const std::string &key) {
                 try {
                     return map.get(key);
                 } catch (const std::out_of_range &) {
                     throw py::key_error("key '" + key + "' does not exist");
                 }
             })
352
353
        .def("__setitem__", &StringMap::set)
        .def("__len__", &StringMap::size)
354
355
356
357
358
359
360
        .def(
            "__iter__",
            [](const StringMap &map) { return py::make_key_iterator(map.begin(), map.end()); },
            py::keep_alive<0, 1>())
        .def(
            "items",
            [](const StringMap &map) { return py::make_iterator(map.begin(), map.end()); },
361
362
363
364
            py::keep_alive<0, 1>())
        .def(
            "values",
            [](const StringMap &map) { return py::make_value_iterator(map.begin(), map.end()); },
365
            py::keep_alive<0, 1>());
366

367
368
369
    // test_generalized_iterators
    class IntPairs {
    public:
370
        explicit IntPairs(std::vector<std::pair<int, int>> data) : data_(std::move(data)) {}
371
        const std::pair<int, int> *begin() const { return data_.data(); }
372
        // .end() only required for py::make_iterator(self) overload
373
374
        const std::pair<int, int> *end() const { return data_.data() + data_.size(); }

375
376
377
    private:
        std::vector<std::pair<int, int>> data_;
    };
378
379
    py::class_<IntPairs>(m, "IntPairs")
        .def(py::init<std::vector<std::pair<int, int>>>())
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
        .def(
            "nonzero",
            [](const IntPairs &s) {
                return py::make_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()),
                                         NonZeroSentinel());
            },
            py::keep_alive<0, 1>())
        .def(
            "nonzero_keys",
            [](const IntPairs &s) {
                return py::make_key_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()),
                                             NonZeroSentinel());
            },
            py::keep_alive<0, 1>())
        .def(
            "nonzero_values",
            [](const IntPairs &s) {
                return py::make_value_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()),
                                               NonZeroSentinel());
            },
            py::keep_alive<0, 1>())
401

402
        // test iterator that returns values instead of references
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        .def(
            "nonref",
            [](const IntPairs &s) {
                return py::make_iterator(NonRefIterator<std::pair<int, int>>(s.begin()),
                                         NonRefIterator<std::pair<int, int>>(s.end()));
            },
            py::keep_alive<0, 1>())
        .def(
            "nonref_keys",
            [](const IntPairs &s) {
                return py::make_key_iterator(NonRefIterator<std::pair<int, int>>(s.begin()),
                                             NonRefIterator<std::pair<int, int>>(s.end()));
            },
            py::keep_alive<0, 1>())
        .def(
            "nonref_values",
            [](const IntPairs &s) {
                return py::make_value_iterator(NonRefIterator<std::pair<int, int>>(s.begin()),
                                               NonRefIterator<std::pair<int, int>>(s.end()));
            },
            py::keep_alive<0, 1>())
424

425
        // test single-argument make_iterator
426
427
428
429
430
431
432
433
434
435
436
437
        .def(
            "simple_iterator",
            [](IntPairs &self) { return py::make_iterator(self); },
            py::keep_alive<0, 1>())
        .def(
            "simple_keys",
            [](IntPairs &self) { return py::make_key_iterator(self); },
            py::keep_alive<0, 1>())
        .def(
            "simple_values",
            [](IntPairs &self) { return py::make_value_iterator(self); },
            py::keep_alive<0, 1>())
438

439
440
441
        // Test iterator with an Extra (doesn't do anything useful, so not used
        // at runtime, but tests need to be able to compile with the correct
        // overload. See PR #3293.
442
443
444
445
446
447
448
449
450
451
452
453
        .def(
            "_make_iterator_extras",
            [](IntPairs &self) { return py::make_iterator(self, py::call_guard<int>()); },
            py::keep_alive<0, 1>())
        .def(
            "_make_key_extras",
            [](IntPairs &self) { return py::make_key_iterator(self, py::call_guard<int>()); },
            py::keep_alive<0, 1>())
        .def(
            "_make_value_extras",
            [](IntPairs &self) { return py::make_value_iterator(self, py::call_guard<int>()); },
            py::keep_alive<0, 1>());
454

Chris Ohk's avatar
Chris Ohk committed
455
    // test_iterator_referencing
456
457
458
    py::class_<NonCopyableInt>(m, "NonCopyableInt")
        .def(py::init<int>())
        .def("set", &NonCopyableInt::set)
459
        .def("__int__", &NonCopyableInt::get);
460
461
    py::class_<std::vector<NonCopyableInt>>(m, "VectorNonCopyableInt")
        .def(py::init<>())
462
463
        .def("append",
             [](std::vector<NonCopyableInt> &vec, int value) { vec.emplace_back(value); })
464
465
        .def("__iter__", [](std::vector<NonCopyableInt> &vec) {
            return py::make_iterator(vec.begin(), vec.end());
466
        });
467
468
    py::class_<std::vector<NonCopyableIntPair>>(m, "VectorNonCopyableIntPair")
        .def(py::init<>())
469
470
471
472
473
474
475
476
        .def("append",
             [](std::vector<NonCopyableIntPair> &vec, const std::pair<int, int> &value) {
                 vec.emplace_back(NonCopyableInt(value.first), NonCopyableInt(value.second));
             })
        .def("keys",
             [](std::vector<NonCopyableIntPair> &vec) {
                 return py::make_key_iterator(vec.begin(), vec.end());
             })
477
478
        .def("values", [](std::vector<NonCopyableIntPair> &vec) {
            return py::make_value_iterator(vec.begin(), vec.end());
479
        });
480

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
#if 0
    // Obsolete: special data structure for exposing custom iterator types to python
    // kept here for illustrative purposes because there might be some use cases which
    // are not covered by the much simpler py::make_iterator

    struct PySequenceIterator {
        PySequenceIterator(const Sequence &seq, py::object ref) : seq(seq), ref(ref) { }

        float next() {
            if (index == seq.size())
                throw py::stop_iteration();
            return seq[index++];
        }

        const Sequence &seq;
        py::object ref; // keep a reference
        size_t index = 0;
    };

Wenzel Jakob's avatar
Wenzel Jakob committed
500
501
502
    py::class_<PySequenceIterator>(seq, "Iterator")
        .def("__iter__", [](PySequenceIterator &it) -> PySequenceIterator& { return it; })
        .def("__next__", &PySequenceIterator::next);
503
504
505
506

    On the actual Sequence object, the iterator would be constructed as follows:
    .def("__iter__", [](py::object s) { return PySequenceIterator(s.cast<const Sequence &>(), s); })
#endif
Dean Moldovan's avatar
Dean Moldovan committed
507

508
    // test_python_iterator_in_cpp
509
    m.def("object_to_list", [](const py::object &o) {
Dean Moldovan's avatar
Dean Moldovan committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        auto l = py::list();
        for (auto item : o) {
            l.append(item);
        }
        return l;
    });

    m.def("iterator_to_list", [](py::iterator it) {
        auto l = py::list();
        while (it != py::iterator::sentinel()) {
            l.append(*it);
            ++it;
        }
        return l;
    });
525

526
    // test_sequence_length: check that Python sequences can be converted to py::sequence.
527
    m.def("sequence_length", [](const py::sequence &seq) { return seq.size(); });
528

529
    // Make sure that py::iterator works with std algorithms
530
    m.def("count_none", [](const py::object &o) {
531
532
533
        return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });
    });

534
    m.def("find_none", [](const py::object &o) {
535
536
537
        auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });
        return it->is_none();
    });
538

539
540
541
542
    m.def("count_nonzeros", [](const py::dict &d) {
        return std::count_if(d.begin(), d.end(), [](std::pair<py::handle, py::handle> p) {
            return p.second.cast<int>() != 0;
        });
543
544
    });

545
546
547
    m.def("tuple_iterator", &test_random_access_iterator<py::tuple>);
    m.def("list_iterator", &test_random_access_iterator<py::list>);
    m.def("sequence_iterator", &test_random_access_iterator<py::sequence>);
548

549
    // test_iterator_passthrough
550
551
552
553
554
    // #181: iterator passthrough did not compile
    m.def("iterator_passthrough", [](py::iterator s) -> py::iterator {
        return py::make_iterator(std::begin(s), std::end(s));
    });

555
    // test_iterator_rvp
556
    // #388: Can't make iterators via make_iterator() with different r/v policies
557
558
559
560
561
    static std::vector<int> list = {1, 2, 3};
    m.def("make_iterator_1",
          []() { return py::make_iterator<py::return_value_policy::copy>(list); });
    m.def("make_iterator_2",
          []() { return py::make_iterator<py::return_value_policy::automatic>(list); });
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580

    // test_iterator on c arrays
    // #4100: ensure lvalue required as increment operand
    class CArrayHolder {
    public:
        CArrayHolder(double x, double y, double z) {
            values[0] = x;
            values[1] = y;
            values[2] = z;
        };
        double values[3];
    };

    py::class_<CArrayHolder>(m, "CArrayHolder")
        .def(py::init<double, double, double>())
        .def(
            "__iter__",
            [](const CArrayHolder &v) { return py::make_iterator(v.values, v.values + 3); },
            py::keep_alive<0, 1>());
581
}