test_interpreter.cpp 13.7 KB
Newer Older
1
#include <pybind11/embed.h>
2
3

#ifdef _MSC_VER
4
5
6
// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
// catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
#    pragma warning(disable : 4996)
7
8
#endif

9
#include <catch.hpp>
10
#include <cstdlib>
11
12
#include <fstream>
#include <functional>
13
14
#include <thread>
#include <utility>
15

16
17
18
19
20
namespace py = pybind11;
using namespace py::literals;

class Widget {
public:
21
    explicit Widget(std::string message) : message(std::move(message)) {}
22
23
24
25
    virtual ~Widget() = default;

    std::string the_message() const { return message; }
    virtual int the_answer() const = 0;
26
    virtual std::string argv0() const = 0;
27
28
29
30
31
32
33
34

private:
    std::string message;
};

class PyWidget final : public Widget {
    using Widget::Widget;

35
    int the_answer() const override { PYBIND11_OVERRIDE_PURE(int, Widget, the_answer); }
36
    std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); }
37
38
};

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class test_override_cache_helper {

public:
    virtual int func() { return 0; }

    test_override_cache_helper() = default;
    virtual ~test_override_cache_helper() = default;
    // Non-copyable
    test_override_cache_helper &operator=(test_override_cache_helper const &Right) = delete;
    test_override_cache_helper(test_override_cache_helper const &Copy) = delete;
};

class test_override_cache_helper_trampoline : public test_override_cache_helper {
    int func() override { PYBIND11_OVERRIDE(int, test_override_cache_helper, func); }
};

55
PYBIND11_EMBEDDED_MODULE(widget_module, m) {
56
57
58
    py::class_<Widget, PyWidget>(m, "Widget")
        .def(py::init<std::string>())
        .def_property_readonly("the_message", &Widget::the_message);
59
60

    m.def("add", [](int i, int j) { return i + j; });
61
}
62

63
PYBIND11_EMBEDDED_MODULE(trampoline_module, m) {
64
65
66
    py::class_<test_override_cache_helper,
               test_override_cache_helper_trampoline,
               std::shared_ptr<test_override_cache_helper>>(m, "test_override_cache_helper")
67
68
69
70
        .def(py::init_alias<>())
        .def("func", &test_override_cache_helper::func);
}

71
PYBIND11_EMBEDDED_MODULE(throw_exception, ) { throw std::runtime_error("C++ Error"); }
72

73
74
75
PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
    auto d = py::dict();
    d["missing"].cast<py::object>();
76
77
78
}

TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
79
80
    auto module_ = py::module_::import("test_interpreter");
    REQUIRE(py::hasattr(module_, "DerivedWidget"));
81

82
    auto locals = py::dict("hello"_a = "Hello, World!", "x"_a = 5, **module_.attr("__dict__"));
83
84
85
    py::exec(R"(
        widget = DerivedWidget("{} - {}".format(hello, x))
        message = widget.the_message
86
87
88
    )",
             py::globals(),
             locals);
89
90
    REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");

91
    auto py_widget = module_.attr("DerivedWidget")("The question");
92
93
    auto message = py_widget.attr("the_message");
    REQUIRE(message.cast<std::string>() == "The question");
94

95
96
97
98
    const auto &cpp_widget = py_widget.cast<const Widget &>();
    REQUIRE(cpp_widget.the_answer() == 42);
}

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
TEST_CASE("Override cache") {
    auto module_ = py::module_::import("test_trampoline");
    REQUIRE(py::hasattr(module_, "func"));
    REQUIRE(py::hasattr(module_, "func2"));

    auto locals = py::dict(**module_.attr("__dict__"));

    int i = 0;
    for (; i < 1500; ++i) {
        std::shared_ptr<test_override_cache_helper> p_obj;
        std::shared_ptr<test_override_cache_helper> p_obj2;

        py::object loc_inst = locals["func"]();
        p_obj = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);

        int ret = p_obj->func();

        REQUIRE(ret == 42);

        loc_inst = locals["func2"]();

        p_obj2 = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);

        p_obj2->func();
    }
}

126
TEST_CASE("Import error handling") {
127
    REQUIRE_NOTHROW(py::module_::import("widget_module"));
128
    REQUIRE_THROWS_WITH(py::module_::import("throw_exception"), "ImportError: C++ Error");
129
130
131
    REQUIRE_THROWS_WITH(py::module_::import("throw_error_already_set"),
                        Catch::Contains("ImportError: initialization failed"));

132
    auto locals = py::dict("is_keyerror"_a = false, "message"_a = "not set");
133
134
135
136
137
138
    py::exec(R"(
        try:
            import throw_error_already_set
        except ImportError as e:
            is_keyerror = type(e.__cause__) == KeyError
            message = str(e.__cause__)
139
140
141
    )",
             py::globals(),
             locals);
142
143
    REQUIRE(locals["is_keyerror"].cast<bool>() == true);
    REQUIRE(locals["message"].cast<std::string>() == "'missing'");
144
}
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
TEST_CASE("There can be only one interpreter") {
    static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
    static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
    static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
    static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");

    REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
    REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");

    py::finalize_interpreter();
    REQUIRE_NOTHROW(py::scoped_interpreter());
    {
        auto pyi1 = py::scoped_interpreter();
        auto pyi2 = std::move(pyi1);
160
    }
161
    py::initialize_interpreter();
162
}
163

164
bool has_pybind11_internals_builtin() {
165
166
167
168
    auto builtins = py::handle(PyEval_GetBuiltins());
    return builtins.contains(PYBIND11_INTERNALS_ID);
};

169
bool has_pybind11_internals_static() {
170
    auto **&ipp = py::detail::get_internals_pp();
171
    return (ipp != nullptr) && (*ipp != nullptr);
172
173
}

174
175
TEST_CASE("Restart the interpreter") {
    // Verify pre-restart state.
176
    REQUIRE(py::module_::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
177
178
    REQUIRE(has_pybind11_internals_builtin());
    REQUIRE(has_pybind11_internals_static());
179
180
    REQUIRE(py::module_::import("external_module").attr("A")(123).attr("value").cast<int>()
            == 123);
181
182

    // local and foreign module internals should point to the same internals:
183
184
    REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
            == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
185
186
187
188
189
190
191
192
193

    // Restart the interpreter.
    py::finalize_interpreter();
    REQUIRE(Py_IsInitialized() == 0);

    py::initialize_interpreter();
    REQUIRE(Py_IsInitialized() == 1);

    // Internals are deleted after a restart.
194
195
    REQUIRE_FALSE(has_pybind11_internals_builtin());
    REQUIRE_FALSE(has_pybind11_internals_static());
196
    pybind11::detail::get_internals();
197
198
    REQUIRE(has_pybind11_internals_builtin());
    REQUIRE(has_pybind11_internals_static());
199
200
    REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
            == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
201
202
203
204
205
206

    // Make sure that an interpreter with no get_internals() created until finalize still gets the
    // internals destroyed
    py::finalize_interpreter();
    py::initialize_interpreter();
    bool ran = false;
207
208
209
210
211
    py::module_::import("__main__").attr("internals_destroy_test")
        = py::capsule(&ran, [](void *ran) {
              py::detail::get_internals();
              *static_cast<bool *>(ran) = true;
          });
212
213
214
215
216
217
218
219
    REQUIRE_FALSE(has_pybind11_internals_builtin());
    REQUIRE_FALSE(has_pybind11_internals_static());
    REQUIRE_FALSE(ran);
    py::finalize_interpreter();
    REQUIRE(ran);
    py::initialize_interpreter();
    REQUIRE_FALSE(has_pybind11_internals_builtin());
    REQUIRE_FALSE(has_pybind11_internals_static());
220
221

    // C++ modules can be reloaded.
222
    auto cpp_module = py::module_::import("widget_module");
223
224
225
    REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);

    // C++ type information is reloaded and can be used in python modules.
226
    auto py_module = py::module_::import("test_interpreter");
227
228
229
230
231
232
    auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
    REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
}

TEST_CASE("Subinterpreter") {
    // Add tags to the modules in the main interpreter and test the basics.
233
    py::module_::import("__main__").attr("main_tag") = "main interpreter";
234
    {
235
        auto m = py::module_::import("widget_module");
236
237
238
239
        m.attr("extension_module_tag") = "added to module in main interpreter";

        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
    }
240
241
    REQUIRE(has_pybind11_internals_builtin());
    REQUIRE(has_pybind11_internals_static());
242
243

    /// Create and switch to a subinterpreter.
244
245
    auto *main_tstate = PyThreadState_Get();
    auto *sub_tstate = Py_NewInterpreter();
246
247
248
249

    // Subinterpreters get their own copy of builtins. detail::get_internals() still
    // works by returning from the static variable, i.e. all interpreters share a single
    // global pybind11::internals;
250
251
    REQUIRE_FALSE(has_pybind11_internals_builtin());
    REQUIRE(has_pybind11_internals_static());
252
253

    // Modules tags should be gone.
254
    REQUIRE_FALSE(py::hasattr(py::module_::import("__main__"), "tag"));
255
    {
256
        auto m = py::module_::import("widget_module");
257
258
259
260
261
262
263
264
265
266
        REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));

        // Function bindings should still work.
        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
    }

    // Restore main interpreter.
    Py_EndInterpreter(sub_tstate);
    PyThreadState_Swap(main_tstate);

267
268
    REQUIRE(py::hasattr(py::module_::import("__main__"), "main_tag"));
    REQUIRE(py::hasattr(py::module_::import("widget_module"), "extension_module_tag"));
269
}
270
271
272
273
274
275
276

TEST_CASE("Execution frame") {
    // When the interpreter is embedded, there is no execution frame, but `py::exec`
    // should still function by using reasonable globals: `__main__.__dict__`.
    py::exec("var = dict(number=42)");
    REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
}
277
278
279
280
281
282
283
284

TEST_CASE("Threads") {
    // Restart interpreter to ensure threads are not initialized
    py::finalize_interpreter();
    py::initialize_interpreter();
    REQUIRE_FALSE(has_pybind11_internals_static());

    constexpr auto num_threads = 10;
285
    auto locals = py::dict("count"_a = 0);
286
287
288
289
290
291
292
293
294

    {
        py::gil_scoped_release gil_release{};
        REQUIRE(has_pybind11_internals_static());

        auto threads = std::vector<std::thread>();
        for (auto i = 0; i < num_threads; ++i) {
            threads.emplace_back([&]() {
                py::gil_scoped_acquire gil{};
295
                locals["count"] = locals["count"].cast<int>() + 1;
296
297
298
299
300
301
302
303
304
305
            });
        }

        for (auto &thread : threads) {
            thread.join();
        }
    }

    REQUIRE(locals["count"].cast<int>() == num_threads);
}
306
307
308
309
310

// Scope exit utility https://stackoverflow.com/a/36644501/7255855
struct scope_exit {
    std::function<void()> f_;
    explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
311
312
313
314
315
    ~scope_exit() {
        if (f_) {
            f_();
        }
    }
316
317
318
319
320
321
};

TEST_CASE("Reload module from file") {
    // Disable generation of cached bytecode (.pyc files) for this test, otherwise
    // Python might pick up an old version from the cache instead of the new versions
    // of the .py files generated below
322
    auto sys = py::module_::import("sys");
323
324
325
    bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
    sys.attr("dont_write_bytecode") = true;
    // Reset the value at scope exit
326
327
    scope_exit reset_dont_write_bytecode(
        [&]() { sys.attr("dont_write_bytecode") = dont_write_bytecode; });
328
329
330
331
332
333
334
335
336
337

    std::string module_name = "test_module_reload";
    std::string module_file = module_name + ".py";

    // Create the module .py file
    std::ofstream test_module(module_file);
    test_module << "def test():\n";
    test_module << "    return 1\n";
    test_module.close();
    // Delete the file at scope exit
338
    scope_exit delete_module_file([&]() { std::remove(module_file.c_str()); });
339
340

    // Import the module from file
341
342
    auto module_ = py::module_::import(module_name.c_str());
    int result = module_.attr("test")().cast<int>();
343
344
345
346
347
348
349
350
351
    REQUIRE(result == 1);

    // Update the module .py file with a small change
    test_module.open(module_file);
    test_module << "def test():\n";
    test_module << "    return 2\n";
    test_module.close();

    // Reload the module
352
353
    module_.reload();
    result = module_.attr("test")().cast<int>();
354
355
    REQUIRE(result == 2);
}
356
357
358
359
360
361
362
363
364
365
366
367
368
369

TEST_CASE("sys.argv gets initialized properly") {
    py::finalize_interpreter();
    {
        py::scoped_interpreter default_scope;
        auto module = py::module::import("test_interpreter");
        auto py_widget = module.attr("DerivedWidget")("The question");
        const auto &cpp_widget = py_widget.cast<const Widget &>();
        REQUIRE(cpp_widget.argv0().empty());
    }

    {
        char *argv[] = {strdup("a.out")};
        py::scoped_interpreter argv_scope(true, 1, argv);
370
        std::free(argv[0]);
371
372
373
374
375
376
377
        auto module = py::module::import("test_interpreter");
        auto py_widget = module.attr("DerivedWidget")("The question");
        const auto &cpp_widget = py_widget.cast<const Widget &>();
        REQUIRE(cpp_widget.argv0() == "a.out");
    }
    py::initialize_interpreter();
}
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

TEST_CASE("make_iterator can be called before then after finalizing an interpreter") {
    // Reproduction of issue #2101 (https://github.com/pybind/pybind11/issues/2101)
    py::finalize_interpreter();

    std::vector<int> container;
    {
        pybind11::scoped_interpreter g;
        auto iter = pybind11::make_iterator(container.begin(), container.end());
    }

    REQUIRE_NOTHROW([&]() {
        pybind11::scoped_interpreter g;
        auto iter = pybind11::make_iterator(container.begin(), container.end());
    }());

    py::initialize_interpreter();
}