test_callbacks.py 5.32 KB
Newer Older
1
# -*- coding: utf-8 -*-
Dean Moldovan's avatar
Dean Moldovan committed
2
import pytest
3
from pybind11_tests import callbacks as m
4
from threading import Thread
5
import time
Dean Moldovan's avatar
Dean Moldovan committed
6
7


8
def test_callbacks():
Dean Moldovan's avatar
Dean Moldovan committed
9
10
11
    from functools import partial

    def func1():
12
        return "func1"
Dean Moldovan's avatar
Dean Moldovan committed
13
14

    def func2(a, b, c, d):
15
        return "func2", a, b, c, d
Dean Moldovan's avatar
Dean Moldovan committed
16
17

    def func3(a):
18
19
        return "func3({})".format(a)

20
21
22
23
24
    assert m.test_callback1(func1) == "func1"
    assert m.test_callback2(func2) == ("func2", "Hello", "x", True, 5)
    assert m.test_callback1(partial(func2, 1, 2, 3, 4)) == ("func2", 1, 2, 3, 4)
    assert m.test_callback1(partial(func3, "partial")) == "func3(partial)"
    assert m.test_callback3(lambda i: i + 1) == "func(43) = 44"
Dean Moldovan's avatar
Dean Moldovan committed
25

26
    f = m.test_callback4()
Dean Moldovan's avatar
Dean Moldovan committed
27
    assert f(43) == 44
28
    f = m.test_callback5()
Dean Moldovan's avatar
Dean Moldovan committed
29
30
31
    assert f(number=43) == 44


32
33
34
35
36
37
38
def test_bound_method_callback():
    # Bound Python method:
    class MyClass:
        def double(self, val):
            return 2 * val

    z = MyClass()
39
    assert m.test_callback3(z.double) == "func(43) = 86"
40

41
42
    z = m.CppBoundMethodTest()
    assert m.test_callback3(z.triple) == "func(43) = 129"
43
44


45
46
47
48
def test_keyword_args_and_generalized_unpacking():
    def f(*args, **kwargs):
        return args, kwargs

49
    assert m.test_tuple_unpacking(f) == (("positional", 1, 2, 3, 4, 5, 6), {})
50
51
52
53
    assert m.test_dict_unpacking(f) == (
        ("positional", 1),
        {"key": "value", "a": 1, "b": 2},
    )
54
55
56
    assert m.test_keyword_args(f) == ((), {"x": 10, "y": 20})
    assert m.test_unpacking_and_keywords1(f) == ((1, 2), {"c": 3, "d": 4})
    assert m.test_unpacking_and_keywords2(f) == (
57
        ("positional", 1, 2, 3, 4, 5),
58
        {"key": "value", "a": 1, "b": 2, "c": 3, "d": 4, "e": 5},
59
60
61
    )

    with pytest.raises(TypeError) as excinfo:
62
        m.test_unpacking_error1(f)
63
64
65
    assert "Got multiple values for keyword argument" in str(excinfo.value)

    with pytest.raises(TypeError) as excinfo:
66
        m.test_unpacking_error2(f)
67
68
69
    assert "Got multiple values for keyword argument" in str(excinfo.value)

    with pytest.raises(RuntimeError) as excinfo:
70
        m.test_arg_conversion_error1(f)
71
72
73
    assert "Unable to convert call argument" in str(excinfo.value)

    with pytest.raises(RuntimeError) as excinfo:
74
        m.test_arg_conversion_error2(f)
75
76
77
    assert "Unable to convert call argument" in str(excinfo.value)


Dean Moldovan's avatar
Dean Moldovan committed
78
def test_lambda_closure_cleanup():
79
80
    m.test_cleanup()
    cstats = m.payload_cstats()
Dean Moldovan's avatar
Dean Moldovan committed
81
82
83
84
85
    assert cstats.alive() == 0
    assert cstats.copy_constructions == 1
    assert cstats.move_constructions >= 1


86
def test_cpp_function_roundtrip():
Dean Moldovan's avatar
Dean Moldovan committed
87
88
    """Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer"""

89
90
91
92
93
94
95
    assert (
        m.test_dummy_function(m.dummy_function) == "matches dummy_function: eval(1) = 2"
    )
    assert (
        m.test_dummy_function(m.roundtrip(m.dummy_function))
        == "matches dummy_function: eval(1) = 2"
    )
96
    assert m.roundtrip(None, expect_none=True) is None
97
98
99
100
    assert (
        m.test_dummy_function(lambda x: x + 2)
        == "can't convert to function pointer: eval(1) = 3"
    )
Dean Moldovan's avatar
Dean Moldovan committed
101

102
    with pytest.raises(TypeError) as excinfo:
103
        m.test_dummy_function(m.dummy_function2)
104
    assert "incompatible function arguments" in str(excinfo.value)
Dean Moldovan's avatar
Dean Moldovan committed
105

106
    with pytest.raises(TypeError) as excinfo:
107
        m.test_dummy_function(lambda x, y: x + y)
108
109
110
111
    assert any(
        s in str(excinfo.value)
        for s in ("missing 1 required positional argument", "takes exactly 2 arguments")
    )
Dean Moldovan's avatar
Dean Moldovan committed
112
113
114


def test_function_signatures(doc):
115
116
    assert doc(m.test_callback3) == "test_callback3(arg0: Callable[[int], int]) -> str"
    assert doc(m.test_callback4) == "test_callback4() -> Callable[[int], int]"
117
118
119


def test_movable_object():
120
    assert m.callback_with_movable(lambda _: None) is True
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140


def test_async_callbacks():
    # serves as state for async callback
    class Item:
        def __init__(self, value):
            self.value = value

    res = []

    # generate stateful lambda that will store result in `res`
    def gen_f():
        s = Item(3)
        return lambda j: res.append(s.value + j)

    # do some work async
    work = [1, 2, 3, 4]
    m.test_async_callback(gen_f(), work)
    # wait until work is done
    from time import sleep
141

142
143
144
145
146
147
148
149
    sleep(0.5)
    assert sum(res) == sum([x + 3 for x in work])


def test_async_async_callbacks():
    t = Thread(target=test_async_callbacks)
    t.start()
    t.join()
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180


def test_callback_num_times():
    # Super-simple micro-benchmarking related to PR #2919.
    # Example runtimes (Intel Xeon 2.2GHz, fully optimized):
    #   num_millions  1, repeats  2:  0.1 secs
    #   num_millions 20, repeats 10: 11.5 secs
    one_million = 1000000
    num_millions = 1  # Try 20 for actual micro-benchmarking.
    repeats = 2  # Try 10.
    rates = []
    for rep in range(repeats):
        t0 = time.time()
        m.callback_num_times(lambda: None, num_millions * one_million)
        td = time.time() - t0
        rate = num_millions / td if td else 0
        rates.append(rate)
        if not rep:
            print()
        print(
            "callback_num_times: {:d} million / {:.3f} seconds = {:.3f} million / second".format(
                num_millions, td, rate
            )
        )
    if len(rates) > 1:
        print("Min    Mean   Max")
        print(
            "{:6.3f} {:6.3f} {:6.3f}".format(
                min(rates), sum(rates) / len(rates), max(rates)
            )
        )