bind.cpp 12.5 KB
Newer Older
1
2
#include "scheduler.h"
#include <memory>
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <torch/extension.h>

namespace py = pybind11;

PYBIND11_MODULE(sched_ext, m) {
  py::class_<scheduler::ModelSettings>(m, "ModelSettings")
      .def(py::init<>())
      .def_readwrite("model_path", &scheduler::ModelSettings::model_path)
      .def_readwrite("params_count", &scheduler::ModelSettings::params_count)
      .def_readwrite("layer_count", &scheduler::ModelSettings::layer_count)
      .def_readwrite("num_k_heads", &scheduler::ModelSettings::num_k_heads)
      .def_readwrite("k_head_dim", &scheduler::ModelSettings::k_head_dim)
19
20
21
22
      .def_readwrite("bytes_per_params",
                     &scheduler::ModelSettings::bytes_per_params)
      .def_readwrite("bytes_per_kv_cache_element",
                     &scheduler::ModelSettings::bytes_per_kv_cache_element)
23
      .def("params_size", &scheduler::ModelSettings::params_nbytes)
24
25
      .def("bytes_per_token_kv_cache",
           &scheduler::ModelSettings::bytes_per_token_kv_cache)
26
27
      // 添加 pickle 支持
      .def(py::pickle(
28
29
30
31
32
          [](const scheduler::ModelSettings &self) { // __getstate__
            return py::make_tuple(self.params_count, self.layer_count,
                                  self.num_k_heads, self.k_head_dim,
                                  self.bytes_per_params,
                                  self.bytes_per_kv_cache_element);
33
          },
34
          [](py::tuple t) { // __setstate__
35
            if (t.size() != 6)
36
37
              throw std::runtime_error("Invalid state! t.size() = " +
                                       std::to_string(t.size()));
38
39
40
41
42
43
44
45
46
47
48
            scheduler::ModelSettings ms;
            ms.params_count = t[0].cast<size_t>();
            ms.layer_count = t[1].cast<size_t>();
            ms.num_k_heads = t[2].cast<size_t>();
            ms.k_head_dim = t[3].cast<size_t>();
            ms.bytes_per_params = t[4].cast<double>();
            ms.bytes_per_kv_cache_element = t[5].cast<double>();
            return ms;
          }));

  py::class_<scheduler::SampleOptions>(m, "SampleOptions")
49
50
51
52
53
54
55
56
57
58
59
60
61
      .def(py::init<>())
      .def_readwrite("temperature", &scheduler::SampleOptions::temperature)
      .def_readwrite("top_p",
                     &scheduler::SampleOptions::top_p) // 确保 top_p 也能被访问
      .def(py::pickle(
          [](const scheduler::SampleOptions &self) {
            return py::make_tuple(self.temperature,
                                  self.top_p); // 序列化 temperature 和 top_p
          },
          [](py::tuple t) {
            if (t.size() != 2) // 确保解包时参数数量匹配
              throw std::runtime_error("Invalid state! t.size() = " +
                                       std::to_string(t.size()));
62
63
            scheduler::SampleOptions so;
            so.temperature = t[0].cast<double>();
64
            so.top_p = t[1].cast<double>(); // 反序列化 top_p
65
            return so;
66
          }));
67
68
69
70
71
72
73
74
75

  py::class_<scheduler::Settings>(m, "Settings")
      .def(py::init<>())
      .def_readwrite("model_name", &scheduler::Settings::model_name)
      .def_readwrite("quant_type", &scheduler::Settings::quant_type)
      .def_readwrite("model_settings", &scheduler::Settings::model_settings)
      .def_readwrite("page_size", &scheduler::Settings::page_size)
      .def_readwrite("gpu_device_id", &scheduler::Settings::gpu_device_id)
      .def_readwrite("gpu_memory_size", &scheduler::Settings::gpu_memory_size)
76
77
      .def_readwrite("memory_utilization_percentage",
                     &scheduler::Settings::memory_utilization_percentage)
78
      .def_readwrite("max_batch_size", &scheduler::Settings::max_batch_size)
79
80
81
      .def_readwrite(
          "recommended_chunk_prefill_token_count",
          &scheduler::Settings::recommended_chunk_prefill_token_count)
82
      .def_readwrite("sample_options", &scheduler::Settings::sample_options)
83
84
      .def_readwrite("sched_metrics_port",
                     &scheduler::Settings::sched_metrics_port)
85
      .def_readwrite("gpu_only", &scheduler::Settings::gpu_only)
86
87
88
89
90
91
      .def_readwrite("use_self_defined_head_dim",
                     &scheduler::Settings::use_self_defined_head_dim)
      .def_readwrite("self_defined_head_dim",
                     &scheduler::Settings::self_defined_head_dim)
      .def_readwrite("full_kv_cache_on_each_gpu",
                     &scheduler::Settings::full_kv_cache_on_each_gpu)
92
93
94
95
      .def_readwrite("k_cache_on", &scheduler::Settings::k_cache_on)
      .def_readwrite("v_cache_on", &scheduler::Settings::v_cache_on)
      .def_readwrite("kvc2_config_path", &scheduler::Settings::kvc2_config_path)
      .def_readwrite("kvc2_root_path", &scheduler::Settings::kvc2_root_path)
96
97
      .def_readwrite("memory_pool_size_GB",
                     &scheduler::Settings::memory_pool_size_GB)
98
99
      .def_readwrite("evict_count", &scheduler::Settings::evict_count)
      .def_readwrite("strategy_name", &scheduler::Settings::strategy_name)
100
101
      .def_readwrite("kvc2_metrics_port",
                     &scheduler::Settings::kvc2_metrics_port)
102
103
104
105
      .def_readwrite("load_from_disk", &scheduler::Settings::load_from_disk)
      .def_readwrite("save_to_disk", &scheduler::Settings::save_to_disk)
      // derived
      .def_readwrite("gpu_device_count", &scheduler::Settings::gpu_device_count)
106
107
      .def_readwrite("total_kvcache_pages",
                     &scheduler::Settings::total_kvcache_pages)
108
109
110
      .def_readwrite("devices", &scheduler::Settings::devices)
      .def("auto_derive", &scheduler::Settings::auto_derive);

111
112
  py::class_<scheduler::BatchQueryTodo,
             std::shared_ptr<scheduler::BatchQueryTodo>>(m, "BatchQueryTodo")
113
114
115
116
117
118
119
      .def(py::init<>())
      .def_readwrite("query_ids", &scheduler::BatchQueryTodo::query_ids)
      .def_readwrite("query_tokens", &scheduler::BatchQueryTodo::query_tokens)
      .def_readwrite("query_lengths", &scheduler::BatchQueryTodo::query_lengths)
      .def_readwrite("block_indexes", &scheduler::BatchQueryTodo::block_indexes)
      .def_readwrite("attn_masks", &scheduler::BatchQueryTodo::attn_masks)
      .def_readwrite("rope_ranges", &scheduler::BatchQueryTodo::rope_ranges)
120
121
122
123
124
125
      .def_readwrite("sample_options",
                     &scheduler::BatchQueryTodo::sample_options)
      .def_readwrite("prefill_mini_batches",
                     &scheduler::BatchQueryTodo::prefill_mini_batches)
      .def_readwrite("decode_mini_batches",
                     &scheduler::BatchQueryTodo::decode_mini_batches)
126
127
128
      .def_readwrite("stop_criteria", &scheduler::BatchQueryTodo::stop_criteria)
      .def("debug", &scheduler::BatchQueryTodo::debug)
      .def(py::pickle(
129
130
131
132
133
134
          [](const scheduler::BatchQueryTodo &self) {
            return py::make_tuple(
                self.query_ids, self.query_tokens, self.query_lengths,
                self.block_indexes, self.attn_masks, self.rope_ranges,
                self.sample_options, self.prefill_mini_batches,
                self.decode_mini_batches, self.stop_criteria);
135
136
137
          },
          [](py::tuple t) {
            if (t.size() != 10)
138
139
              throw std::runtime_error("Invalid state! t.size() = " +
                                       std::to_string(t.size()));
140
141
142
            scheduler::BatchQueryTodo bqt;
            bqt.query_ids = t[0].cast<std::vector<scheduler::QueryID>>();
            bqt.query_tokens = t[1].cast<std::vector<torch::Tensor>>();
143
144
            bqt.query_lengths =
                t[2].cast<std::vector<scheduler::TokenLength>>();
145
146
147
            bqt.block_indexes = t[3].cast<std::vector<torch::Tensor>>();
            bqt.attn_masks = t[4].cast<std::optional<torch::Tensor>>();
            bqt.rope_ranges = t[5].cast<std::optional<torch::Tensor>>();
148
149
150
151
152
153
154
155
            bqt.sample_options =
                t[6].cast<std::vector<scheduler::SampleOptions>>();
            bqt.prefill_mini_batches =
                t[7].cast<std::vector<scheduler::PrefillTask>>();
            bqt.decode_mini_batches =
                t[8].cast<std::vector<std::vector<scheduler::QueryID>>>();
            bqt.stop_criteria =
                t[9].cast<std::vector<std::vector<std::vector<int>>>>();
156
157
158
159
160
161
162
163
164
            return bqt;
          }));

  py::class_<scheduler::QueryUpdate>(m, "QueryUpdate")
      .def(py::init<>())
      .def_readwrite("id", &scheduler::QueryUpdate::id)
      .def_readwrite("ok", &scheduler::QueryUpdate::ok)
      .def_readwrite("is_prefill", &scheduler::QueryUpdate::is_prefill)
      .def_readwrite("decode_done", &scheduler::QueryUpdate::decode_done)
165
166
167
168
      .def_readwrite("active_position",
                     &scheduler::QueryUpdate::active_position)
      .def_readwrite("generated_token",
                     &scheduler::QueryUpdate::generated_token)
169
      .def(py::pickle(
170
171
172
          [](const scheduler::QueryUpdate &self) {
            return py::make_tuple(self.id, self.ok, self.is_prefill,
                                  self.decode_done, self.active_position,
173
174
175
176
                                  self.generated_token);
          },
          [](py::tuple t) {
            if (t.size() != 6)
177
178
              throw std::runtime_error("Invalid state! t.size() = " +
                                       std::to_string(t.size()));
179
180
181
182
183
184
185
186
187
188
189
190
191
            scheduler::QueryUpdate qu;
            qu.id = t[0].cast<scheduler::QueryID>();
            qu.ok = t[1].cast<bool>();
            qu.is_prefill = t[2].cast<bool>();
            qu.decode_done = t[3].cast<bool>();
            qu.active_position = t[4].cast<scheduler::TokenLength>();
            qu.generated_token = t[5].cast<scheduler::Token>();
            return qu;
          }));

  py::class_<scheduler::InferenceContext>(m, "InferenceContext")
      .def(py::init<>())
      .def_readwrite("k_cache", &scheduler::InferenceContext::k_cache)
192
      .def_readwrite("v_cache", &scheduler::InferenceContext::v_cache);
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

  py::class_<scheduler::QueryAdd>(m, "QueryAdd")
      .def(py::init<>())
      .def_readwrite("query_token", &scheduler::QueryAdd::query_token)
      // .def_readwrite("attn_mask", &scheduler::QueryAdd::attn_mask)
      .def_readwrite("query_length", &scheduler::QueryAdd::query_length)
      .def_readwrite("estimated_length", &scheduler::QueryAdd::estimated_length)
      .def_readwrite("sample_options", &scheduler::QueryAdd::sample_options)
      .def_readwrite("user_id", &scheduler::QueryAdd::user_id)
      .def_readwrite("SLO_TTFT_ms", &scheduler::QueryAdd::SLO_TTFT_ms)
      .def_readwrite("SLO_TBT_ms", &scheduler::QueryAdd::SLO_TBT_ms)
      .def_readwrite("stop_criteria", &scheduler::QueryAdd::stop_criteria)
      .def("serialize", &scheduler::QueryAdd::serialize)
      .def_static("deserialize", &scheduler::QueryAdd::deserialize)
      .def(py::pickle(
208
          [](const scheduler::QueryAdd &self) {
209
210
            return py::make_tuple(self.query_token,
                                  // self.attn_mask,
211
212
213
214
                                  self.query_length, self.estimated_length,
                                  self.sample_options, self.user_id,
                                  self.SLO_TTFT_ms, self.SLO_TBT_ms,
                                  self.stop_criteria);
215
216
217
          },
          [](py::tuple t) {
            if (t.size() != 8)
218
219
              throw std::runtime_error("Invalid state! t.size() = " +
                                       std::to_string(t.size()));
220
221
222
223
224
225
226
227
228
229
230
231
232
            scheduler::QueryAdd qa;
            qa.query_token = t[0].cast<std::vector<scheduler::Token>>();
            // qa.attn_mask = t[1].cast<torch::Tensor>();
            qa.query_length = t[1].cast<scheduler::TokenLength>();
            qa.estimated_length = t[2].cast<scheduler::TokenLength>();
            qa.sample_options = t[3].cast<scheduler::SampleOptions>();
            qa.user_id = t[4].cast<scheduler::UserID>();
            qa.SLO_TTFT_ms = t[5].cast<int>();
            qa.SLO_TBT_ms = t[6].cast<int>();
            qa.stop_criteria = t[7].cast<std::vector<std::vector<int>>>();
            return qa;
          }));

233
234
  py::class_<scheduler::Scheduler, std::shared_ptr<scheduler::Scheduler>>(
      m, "Scheduler")
235
236
237
      .def("init", &scheduler::Scheduler::init)
      .def("run", &scheduler::Scheduler::run)
      .def("stop", &scheduler::Scheduler::stop)
238
239
240
241
242
243
244
245
      .def("add_query", &scheduler::Scheduler::add_query,
           py::call_guard<py::gil_scoped_release>())
      .def("cancel_query", &scheduler::Scheduler::cancel_query,
           py::call_guard<py::gil_scoped_release>())
      .def("update_last_batch", &scheduler::Scheduler::update_last_batch,
           py::call_guard<py::gil_scoped_release>())
      .def("get_inference_context",
           &scheduler::Scheduler::get_inference_context);
246

247
248
  m.def("create_scheduler", &scheduler::create_scheduler,
        "Create a new Scheduler instance");
249
}