test_numpy_vectorize.py 8.68 KB
Newer Older
Dean Moldovan's avatar
Dean Moldovan committed
1
2
import pytest

3
4
pytestmark = pytest.requires_numpy

Dean Moldovan's avatar
Dean Moldovan committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
with pytest.suppress(ImportError):
    import numpy as np


def test_vectorize(capture):
    from pybind11_tests import vectorized_func, vectorized_func2, vectorized_func3

    assert np.isclose(vectorized_func3(np.array(3 + 7j)), [6 + 14j])

    for f in [vectorized_func, vectorized_func2]:
        with capture:
            assert np.isclose(f(1, 2, 3), 6)
        assert capture == "my_func(x:int=1, y:float=2, z:float=3)"
        with capture:
            assert np.isclose(f(np.array(1), np.array(2), 3), 6)
        assert capture == "my_func(x:int=1, y:float=2, z:float=3)"
        with capture:
            assert np.allclose(f(np.array([1, 3]), np.array([2, 4]), 3), [6, 36])
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=3)
            my_func(x:int=3, y:float=4, z:float=3)
        """
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        with capture:
            a = np.array([[1, 2], [3, 4]], order='F')
            b = np.array([[10, 20], [30, 40]], order='F')
            c = 3
            result = f(a, b, c)
            assert np.allclose(result, a * b * c)
            assert result.flags.f_contiguous
        # All inputs are F order and full or singletons, so we the result is in col-major order:
        assert capture == """
            my_func(x:int=1, y:float=10, z:float=3)
            my_func(x:int=3, y:float=30, z:float=3)
            my_func(x:int=2, y:float=20, z:float=3)
            my_func(x:int=4, y:float=40, z:float=3)
        """
Dean Moldovan's avatar
Dean Moldovan committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        with capture:
            a, b, c = np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=3)
            my_func(x:int=3, y:float=4, z:float=3)
            my_func(x:int=5, y:float=6, z:float=3)
            my_func(x:int=7, y:float=8, z:float=3)
            my_func(x:int=9, y:float=10, z:float=3)
            my_func(x:int=11, y:float=12, z:float=3)
        """
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=2, y:float=3, z:float=2)
            my_func(x:int=3, y:float=4, z:float=2)
            my_func(x:int=4, y:float=2, z:float=2)
            my_func(x:int=5, y:float=3, z:float=2)
            my_func(x:int=6, y:float=4, z:float=2)
        """
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=2, y:float=2, z:float=2)
            my_func(x:int=3, y:float=2, z:float=2)
            my_func(x:int=4, y:float=3, z:float=2)
            my_func(x:int=5, y:float=3, z:float=2)
            my_func(x:int=6, y:float=3, z:float=2)
        """
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]], order='F'), np.array([[2], [3]]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=2, y:float=2, z:float=2)
            my_func(x:int=3, y:float=2, z:float=2)
            my_func(x:int=4, y:float=3, z:float=2)
            my_func(x:int=5, y:float=3, z:float=2)
            my_func(x:int=6, y:float=3, z:float=2)
        """
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]])[::, ::2], np.array([[2], [3]]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=3, y:float=2, z:float=2)
            my_func(x:int=4, y:float=3, z:float=2)
            my_func(x:int=6, y:float=3, z:float=2)
        """
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]], order='F')[::, ::2], np.array([[2], [3]]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=3, y:float=2, z:float=2)
            my_func(x:int=4, y:float=3, z:float=2)
            my_func(x:int=6, y:float=3, z:float=2)
        """
Dean Moldovan's avatar
Dean Moldovan committed
103
104


105
def test_type_selection():
Dean Moldovan's avatar
Dean Moldovan committed
106
107
    from pybind11_tests import selective_func

108
109
110
    assert selective_func(np.array([1], dtype=np.int32)) == "Int branch taken."
    assert selective_func(np.array([1.0], dtype=np.float32)) == "Float branch taken."
    assert selective_func(np.array([1.0j], dtype=np.complex64)) == "Complex float branch taken."
Dean Moldovan's avatar
Dean Moldovan committed
111
112
113
114
115


def test_docs(doc):
    from pybind11_tests import vectorized_func

116
    assert doc(vectorized_func) == """
117
        vectorized_func(arg0: numpy.ndarray[int32], arg1: numpy.ndarray[float32], arg2: numpy.ndarray[float64]) -> object
118
    """  # noqa: E501 line too long
119
120
121


def test_trivial_broadcasting():
122
    from pybind11_tests import vectorized_is_trivial, trivial, vectorized_func
123

124
125
126
127
    assert vectorized_is_trivial(1, 2, 3) == trivial.c_trivial
    assert vectorized_is_trivial(np.array(1), np.array(2), 3) == trivial.c_trivial
    assert vectorized_is_trivial(np.array([1, 3]), np.array([2, 4]), 3) == trivial.c_trivial
    assert trivial.c_trivial == vectorized_is_trivial(
128
        np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3)
129
130
131
132
    assert vectorized_is_trivial(
        np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2) == trivial.non_trivial
    assert vectorized_is_trivial(
        np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2) == trivial.non_trivial
133
134
135
    z1 = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype='int32')
    z2 = np.array(z1, dtype='float32')
    z3 = np.array(z1, dtype='float64')
136
137
138
139
140
141
142
143
    assert vectorized_is_trivial(z1, z2, z3) == trivial.c_trivial
    assert vectorized_is_trivial(1, z2, z3) == trivial.c_trivial
    assert vectorized_is_trivial(z1, 1, z3) == trivial.c_trivial
    assert vectorized_is_trivial(z1, z2, 1) == trivial.c_trivial
    assert vectorized_is_trivial(z1[::2, ::2], 1, 1) == trivial.non_trivial
    assert vectorized_is_trivial(1, 1, z1[::2, ::2]) == trivial.c_trivial
    assert vectorized_is_trivial(1, 1, z3[::2, ::2]) == trivial.non_trivial
    assert vectorized_is_trivial(z1, 1, z3[1::4, 1::4]) == trivial.c_trivial
144
145
146
147

    y1 = np.array(z1, order='F')
    y2 = np.array(y1)
    y3 = np.array(y1)
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    assert vectorized_is_trivial(y1, y2, y3) == trivial.f_trivial
    assert vectorized_is_trivial(y1, 1, 1) == trivial.f_trivial
    assert vectorized_is_trivial(1, y2, 1) == trivial.f_trivial
    assert vectorized_is_trivial(1, 1, y3) == trivial.f_trivial
    assert vectorized_is_trivial(y1, z2, 1) == trivial.non_trivial
    assert vectorized_is_trivial(z1[1::4, 1::4], y2, 1) == trivial.f_trivial
    assert vectorized_is_trivial(y1[1::4, 1::4], z2, 1) == trivial.c_trivial

    assert vectorized_func(z1, z2, z3).flags.c_contiguous
    assert vectorized_func(y1, y2, y3).flags.f_contiguous
    assert vectorized_func(z1, 1, 1).flags.c_contiguous
    assert vectorized_func(1, y2, 1).flags.f_contiguous
    assert vectorized_func(z1[1::4, 1::4], y2, 1).flags.f_contiguous
    assert vectorized_func(y1[1::4, 1::4], z2, 1).flags.c_contiguous
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203


def test_passthrough_arguments(doc):
    from pybind11_tests import vec_passthrough, NonPODClass

    assert doc(vec_passthrough) == (
        "vec_passthrough("
        "arg0: float, arg1: numpy.ndarray[float64], arg2: numpy.ndarray[float64], "
        "arg3: numpy.ndarray[int32], arg4: int, arg5: m.NonPODClass, arg6: numpy.ndarray[float64]"
        ") -> object")

    b = np.array([[10, 20, 30]], dtype='float64')
    c = np.array([100, 200])  # NOT a vectorized argument
    d = np.array([[1000], [2000], [3000]], dtype='int')
    g = np.array([[1000000, 2000000, 3000000]], dtype='int')  # requires casting
    assert np.all(
        vec_passthrough(1, b, c, d, 10000, NonPODClass(100000), g) ==
        np.array([[1111111, 2111121, 3111131],
                  [1112111, 2112121, 3112131],
                  [1113111, 2113121, 3113131]]))


def test_method_vectorization():
    from pybind11_tests import VectorizeTestClass

    o = VectorizeTestClass(3)
    x = np.array([1, 2], dtype='int')
    y = np.array([[10], [20]], dtype='float32')
    assert np.all(o.method(x, y) == [[14, 15], [24, 25]])


def test_array_collapse():
    from pybind11_tests import vectorized_func

    assert not isinstance(vectorized_func(1, 2, 3), np.ndarray)
    assert not isinstance(vectorized_func(np.array(1), 2, 3), np.ndarray)
    z = vectorized_func([1], 2, 3)
    assert isinstance(z, np.ndarray)
    assert z.shape == (1, )
    z = vectorized_func(1, [[[2]]], 3)
    assert isinstance(z, np.ndarray)
    assert z.shape == (1, 1, 1)