engine.hpp 3.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include "../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

namespace infinilm::engine::distributed {

inline void bind_dist_config(py::module &m) {
    py::class_<DistConfig>(m, "DistConfig")
        .def(py::init<>(), "Default constructor, empty device list")
        .def(py::init<int>(), py::arg("tp_size"),
             "Constructor with tensor parallel size, auto-assigns device IDs 0..tp_size-1")
        .def(py::init<const std::vector<int> &>(), py::arg("tp_device_ids"),
             "Constructor with explicit device IDs")
        .def_readwrite("tp_device_ids", &DistConfig::tp_device_ids,
                       "List of device IDs used in tensor parallelism")
        .def("__repr__", [](const DistConfig &cfg) {
            return std::string(cfg);
        })
        .def("__str__", [](const DistConfig &cfg) {
            return std::string(cfg);
        });
}

} // namespace infinilm::engine::distributed

namespace infinilm::engine {

inline void bind_infer_engine(py::module &m) {

    py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine")
        .def(py::init([](const infinilm::models::llama::LlamaConfig &cfg,
                         const infinilm::engine::distributed::DistConfig &dist,
                         infinicore::Device::Type dev) {
                 return new InferEngine(std::any(cfg), dist, dev);
             }),
             py::arg("config"),
             py::arg("distributed_config") = distributed::DistConfig(),
             py::arg("device_type") = infinicore::context::getDevice().getType())
        .def("load_param", &InferEngine::load_param,
             py::arg("name"), py::arg("param"),
             "Load a parameter tensor into all workers (each worker picks its shard)")
45
46
        .def("state_dict", [](InferEngine &self) {
            // Return a dictionary containing references to the whole state of the module.
47
48
49
50
51
52
53
            py::list state_dict_tp_all;
            for (const auto &state_dict_tp : self.state_dict()) {
                py::dict result;
                for (const auto &[name, param] : state_dict_tp) {
                    result[py::cast(name)] = infinicore::Tensor(param);
                }
                state_dict_tp_all.append(result);
54
            }
55
56

            return state_dict_tp_all;
57
        })
Ceng's avatar
Ceng committed
58
        .def("generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor { return self.generate(input_ids.cast<infinicore::Tensor>(), position_ids.cast<infinicore::Tensor>()); }, "Run inference on all ranks with arbitrary arguments")
59
60
        .def("reset_cache", &InferEngine::reset_cache, py::arg("pos") = 0, py::arg("async") = false, "Reset the internal cache in all workers to a specific position (clears state between generations). "
                                                                                                     "By default, this is synchronous. If async=True, this becomes asynchronous (unstable - use with caution).");
61
62
63
64
65
66
67
68

    // Optionally, you can add __repr__ for debugging
    m.attr("InferEngine").attr("__repr__") = py::cpp_function([](const InferEngine &self) {
        return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
    });
}

} // namespace infinilm::engine