test_interpreter.cpp 16.1 KB
Newer Older
1
#include <pybind11/embed.h>
2

3
4
// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
// catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
5
PYBIND11_WARNING_DISABLE_MSVC(4996)
6

7
#include <catch.hpp>
8
#include <cstdlib>
9
10
#include <fstream>
#include <functional>
11
12
#include <thread>
#include <utility>
13

14
15
16
17
18
namespace py = pybind11;
using namespace py::literals;

class Widget {
public:
19
    explicit Widget(std::string message) : message(std::move(message)) {}
20
21
22
23
    virtual ~Widget() = default;

    std::string the_message() const { return message; }
    virtual int the_answer() const = 0;
24
    virtual std::string argv0() const = 0;
25
26
27
28
29
30
31
32

private:
    std::string message;
};

class PyWidget final : public Widget {
    using Widget::Widget;

33
    int the_answer() const override { PYBIND11_OVERRIDE_PURE(int, Widget, the_answer); }
34
    std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); }
35
36
};

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class test_override_cache_helper {

public:
    virtual int func() { return 0; }

    test_override_cache_helper() = default;
    virtual ~test_override_cache_helper() = default;
    // Non-copyable
    test_override_cache_helper &operator=(test_override_cache_helper const &Right) = delete;
    test_override_cache_helper(test_override_cache_helper const &Copy) = delete;
};

class test_override_cache_helper_trampoline : public test_override_cache_helper {
    int func() override { PYBIND11_OVERRIDE(int, test_override_cache_helper, func); }
};

53
PYBIND11_EMBEDDED_MODULE(widget_module, m) {
54
55
56
    py::class_<Widget, PyWidget>(m, "Widget")
        .def(py::init<std::string>())
        .def_property_readonly("the_message", &Widget::the_message);
57
58

    m.def("add", [](int i, int j) { return i + j; });
59
}
60

61
PYBIND11_EMBEDDED_MODULE(trampoline_module, m) {
62
63
64
    py::class_<test_override_cache_helper,
               test_override_cache_helper_trampoline,
               std::shared_ptr<test_override_cache_helper>>(m, "test_override_cache_helper")
65
66
67
68
        .def(py::init_alias<>())
        .def("func", &test_override_cache_helper::func);
}

69
PYBIND11_EMBEDDED_MODULE(throw_exception, ) { throw std::runtime_error("C++ Error"); }
70

71
72
73
PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
    auto d = py::dict();
    d["missing"].cast<py::object>();
74
75
}

76
77
78
79
80
81
82
TEST_CASE("PYTHONPATH is used to update sys.path") {
    // The setup for this TEST_CASE is in catch.cpp!
    auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
    REQUIRE_THAT(sys_path,
                 Catch::Matchers::Contains("pybind11_test_embed_PYTHONPATH_2099743835476552"));
}

83
TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
84
85
    auto module_ = py::module_::import("test_interpreter");
    REQUIRE(py::hasattr(module_, "DerivedWidget"));
86

87
    auto locals = py::dict("hello"_a = "Hello, World!", "x"_a = 5, **module_.attr("__dict__"));
88
89
90
    py::exec(R"(
        widget = DerivedWidget("{} - {}".format(hello, x))
        message = widget.the_message
91
92
93
    )",
             py::globals(),
             locals);
94
95
    REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");

96
    auto py_widget = module_.attr("DerivedWidget")("The question");
97
98
    auto message = py_widget.attr("the_message");
    REQUIRE(message.cast<std::string>() == "The question");
99

100
101
102
103
    const auto &cpp_widget = py_widget.cast<const Widget &>();
    REQUIRE(cpp_widget.the_answer() == 42);
}

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
TEST_CASE("Override cache") {
    auto module_ = py::module_::import("test_trampoline");
    REQUIRE(py::hasattr(module_, "func"));
    REQUIRE(py::hasattr(module_, "func2"));

    auto locals = py::dict(**module_.attr("__dict__"));

    int i = 0;
    for (; i < 1500; ++i) {
        std::shared_ptr<test_override_cache_helper> p_obj;
        std::shared_ptr<test_override_cache_helper> p_obj2;

        py::object loc_inst = locals["func"]();
        p_obj = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);

        int ret = p_obj->func();

        REQUIRE(ret == 42);

        loc_inst = locals["func2"]();

        p_obj2 = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);

        p_obj2->func();
    }
}

131
TEST_CASE("Import error handling") {
132
    REQUIRE_NOTHROW(py::module_::import("widget_module"));
133
    REQUIRE_THROWS_WITH(py::module_::import("throw_exception"), "ImportError: C++ Error");
134
135
136
    REQUIRE_THROWS_WITH(py::module_::import("throw_error_already_set"),
                        Catch::Contains("ImportError: initialization failed"));

137
    auto locals = py::dict("is_keyerror"_a = false, "message"_a = "not set");
138
139
140
141
142
143
    py::exec(R"(
        try:
            import throw_error_already_set
        except ImportError as e:
            is_keyerror = type(e.__cause__) == KeyError
            message = str(e.__cause__)
144
145
146
    )",
             py::globals(),
             locals);
147
148
    REQUIRE(locals["is_keyerror"].cast<bool>() == true);
    REQUIRE(locals["message"].cast<std::string>() == "'missing'");
149
}
150

151
152
153
154
155
156
157
158
159
160
161
162
163
164
TEST_CASE("There can be only one interpreter") {
    static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
    static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
    static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
    static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");

    REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
    REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");

    py::finalize_interpreter();
    REQUIRE_NOTHROW(py::scoped_interpreter());
    {
        auto pyi1 = py::scoped_interpreter();
        auto pyi2 = std::move(pyi1);
165
    }
166
    py::initialize_interpreter();
167
}
168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
TEST_CASE("Custom PyConfig") {
    py::finalize_interpreter();
    PyConfig config;
    PyConfig_InitPythonConfig(&config);
    REQUIRE_NOTHROW(py::scoped_interpreter{&config});
    {
        py::scoped_interpreter p{&config};
        REQUIRE(py::module_::import("widget_module").attr("add")(1, 41).cast<int>() == 42);
    }
    py::initialize_interpreter();
}

TEST_CASE("Custom PyConfig with argv") {
    py::finalize_interpreter();
    {
        PyConfig config;
        PyConfig_InitIsolatedConfig(&config);
        char *argv[] = {strdup("a.out")};
        py::scoped_interpreter argv_scope{&config, 1, argv};
        std::free(argv[0]);
        auto module = py::module::import("test_interpreter");
        auto py_widget = module.attr("DerivedWidget")("The question");
        const auto &cpp_widget = py_widget.cast<const Widget &>();
        REQUIRE(cpp_widget.argv0() == "a.out");
    }
    py::initialize_interpreter();
}
#endif

TEST_CASE("Add program dir to path") {
    static auto get_sys_path_size = []() -> size_t {
        auto sys_path = py::module::import("sys").attr("path");
        return py::len(sys_path);
    };
    static auto validate_path_len = [](size_t default_len) {
#if PY_VERSION_HEX < 0x030A0000
        // It seems a value remains in sys.path
        // left by the previous call of scoped_interpreter ctor.
        REQUIRE(get_sys_path_size() > default_len);
#else
        REQUIRE(get_sys_path_size() == default_len + 1);
#endif
    };
    py::finalize_interpreter();

    size_t sys_path_default_size = 0;
    {
        py::scoped_interpreter scoped_interp{true, 0, nullptr, false};
        sys_path_default_size = get_sys_path_size();
    }
    {
        py::scoped_interpreter scoped_interp{}; // expected to append some to sys.path
        validate_path_len(sys_path_default_size);
    }
#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
    {
        PyConfig config;
        PyConfig_InitPythonConfig(&config);
        py::scoped_interpreter scoped_interp{&config}; // expected to append some to sys.path
        validate_path_len(sys_path_default_size);
    }
#endif
    py::initialize_interpreter();
}

235
bool has_pybind11_internals_builtin() {
236
237
238
239
    auto builtins = py::handle(PyEval_GetBuiltins());
    return builtins.contains(PYBIND11_INTERNALS_ID);
};

240
bool has_pybind11_internals_static() {
241
    auto **&ipp = py::detail::get_internals_pp();
242
    return (ipp != nullptr) && (*ipp != nullptr);
243
244
}

245
246
TEST_CASE("Restart the interpreter") {
    // Verify pre-restart state.
247
    REQUIRE(py::module_::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
248
249
    REQUIRE(has_pybind11_internals_builtin());
    REQUIRE(has_pybind11_internals_static());
250
251
    REQUIRE(py::module_::import("external_module").attr("A")(123).attr("value").cast<int>()
            == 123);
252
253

    // local and foreign module internals should point to the same internals:
254
255
    REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
            == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
256
257
258
259
260
261
262
263
264

    // Restart the interpreter.
    py::finalize_interpreter();
    REQUIRE(Py_IsInitialized() == 0);

    py::initialize_interpreter();
    REQUIRE(Py_IsInitialized() == 1);

    // Internals are deleted after a restart.
265
266
    REQUIRE_FALSE(has_pybind11_internals_builtin());
    REQUIRE_FALSE(has_pybind11_internals_static());
267
    pybind11::detail::get_internals();
268
269
    REQUIRE(has_pybind11_internals_builtin());
    REQUIRE(has_pybind11_internals_static());
270
271
    REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
            == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
272
273
274
275
276
277

    // Make sure that an interpreter with no get_internals() created until finalize still gets the
    // internals destroyed
    py::finalize_interpreter();
    py::initialize_interpreter();
    bool ran = false;
278
279
280
281
282
    py::module_::import("__main__").attr("internals_destroy_test")
        = py::capsule(&ran, [](void *ran) {
              py::detail::get_internals();
              *static_cast<bool *>(ran) = true;
          });
283
284
285
286
287
288
289
290
    REQUIRE_FALSE(has_pybind11_internals_builtin());
    REQUIRE_FALSE(has_pybind11_internals_static());
    REQUIRE_FALSE(ran);
    py::finalize_interpreter();
    REQUIRE(ran);
    py::initialize_interpreter();
    REQUIRE_FALSE(has_pybind11_internals_builtin());
    REQUIRE_FALSE(has_pybind11_internals_static());
291
292

    // C++ modules can be reloaded.
293
    auto cpp_module = py::module_::import("widget_module");
294
295
296
    REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);

    // C++ type information is reloaded and can be used in python modules.
297
    auto py_module = py::module_::import("test_interpreter");
298
299
300
301
302
303
    auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
    REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
}

TEST_CASE("Subinterpreter") {
    // Add tags to the modules in the main interpreter and test the basics.
304
    py::module_::import("__main__").attr("main_tag") = "main interpreter";
305
    {
306
        auto m = py::module_::import("widget_module");
307
308
309
310
        m.attr("extension_module_tag") = "added to module in main interpreter";

        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
    }
311
312
    REQUIRE(has_pybind11_internals_builtin());
    REQUIRE(has_pybind11_internals_static());
313
314

    /// Create and switch to a subinterpreter.
315
316
    auto *main_tstate = PyThreadState_Get();
    auto *sub_tstate = Py_NewInterpreter();
317
318
319
320

    // Subinterpreters get their own copy of builtins. detail::get_internals() still
    // works by returning from the static variable, i.e. all interpreters share a single
    // global pybind11::internals;
321
322
    REQUIRE_FALSE(has_pybind11_internals_builtin());
    REQUIRE(has_pybind11_internals_static());
323
324

    // Modules tags should be gone.
325
    REQUIRE_FALSE(py::hasattr(py::module_::import("__main__"), "tag"));
326
    {
327
        auto m = py::module_::import("widget_module");
328
329
330
331
332
333
334
335
336
337
        REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));

        // Function bindings should still work.
        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
    }

    // Restore main interpreter.
    Py_EndInterpreter(sub_tstate);
    PyThreadState_Swap(main_tstate);

338
339
    REQUIRE(py::hasattr(py::module_::import("__main__"), "main_tag"));
    REQUIRE(py::hasattr(py::module_::import("widget_module"), "extension_module_tag"));
340
}
341
342
343
344
345
346
347

TEST_CASE("Execution frame") {
    // When the interpreter is embedded, there is no execution frame, but `py::exec`
    // should still function by using reasonable globals: `__main__.__dict__`.
    py::exec("var = dict(number=42)");
    REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
}
348
349
350
351
352
353
354
355

TEST_CASE("Threads") {
    // Restart interpreter to ensure threads are not initialized
    py::finalize_interpreter();
    py::initialize_interpreter();
    REQUIRE_FALSE(has_pybind11_internals_static());

    constexpr auto num_threads = 10;
356
    auto locals = py::dict("count"_a = 0);
357
358
359
360
361
362
363
364

    {
        py::gil_scoped_release gil_release{};

        auto threads = std::vector<std::thread>();
        for (auto i = 0; i < num_threads; ++i) {
            threads.emplace_back([&]() {
                py::gil_scoped_acquire gil{};
365
                locals["count"] = locals["count"].cast<int>() + 1;
366
367
368
369
370
371
372
373
374
375
            });
        }

        for (auto &thread : threads) {
            thread.join();
        }
    }

    REQUIRE(locals["count"].cast<int>() == num_threads);
}
376
377
378
379
380

// Scope exit utility https://stackoverflow.com/a/36644501/7255855
struct scope_exit {
    std::function<void()> f_;
    explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
381
382
383
384
385
    ~scope_exit() {
        if (f_) {
            f_();
        }
    }
386
387
388
389
390
391
};

TEST_CASE("Reload module from file") {
    // Disable generation of cached bytecode (.pyc files) for this test, otherwise
    // Python might pick up an old version from the cache instead of the new versions
    // of the .py files generated below
392
    auto sys = py::module_::import("sys");
393
394
395
    bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
    sys.attr("dont_write_bytecode") = true;
    // Reset the value at scope exit
396
397
    scope_exit reset_dont_write_bytecode(
        [&]() { sys.attr("dont_write_bytecode") = dont_write_bytecode; });
398
399
400
401
402
403
404
405
406
407

    std::string module_name = "test_module_reload";
    std::string module_file = module_name + ".py";

    // Create the module .py file
    std::ofstream test_module(module_file);
    test_module << "def test():\n";
    test_module << "    return 1\n";
    test_module.close();
    // Delete the file at scope exit
408
    scope_exit delete_module_file([&]() { std::remove(module_file.c_str()); });
409
410

    // Import the module from file
411
412
    auto module_ = py::module_::import(module_name.c_str());
    int result = module_.attr("test")().cast<int>();
413
414
415
416
417
418
419
420
421
    REQUIRE(result == 1);

    // Update the module .py file with a small change
    test_module.open(module_file);
    test_module << "def test():\n";
    test_module << "    return 2\n";
    test_module.close();

    // Reload the module
422
423
    module_.reload();
    result = module_.attr("test")().cast<int>();
424
425
    REQUIRE(result == 2);
}
426
427
428
429
430
431
432
433
434
435
436
437
438
439

TEST_CASE("sys.argv gets initialized properly") {
    py::finalize_interpreter();
    {
        py::scoped_interpreter default_scope;
        auto module = py::module::import("test_interpreter");
        auto py_widget = module.attr("DerivedWidget")("The question");
        const auto &cpp_widget = py_widget.cast<const Widget &>();
        REQUIRE(cpp_widget.argv0().empty());
    }

    {
        char *argv[] = {strdup("a.out")};
        py::scoped_interpreter argv_scope(true, 1, argv);
440
        std::free(argv[0]);
441
442
443
444
445
446
447
        auto module = py::module::import("test_interpreter");
        auto py_widget = module.attr("DerivedWidget")("The question");
        const auto &cpp_widget = py_widget.cast<const Widget &>();
        REQUIRE(cpp_widget.argv0() == "a.out");
    }
    py::initialize_interpreter();
}
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465

TEST_CASE("make_iterator can be called before then after finalizing an interpreter") {
    // Reproduction of issue #2101 (https://github.com/pybind/pybind11/issues/2101)
    py::finalize_interpreter();

    std::vector<int> container;
    {
        pybind11::scoped_interpreter g;
        auto iter = pybind11::make_iterator(container.begin(), container.end());
    }

    REQUIRE_NOTHROW([&]() {
        pybind11::scoped_interpreter g;
        auto iter = pybind11::make_iterator(container.begin(), container.end());
    }());

    py::initialize_interpreter();
}