test_sequences_and_iterators.py 7.3 KB
Newer Older
1
# -*- coding: utf-8 -*-
Dean Moldovan's avatar
Dean Moldovan committed
2
import pytest
3

4
from pybind11_tests import ConstructorStats
5
from pybind11_tests import sequences_and_iterators as m
Dean Moldovan's avatar
Dean Moldovan committed
6
7
8


def isclose(a, b, rel_tol=1e-05, abs_tol=0.0):
9
    """Like math.isclose() from Python 3.5"""
Dean Moldovan's avatar
Dean Moldovan committed
10
11
12
13
    return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)


def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0):
14
15
16
    return all(
        isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) for a, b in zip(a_list, b_list)
    )
Dean Moldovan's avatar
Dean Moldovan committed
17
18


19
20
21
22
23
24
25
26
27
28
29
def test_slice_constructors():
    assert m.make_forward_slice_size_t() == slice(0, -1, 1)
    assert m.make_reversed_slice_object() == slice(None, None, -1)


@pytest.mark.skipif(not m.has_optional, reason="no <optional>")
def test_slice_constructors_explicit_optional():
    assert m.make_reversed_slice_size_t_optional() == slice(None, None, -1)
    assert m.make_reversed_slice_size_t_optional_verbose() == slice(None, None, -1)


30
def test_generalized_iterators():
31
32
33
    assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)]
    assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)]
    assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero()) == []
34

35
36
37
    assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero_keys()) == [1, 3]
    assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1]
    assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == []
38

Bruce Merry's avatar
Bruce Merry committed
39
40
41
42
    assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero_values()) == [2, 4]
    assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_values()) == [2]
    assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_values()) == []

43
    # __next__ must continue to raise StopIteration
44
    it = m.IntPairs([(0, 0)]).nonzero()
45
46
47
48
    for _ in range(3):
        with pytest.raises(StopIteration):
            next(it)

49
    it = m.IntPairs([(0, 0)]).nonzero_keys()
50
51
52
53
    for _ in range(3):
        with pytest.raises(StopIteration):
            next(it)

54

Bruce Merry's avatar
Bruce Merry committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def test_iterator_referencing():
    """Test that iterators reference rather than copy their referents."""
    vec = m.VectorNonCopyableInt()
    vec.append(3)
    vec.append(5)
    assert [int(x) for x in vec] == [3, 5]
    # Increment everything to make sure the referents can be mutated
    for x in vec:
        x.set(int(x) + 1)
    assert [int(x) for x in vec] == [4, 6]

    vec = m.VectorNonCopyableIntPair()
    vec.append([3, 4])
    vec.append([5, 7])
    assert [int(x) for x in vec.keys()] == [3, 5]
    assert [int(x) for x in vec.values()] == [4, 7]
    for x in vec.keys():
        x.set(int(x) + 1)
    for x in vec.values():
        x.set(int(x) + 10)
    assert [int(x) for x in vec.keys()] == [4, 6]
    assert [int(x) for x in vec.values()] == [14, 17]


79
80
81
82
83
84
85
86
87
88
89
90
91
def test_sliceable():
    sliceable = m.Sliceable(100)
    assert sliceable[::] == (0, 100, 1)
    assert sliceable[10::] == (10, 100, 1)
    assert sliceable[:10:] == (0, 10, 1)
    assert sliceable[::10] == (0, 100, 10)
    assert sliceable[-10::] == (90, 100, 1)
    assert sliceable[:-10:] == (0, 90, 1)
    assert sliceable[::-10] == (99, -1, -10)
    assert sliceable[50:60:1] == (50, 60, 1)
    assert sliceable[50:60:-1] == (50, 60, -1)


Dean Moldovan's avatar
Dean Moldovan committed
92
def test_sequence():
93
    cstats = ConstructorStats.get(m.Sequence)
Dean Moldovan's avatar
Dean Moldovan committed
94

95
    s = m.Sequence(5)
96
    assert cstats.values() == ["of size", "5"]
Dean Moldovan's avatar
Dean Moldovan committed
97
98
99
100
101
102
103
104
105
106

    assert "Sequence" in repr(s)
    assert len(s) == 5
    assert s[0] == 0 and s[3] == 0
    assert 12.34 not in s
    s[0], s[3] = 12.34, 56.78
    assert 12.34 in s
    assert isclose(s[0], 12.34) and isclose(s[3], 56.78)

    rev = reversed(s)
107
    assert cstats.values() == ["of size", "5"]
Dean Moldovan's avatar
Dean Moldovan committed
108
109

    rev2 = s[::-1]
110
    assert cstats.values() == ["of size", "5"]
Dean Moldovan's avatar
Dean Moldovan committed
111

112
    it = iter(m.Sequence(0))
113
114
115
    for _ in range(3):  # __next__ must continue to raise StopIteration
        with pytest.raises(StopIteration):
            next(it)
116
    assert cstats.values() == ["of size", "0"]
117

Dean Moldovan's avatar
Dean Moldovan committed
118
119
120
121
122
    expected = [0, 56.78, 0, 0, 12.34]
    assert allclose(rev, expected)
    assert allclose(rev2, expected)
    assert rev == rev2

123
    rev[0::2] = m.Sequence([2.0, 2.0, 2.0])
124
    assert cstats.values() == ["of size", "3", "from std::vector"]
Dean Moldovan's avatar
Dean Moldovan committed
125
126
127

    assert allclose(rev, [2, 56.78, 2, 0, 2])

128
129
    assert cstats.alive() == 4
    del it
Dean Moldovan's avatar
Dean Moldovan committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    assert cstats.alive() == 3
    del s
    assert cstats.alive() == 2
    del rev
    assert cstats.alive() == 1
    del rev2
    assert cstats.alive() == 0

    assert cstats.values() == []
    assert cstats.default_constructions == 0
    assert cstats.copy_constructions == 0
    assert cstats.move_constructions >= 1
    assert cstats.copy_assignments == 0
    assert cstats.move_assignments == 0


146
def test_sequence_length():
147
    """#2076: Exception raised by len(arg) should be propagated"""
148

149
150
151
    class BadLen(RuntimeError):
        pass

152
    class SequenceLike:
153
154
155
156
157
158
159
160
161
162
163
164
165
        def __getitem__(self, i):
            return None

        def __len__(self):
            raise BadLen()

    with pytest.raises(BadLen):
        m.sequence_length(SequenceLike())

    assert m.sequence_length([1, 2, 3]) == 3
    assert m.sequence_length("hello") == 5


Dean Moldovan's avatar
Dean Moldovan committed
166
def test_map_iterator():
167
168
    sm = m.StringMap({"hi": "bye", "black": "white"})
    assert sm["hi"] == "bye"
169
    assert len(sm) == 2
170
    assert sm["black"] == "white"
Dean Moldovan's avatar
Dean Moldovan committed
171
172

    with pytest.raises(KeyError):
173
174
175
        assert sm["orange"]
    sm["orange"] = "banana"
    assert sm["orange"] == "banana"
Dean Moldovan's avatar
Dean Moldovan committed
176

177
    expected = {"hi": "bye", "black": "white", "orange": "banana"}
178
179
180
    for k in sm:
        assert sm[k] == expected[k]
    for k, v in sm.items():
Dean Moldovan's avatar
Dean Moldovan committed
181
        assert v == expected[k]
Bruce Merry's avatar
Bruce Merry committed
182
    assert list(sm.values()) == [expected[k] for k in sm]
Dean Moldovan's avatar
Dean Moldovan committed
183

184
    it = iter(m.StringMap({}))
185
186
187
188
    for _ in range(3):  # __next__ must continue to raise StopIteration
        with pytest.raises(StopIteration):
            next(it)

Dean Moldovan's avatar
Dean Moldovan committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

def test_python_iterator_in_cpp():
    t = (1, 2, 3)
    assert m.object_to_list(t) == [1, 2, 3]
    assert m.object_to_list(iter(t)) == [1, 2, 3]
    assert m.iterator_to_list(iter(t)) == [1, 2, 3]

    with pytest.raises(TypeError) as excinfo:
        m.object_to_list(1)
    assert "object is not iterable" in str(excinfo.value)

    with pytest.raises(TypeError) as excinfo:
        m.iterator_to_list(1)
    assert "incompatible function arguments" in str(excinfo.value)

    def bad_next_call():
        raise RuntimeError("py::iterator::advance() should propagate errors")

    with pytest.raises(RuntimeError) as excinfo:
        m.iterator_to_list(iter(bad_next_call, None))
    assert str(excinfo.value) == "py::iterator::advance() should propagate errors"
210

211
212
213
    lst = [1, None, 0, None]
    assert m.count_none(lst) == 2
    assert m.find_none(lst) is True
214
215
216
217
218
219
    assert m.count_nonzeros({"a": 0, "b": 1, "c": 2}) == 2

    r = range(5)
    assert all(m.tuple_iterator(tuple(r)))
    assert all(m.list_iterator(list(r)))
    assert all(m.sequence_iterator(r))
220
221
222
223
224
225


def test_iterator_passthrough():
    """#181: iterator passthrough did not compile"""
    from pybind11_tests.sequences_and_iterators import iterator_passthrough

226
227
    values = [3, 5, 7, 9, 11, 13, 15]
    assert list(iterator_passthrough(iter(values))) == values
228
229
230


def test_iterator_rvp():
231
    """#388: Can't make iterators via make_iterator() with different r/v policies"""
232
233
234
235
236
    import pybind11_tests.sequences_and_iterators as m

    assert list(m.make_iterator_1()) == [1, 2, 3]
    assert list(m.make_iterator_2()) == [1, 2, 3]
    assert not isinstance(m.make_iterator_1(), type(m.make_iterator_2()))