engine.hpp 6.04 KB
Newer Older
1
#include "../cache/cache_config.hpp"
2
3
4
5
6
7
8
#include "../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
namespace infinilm::cache {

inline void bind_cache_config(py::module &m) {
    // First bind the CacheType enum
    py::enum_<CacheType>(m, "CacheType")
        .value("DYNAMIC", CacheType::DYNAMIC)
        .value("PAGED", CacheType::PAGED)
        .export_values();

    // Then bind the CacheResetMode enum
    py::enum_<CacheResetMode>(m, "CacheResetMode")
        .value("PRESERVE", CacheResetMode::PRESERVE)
        .value("RECREATE", CacheResetMode::RECREATE)
        .export_values();

    // Finally bind the CacheConfig struct
    py::class_<CacheConfig>(m, "CacheConfig")
        .def(py::init<>(), "Default constructor")
        .def(py::init<CacheType, size_t, size_t>(),
             py::arg("type") = CacheType::DYNAMIC,
             py::arg("num_layers") = 32,
             py::arg("max_kv_cache_length") = 4096,
             "Constructor with parameters")
        .def_readwrite("type", &CacheConfig::type, "Cache type")
        .def_readwrite("num_layers", &CacheConfig::num_layers, "Number of layers")
        .def_readwrite("max_kv_cache_length", &CacheConfig::max_kv_cache_length,
                       "Maximum KV cache length")
        .def_readwrite("initial_capacity", &CacheConfig::initial_capacity,
                       "Initial cache capacity in tokens")
        .def_readwrite("initial_batch_size", &CacheConfig::initial_batch_size,
                       "Initial batch size for cache allocation")
        .def_readwrite("growth_factor", &CacheConfig::growth_factor,
                       "Cache growth factor when resizing (e.g., 2.0 for doubling)")
        .def_readwrite("allow_expand", &CacheConfig::allow_expand,
                       "Whether to allow cache expansion")
        .def_readwrite("reset_mode", &CacheConfig::reset_mode,
                       "Cache reset mode")
        .def("__eq__", &CacheConfig::operator==, py::is_operator(),
             "Check if two CacheConfig objects are equal")
        .def("__ne__", &CacheConfig::operator!=, py::is_operator(),
             "Check if two CacheConfig objects are not equal")
        .def("__repr__", [](const CacheConfig &cfg) {
            return fmt::format("CacheConfig(type={}, num_layers={}, max_kv_cache_length={}, "
                               "initial_capacity={}, initial_batch_size={}, growth_factor={}, "
                               "allow_expand={}, reset_mode={})",
                               static_cast<int>(cfg.type), cfg.num_layers,
                               cfg.max_kv_cache_length, cfg.initial_capacity,
                               cfg.initial_batch_size, cfg.growth_factor,
                               cfg.allow_expand, static_cast<int>(cfg.reset_mode));
        });
}

} // namespace infinilm::cache

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
namespace infinilm::engine::distributed {

inline void bind_dist_config(py::module &m) {
    py::class_<DistConfig>(m, "DistConfig")
        .def(py::init<>(), "Default constructor, empty device list")
        .def(py::init<int>(), py::arg("tp_size"),
             "Constructor with tensor parallel size, auto-assigns device IDs 0..tp_size-1")
        .def(py::init<const std::vector<int> &>(), py::arg("tp_device_ids"),
             "Constructor with explicit device IDs")
        .def_readwrite("tp_device_ids", &DistConfig::tp_device_ids,
                       "List of device IDs used in tensor parallelism")
        .def("__repr__", [](const DistConfig &cfg) {
            return std::string(cfg);
        })
        .def("__str__", [](const DistConfig &cfg) {
            return std::string(cfg);
        });
}

} // namespace infinilm::engine::distributed

namespace infinilm::engine {

inline void bind_infer_engine(py::module &m) {
    py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine")
Jiacheng Huang's avatar
Jiacheng Huang committed
88
        .def(py::init([](const InfinilmModel::Config &cfg,
89
                         const infinilm::engine::distributed::DistConfig &dist,
90
91
                         infinicore::Device::Type dev,
                         const infinilm::cache::CacheConfig &cache_config) {
Jiacheng Huang's avatar
Jiacheng Huang committed
92
                 return new InferEngine(cfg, dist, dev, cache_config);
93
94
95
             }),
             py::arg("config"),
             py::arg("distributed_config") = distributed::DistConfig(),
96
97
             py::arg("device_type") = infinicore::context::getDevice().getType(),
             py::arg("cache_config") = infinilm::cache::CacheConfig())
98
99
100
        .def("load_param", &InferEngine::load_param,
             py::arg("name"), py::arg("param"),
             "Load a parameter tensor into all workers (each worker picks its shard)")
101
        .def("state_dict", [](InferEngine &self) {
102
103
104
105
106
107
108
            py::list state_dict_tp_all;
            for (const auto &state_dict_tp : self.state_dict()) {
                py::dict result;
                for (const auto &[name, param] : state_dict_tp) {
                    result[py::cast(name)] = infinicore::Tensor(param);
                }
                state_dict_tp_all.append(result);
109
            }
110
            return state_dict_tp_all;
111
        })
112
113
114
115
116
117
118
119
120
121
122
        .def(
            "generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor {
                return self.generate(input_ids.cast<infinicore::Tensor>(), position_ids.cast<infinicore::Tensor>());
            },
            "Run inference on all ranks with arbitrary arguments")
        .def("reset_cache", py::overload_cast<size_t>(&InferEngine::reset_cache), py::arg("pos") = 0, "Reset the internal cache in all workers to a specific position")
        .def("reset_cache", py::overload_cast<const cache::CacheConfig &, size_t>(&InferEngine::reset_cache), py::arg("cache_config"), py::arg("pos") = 0, "Reset cache with new KV configuration")
        .def("get_cache_config", &InferEngine::get_cache_config, "Get current KV configuration")
        .def("__repr__", [](const InferEngine &self) {
            return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
        });
123
124
125
}

} // namespace infinilm::engine