engine.hpp 9.83 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "../../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

namespace infinilm::engine::distributed {

inline void bind_dist_config(py::module &m) {
    py::class_<DistConfig>(m, "DistConfig")
        .def(py::init<>(), "Default constructor, empty device list")
        .def(py::init<int>(), py::arg("tp_size"),
             "Constructor with tensor parallel size, auto-assigns device IDs 0..tp_size-1")
        .def(py::init<const std::vector<int> &>(), py::arg("tp_device_ids"),
             "Constructor with explicit device IDs")
        .def_readwrite("tp_device_ids", &DistConfig::tp_device_ids,
                       "List of device IDs used in tensor parallelism")
        .def("__repr__", [](const DistConfig &cfg) {
            return std::string(cfg);
        })
        .def("__str__", [](const DistConfig &cfg) {
            return std::string(cfg);
        });
}

} // namespace infinilm::engine::distributed

namespace infinilm::engine {

inline void bind_infer_engine(py::module &m) {
    py::class_<InferEngine, std::shared_ptr<InferEngine>> infer_engine(m, "InferEngine");
    infer_engine
        .def(py::init([](
                          const InfinilmModel::Config &cfg,
                          const distributed::DistConfig &dist,
                          infinicore::Device::Type dev,
38
                          std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
39
40
                          bool enable_graph_compiling,
                          const std::string &attention_backend) {
PanZezhong's avatar
PanZezhong committed
41
42
43
44
                 return std::make_shared<InferEngine>(
                     cfg,
                     dist,
                     dev,
45
                     cache_cfg ? cache_cfg.get() : nullptr,
46
47
                     enable_graph_compiling,
                     infinilm::backends::parse_attention_backend(attention_backend));
PanZezhong's avatar
PanZezhong committed
48
49
50
51
             }),
             py::arg("config"),
             py::arg("distributed_config") = distributed::DistConfig(),
             py::arg("device_type") = infinicore::context::getDevice().getType(),
52
             py::arg("cache_config") = py::none(),
53
54
             py::arg("enable_graph_compiling") = false,
             py::arg("attention_backend") = "default")
55
56
57
        .def("load_param", &InferEngine::load_param,
             py::arg("name"), py::arg("param"),
             "Load a parameter tensor into all workers (each worker picks its shard)")
PanZezhong's avatar
PanZezhong committed
58
59
60
61
62
63
64
65
66
67
68
        .def("state_dict", [](InferEngine &self) {
            py::list state_dict_tp_all;
            for (const auto &state_dict_tp : self.state_dict()) {
                py::dict result;
                for (const auto &[name, param] : state_dict_tp) {
                    result[py::cast(name)] = infinicore::Tensor(param);
                }
                state_dict_tp_all.append(result);
            }
            return state_dict_tp_all;
        })
69
70
71
72
        .def(
            "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
        .def(
            "reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
PanZezhong's avatar
PanZezhong committed
73
        .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
PanZezhong's avatar
PanZezhong committed
74
            auto cfg = self.get_cache_config();
PanZezhong's avatar
PanZezhong committed
75
76
            return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr;
        })
77
78
79
80
81
82
83
84
        .def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });

    infer_engine
        .def(py::init([](
                          const std::string &model_path,
                          const distributed::DistConfig &dist,
                          infinicore::Device::Type dev,
                          std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
85
86
                          bool enable_graph_compiling,
                          const std::string &attention_backend) {
87
88
89
90
91
                 return std::make_shared<InferEngine>(
                     model_path,
                     dist,
                     dev,
                     cache_cfg ? cache_cfg.get() : nullptr,
92
93
                     enable_graph_compiling,
                     infinilm::backends::parse_attention_backend(attention_backend));
94
95
96
97
98
             }),
             py::arg("model_path") = "",
             py::arg("distributed_config") = distributed::DistConfig(),
             py::arg("device_type") = infinicore::context::getDevice().getType(),
             py::arg("cache_config") = py::none(),
99
100
             py::arg("enable_graph_compiling") = false,
             py::arg("attention_backend") = "default")
101
102
103
104
105
106
107
108
109
110
111
112
113
        .def("load_param", &InferEngine::load_param,
             py::arg("name"), py::arg("param"),
             "Load a parameter tensor into all workers (each worker picks its shard)")
        .def("state_dict", [](InferEngine &self) {
            py::list state_dict_tp_all;
            for (const auto &state_dict_tp : self.state_dict()) {
                py::dict result;
                for (const auto &[name, param] : state_dict_tp) {
                    result[py::cast(name)] = infinicore::Tensor(param);
                }
                state_dict_tp_all.append(result);
            }
            return state_dict_tp_all;
PanZezhong's avatar
PanZezhong committed
114
        })
115
116
117
118
        .def(
            "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
        .def(
            "reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
119
120
121
122
        .def("get_cache_config", [](const InferEngine &self) {
            auto cfg = self.get_cache_config();
            return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })
        .def("__repr__", [](const InferEngine &self) { return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; });
PanZezhong's avatar
PanZezhong committed
123
124

    py::class_<InferEngine::Input>(infer_engine, "Input")
125
126
127
128
        .def(
            py::init([](
                         std::optional<infinicore::Tensor> input_ids,
                         std::optional<infinicore::Tensor> position_ids,
129
130
                         std::optional<infinicore::Tensor> past_sequence_lengths,
                         std::optional<infinicore::Tensor> total_sequence_lengths,
131
                         std::optional<infinicore::Tensor> input_offsets,
132
                         std::optional<infinicore::Tensor> cu_seqlens,
133
                         std::optional<infinicore::Tensor> block_tables,
134
135
                         std::optional<infinicore::Tensor> slot_mapping,
                         py::kwargs kwargs) {
136
                InferEngine::Input input{
137
138
                    std::move(input_ids),
                    std::move(position_ids),
139
140
                    std::move(past_sequence_lengths),
                    std::move(total_sequence_lengths),
141
                    std::move(input_offsets),
142
                    std::move(cu_seqlens),
143
                    std::move(block_tables),
144
145
                    std::move(slot_mapping),
                };
146

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
                // Explicit defaults
                input.temperature = 1.0f;
                input.top_p = 1.0f;
                input.top_k = 1;

                // Allowed keyword arguments
                static const std::unordered_set<std::string> allowed_kwargs = {
                    "temperature",
                    "top_p",
                    "top_k",
                };

                for (auto &item : kwargs) {
                    const std::string key = py::cast<std::string>(item.first);

                    if (allowed_kwargs.find(key) == allowed_kwargs.end()) {
                        throw py::value_error(
                            "InferEngine.Input got an unexpected keyword argument '" + key + "'");
165
                    }
166
167
168
169
170
171
172

                    if (key == "temperature") {
                        input.temperature = py::cast<float>(item.second);
                    } else if (key == "top_p") {
                        input.top_p = py::cast<float>(item.second);
                    } else if (key == "top_k") {
                        input.top_k = py::cast<int>(item.second);
173
174
175
176
                    }
                }

                return input;
177
178
179
            }),
            py::arg("input_ids") = std::nullopt,
            py::arg("position_ids") = std::nullopt,
180
181
            py::arg("past_sequence_lengths") = std::nullopt,
            py::arg("total_sequence_lengths") = std::nullopt,
182
            py::arg("input_offsets") = std::nullopt,
183
            py::arg("cu_seqlens") = std::nullopt,
184
185
186
187
            py::arg("block_tables") = std::nullopt,
            py::arg("slot_mapping") = std::nullopt)
        .def_readwrite("input_ids", &InferEngine::Input::input_ids)
        .def_readwrite("position_ids", &InferEngine::Input::position_ids)
188
189
        .def_readwrite("past_sequence_lengths", &InferEngine::Input::past_sequence_lengths)
        .def_readwrite("total_sequence_lengths", &InferEngine::Input::total_sequence_lengths)
190
        .def_readwrite("input_offsets", &InferEngine::Input::input_offsets)
191
        .def_readwrite("cu_seqlens", &InferEngine::Input::cu_seqlens)
192
        .def_readwrite("block_tables", &InferEngine::Input::block_tables)
193
194
195
196
        .def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping)
        .def_readwrite("temperature", &InferEngine::Input::temperature)
        .def_readwrite("top_k", &InferEngine::Input::top_k)
        .def_readwrite("top_p", &InferEngine::Input::top_p);
PanZezhong's avatar
PanZezhong committed
197
198

    py::class_<InferEngine::Output>(infer_engine, "Output")
199
        .def_readwrite("output_ids", &InferEngine::Output::output_ids, "Output tensor");
PanZezhong's avatar
PanZezhong committed
200
201
202
}

} // namespace infinilm::engine