engine.hpp 9.33 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "../../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

namespace infinilm::engine::distributed {

inline void bind_dist_config(py::module &m) {
    py::class_<DistConfig>(m, "DistConfig")
        .def(py::init<>(), "Default constructor, empty device list")
        .def(py::init<int>(), py::arg("tp_size"),
             "Constructor with tensor parallel size, auto-assigns device IDs 0..tp_size-1")
        .def(py::init<const std::vector<int> &>(), py::arg("tp_device_ids"),
             "Constructor with explicit device IDs")
        .def_readwrite("tp_device_ids", &DistConfig::tp_device_ids,
                       "List of device IDs used in tensor parallelism")
        .def("__repr__", [](const DistConfig &cfg) {
            return std::string(cfg);
        })
        .def("__str__", [](const DistConfig &cfg) {
            return std::string(cfg);
        });
}

} // namespace infinilm::engine::distributed

namespace infinilm::engine {

inline void bind_infer_engine(py::module &m) {
    py::class_<InferEngine, std::shared_ptr<InferEngine>> infer_engine(m, "InferEngine");
    infer_engine
        .def(py::init([](
                          const InfinilmModel::Config &cfg,
                          const distributed::DistConfig &dist,
                          infinicore::Device::Type dev,
38
39
                          std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
                          bool enable_graph_compiling) {
PanZezhong's avatar
PanZezhong committed
40
41
42
43
                 return std::make_shared<InferEngine>(
                     cfg,
                     dist,
                     dev,
44
45
                     cache_cfg ? cache_cfg.get() : nullptr,
                     enable_graph_compiling);
PanZezhong's avatar
PanZezhong committed
46
47
48
49
             }),
             py::arg("config"),
             py::arg("distributed_config") = distributed::DistConfig(),
             py::arg("device_type") = infinicore::context::getDevice().getType(),
50
51
             py::arg("cache_config") = py::none(),
             py::arg("enable_graph_compiling") = false)
52
53
54
        .def("load_param", &InferEngine::load_param,
             py::arg("name"), py::arg("param"),
             "Load a parameter tensor into all workers (each worker picks its shard)")
PanZezhong's avatar
PanZezhong committed
55
56
57
58
59
60
61
62
63
64
65
        .def("state_dict", [](InferEngine &self) {
            py::list state_dict_tp_all;
            for (const auto &state_dict_tp : self.state_dict()) {
                py::dict result;
                for (const auto &[name, param] : state_dict_tp) {
                    result[py::cast(name)] = infinicore::Tensor(param);
                }
                state_dict_tp_all.append(result);
            }
            return state_dict_tp_all;
        })
66
67
        .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
        .def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
PanZezhong's avatar
PanZezhong committed
68
69
        .def("get_cache_config", [](const InferEngine &self) {
            auto cfg = self.get_cache_config();
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
            return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
        .def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });

    infer_engine
        .def(py::init([](
                          const std::string &model_path,
                          const distributed::DistConfig &dist,
                          infinicore::Device::Type dev,
                          std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
                          bool enable_graph_compiling) {
                 return std::make_shared<InferEngine>(
                     model_path,
                     dist,
                     dev,
                     cache_cfg ? cache_cfg.get() : nullptr,
                     enable_graph_compiling);
             }),
             py::arg("model_path") = "",
             py::arg("distributed_config") = distributed::DistConfig(),
             py::arg("device_type") = infinicore::context::getDevice().getType(),
             py::arg("cache_config") = py::none(),
             py::arg("enable_graph_compiling") = false)
        .def("load_param", &InferEngine::load_param,
             py::arg("name"), py::arg("param"),
             "Load a parameter tensor into all workers (each worker picks its shard)")
        .def("state_dict", [](InferEngine &self) {
            py::list state_dict_tp_all;
            for (const auto &state_dict_tp : self.state_dict()) {
                py::dict result;
                for (const auto &[name, param] : state_dict_tp) {
                    result[py::cast(name)] = infinicore::Tensor(param);
                }
                state_dict_tp_all.append(result);
            }
            return state_dict_tp_all;
PanZezhong's avatar
PanZezhong committed
105
        })
106
107
108
109
110
111
        .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
        .def("reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
        .def("get_cache_config", [](const InferEngine &self) {
            auto cfg = self.get_cache_config();
            return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
        .def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
PanZezhong's avatar
PanZezhong committed
112
113

    py::class_<InferEngine::Input>(infer_engine, "Input")
114
115
116
117
        .def(
            py::init([](
                         std::optional<infinicore::Tensor> input_ids,
                         std::optional<infinicore::Tensor> position_ids,
118
119
                         std::optional<infinicore::Tensor> past_sequence_lengths,
                         std::optional<infinicore::Tensor> total_sequence_lengths,
120
                         std::optional<infinicore::Tensor> input_offsets,
121
                         std::optional<infinicore::Tensor> cu_seqlens,
122
                         std::optional<infinicore::Tensor> block_tables,
123
124
                         std::optional<infinicore::Tensor> slot_mapping,
                         py::kwargs kwargs) {
125
                InferEngine::Input input{
126
127
                    std::move(input_ids),
                    std::move(position_ids),
128
129
                    std::move(past_sequence_lengths),
                    std::move(total_sequence_lengths),
130
                    std::move(input_offsets),
131
                    std::move(cu_seqlens),
132
                    std::move(block_tables),
133
134
                    std::move(slot_mapping),
                };
135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                // Explicit defaults
                input.temperature = 1.0f;
                input.top_p = 1.0f;
                input.top_k = 1;

                // Allowed keyword arguments
                static const std::unordered_set<std::string> allowed_kwargs = {
                    "temperature",
                    "top_p",
                    "top_k",
                };

                for (auto &item : kwargs) {
                    const std::string key = py::cast<std::string>(item.first);

                    if (allowed_kwargs.find(key) == allowed_kwargs.end()) {
                        throw py::value_error(
                            "InferEngine.Input got an unexpected keyword argument '" + key + "'");
154
                    }
155
156
157
158
159
160
161

                    if (key == "temperature") {
                        input.temperature = py::cast<float>(item.second);
                    } else if (key == "top_p") {
                        input.top_p = py::cast<float>(item.second);
                    } else if (key == "top_k") {
                        input.top_k = py::cast<int>(item.second);
162
163
164
165
                    }
                }

                return input;
166
167
168
            }),
            py::arg("input_ids") = std::nullopt,
            py::arg("position_ids") = std::nullopt,
169
170
            py::arg("past_sequence_lengths") = std::nullopt,
            py::arg("total_sequence_lengths") = std::nullopt,
171
            py::arg("input_offsets") = std::nullopt,
172
            py::arg("cu_seqlens") = std::nullopt,
173
174
175
176
            py::arg("block_tables") = std::nullopt,
            py::arg("slot_mapping") = std::nullopt)
        .def_readwrite("input_ids", &InferEngine::Input::input_ids)
        .def_readwrite("position_ids", &InferEngine::Input::position_ids)
177
178
        .def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths)
        .def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths)
179
        .def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
180
        .def_readwrite("cu_seqlens", &InferEngine::Input::cu_seqlens)
181
        .def_readwrite("block_tables", &InferEngine::Input::block_tables)
182
183
184
185
        .def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping)
        .def_readwrite("temperature", &InferEngine::Input::temperature)
        .def_readwrite("top_k", &InferEngine::Input::top_k)
        .def_readwrite("top_p", &InferEngine::Input::top_p);
PanZezhong's avatar
PanZezhong committed
186
187

    py::class_<InferEngine::Output>(infer_engine, "Output")
188
        .def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor");
PanZezhong's avatar
PanZezhong committed
189
190
191
}

} // namespace infinilm::engine