test_numpy_vectorize.py 5.45 KB
Newer Older
Dean Moldovan's avatar
Dean Moldovan committed
1
2
import pytest

3
4
pytestmark = pytest.requires_numpy

Dean Moldovan's avatar
Dean Moldovan committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
with pytest.suppress(ImportError):
    import numpy as np


def test_vectorize(capture):
    from pybind11_tests import vectorized_func, vectorized_func2, vectorized_func3

    assert np.isclose(vectorized_func3(np.array(3 + 7j)), [6 + 14j])

    for f in [vectorized_func, vectorized_func2]:
        with capture:
            assert np.isclose(f(1, 2, 3), 6)
        assert capture == "my_func(x:int=1, y:float=2, z:float=3)"
        with capture:
            assert np.isclose(f(np.array(1), np.array(2), 3), 6)
        assert capture == "my_func(x:int=1, y:float=2, z:float=3)"
        with capture:
            assert np.allclose(f(np.array([1, 3]), np.array([2, 4]), 3), [6, 36])
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=3)
            my_func(x:int=3, y:float=4, z:float=3)
        """
        with capture:
            a, b, c = np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=3)
            my_func(x:int=3, y:float=4, z:float=3)
            my_func(x:int=5, y:float=6, z:float=3)
            my_func(x:int=7, y:float=8, z:float=3)
            my_func(x:int=9, y:float=10, z:float=3)
            my_func(x:int=11, y:float=12, z:float=3)
        """
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=2, y:float=3, z:float=2)
            my_func(x:int=3, y:float=4, z:float=2)
            my_func(x:int=4, y:float=2, z:float=2)
            my_func(x:int=5, y:float=3, z:float=2)
            my_func(x:int=6, y:float=4, z:float=2)
        """
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=2, y:float=2, z:float=2)
            my_func(x:int=3, y:float=2, z:float=2)
            my_func(x:int=4, y:float=3, z:float=2)
            my_func(x:int=5, y:float=3, z:float=2)
            my_func(x:int=6, y:float=3, z:float=2)
        """
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]], order='F'), np.array([[2], [3]]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=2, y:float=2, z:float=2)
            my_func(x:int=3, y:float=2, z:float=2)
            my_func(x:int=4, y:float=3, z:float=2)
            my_func(x:int=5, y:float=3, z:float=2)
            my_func(x:int=6, y:float=3, z:float=2)
        """
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]])[::, ::2], np.array([[2], [3]]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=3, y:float=2, z:float=2)
            my_func(x:int=4, y:float=3, z:float=2)
            my_func(x:int=6, y:float=3, z:float=2)
        """
        with capture:
            a, b, c = np.array([[1, 2, 3], [4, 5, 6]], order='F')[::, ::2], np.array([[2], [3]]), 2
            assert np.allclose(f(a, b, c), a * b * c)
        assert capture == """
            my_func(x:int=1, y:float=2, z:float=2)
            my_func(x:int=3, y:float=2, z:float=2)
            my_func(x:int=4, y:float=3, z:float=2)
            my_func(x:int=6, y:float=3, z:float=2)
        """
Dean Moldovan's avatar
Dean Moldovan committed
89
90


91
def test_type_selection():
Dean Moldovan's avatar
Dean Moldovan committed
92
93
    from pybind11_tests import selective_func

94
95
96
    assert selective_func(np.array([1], dtype=np.int32)) == "Int branch taken."
    assert selective_func(np.array([1.0], dtype=np.float32)) == "Float branch taken."
    assert selective_func(np.array([1.0j], dtype=np.complex64)) == "Complex float branch taken."
Dean Moldovan's avatar
Dean Moldovan committed
97
98
99
100
101


def test_docs(doc):
    from pybind11_tests import vectorized_func

102
    assert doc(vectorized_func) == """
103
        vectorized_func(arg0: numpy.ndarray[int32], arg1: numpy.ndarray[float32], arg2: numpy.ndarray[float64]) -> object
104
    """  # noqa: E501 line too long
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133


def test_trivial_broadcasting():
    from pybind11_tests import vectorized_is_trivial

    assert vectorized_is_trivial(1, 2, 3)
    assert vectorized_is_trivial(np.array(1), np.array(2), 3)
    assert vectorized_is_trivial(np.array([1, 3]), np.array([2, 4]), 3)
    assert vectorized_is_trivial(
        np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3)
    assert not vectorized_is_trivial(
        np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2)
    assert not vectorized_is_trivial(
        np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2)
    z1 = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype='int32')
    z2 = np.array(z1, dtype='float32')
    z3 = np.array(z1, dtype='float64')
    assert vectorized_is_trivial(z1, z2, z3)
    assert not vectorized_is_trivial(z1[::2, ::2], 1, 1)
    assert vectorized_is_trivial(1, 1, z1[::2, ::2])
    assert not vectorized_is_trivial(1, 1, z3[::2, ::2])
    assert vectorized_is_trivial(z1, 1, z3[1::4, 1::4])

    y1 = np.array(z1, order='F')
    y2 = np.array(y1)
    y3 = np.array(y1)
    assert not vectorized_is_trivial(y1, y2, y3)
    assert not vectorized_is_trivial(y1, z2, z3)
    assert not vectorized_is_trivial(y1, 1, 1)