#include "../engine/infer_engine.hpp" #include "infinicore/tensor.hpp" #include #include namespace py = pybind11; namespace infinilm::engine::distributed { inline void bind_dist_config(py::module &m) { py::class_(m, "DistConfig") .def(py::init<>(), "Default constructor, empty device list") .def(py::init(), py::arg("tp_size"), "Constructor with tensor parallel size, auto-assigns device IDs 0..tp_size-1") .def(py::init &>(), py::arg("tp_device_ids"), "Constructor with explicit device IDs") .def_readwrite("tp_device_ids", &DistConfig::tp_device_ids, "List of device IDs used in tensor parallelism") .def("__repr__", [](const DistConfig &cfg) { return std::string(cfg); }) .def("__str__", [](const DistConfig &cfg) { return std::string(cfg); }); } } // namespace infinilm::engine::distributed namespace infinilm::engine { inline void bind_infer_engine(py::module &m) { py::class_>(m, "InferEngine") .def(py::init([](const infinilm::models::llama::LlamaConfig &cfg, const infinilm::engine::distributed::DistConfig &dist, infinicore::Device::Type dev) { return new InferEngine(std::any(cfg), dist, dev); }), py::arg("config"), py::arg("distributed_config") = distributed::DistConfig(), py::arg("device_type") = infinicore::context::getDevice().getType()) .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") .def("state_dict", [](InferEngine &self) { // Return a dictionary containing references to the whole state of the module. auto state_dict = self.state_dict(); py::dict result; for (const auto &[name, param] : state_dict) { result[py::cast(name)] = infinicore::Tensor(param); } return result; }) .def("generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor { return self.generate(input_ids.cast(), position_ids.cast()); }, "Run inference on all ranks with arbitrary arguments"); // Optionally, you can add __repr__ for debugging m.attr("InferEngine").attr("__repr__") = py::cpp_function([](const InferEngine &self) { return ""; }); } } // namespace infinilm::engine