bind.cpp 11.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include "scheduler.h"

#include <torch/extension.h>

namespace py = pybind11;

PYBIND11_MODULE(sched_ext, m) {
  py::class_<scheduler::ModelSettings>(m, "ModelSettings")
      .def(py::init<>())
      .def_readwrite("model_path", &scheduler::ModelSettings::model_path)
      .def_readwrite("params_count", &scheduler::ModelSettings::params_count)
      .def_readwrite("layer_count", &scheduler::ModelSettings::layer_count)
      .def_readwrite("num_k_heads", &scheduler::ModelSettings::num_k_heads)
      .def_readwrite("k_head_dim", &scheduler::ModelSettings::k_head_dim)
      .def_readwrite("bytes_per_params", &scheduler::ModelSettings::bytes_per_params)
      .def_readwrite("bytes_per_kv_cache_element", &scheduler::ModelSettings::bytes_per_kv_cache_element)
      .def("params_size", &scheduler::ModelSettings::params_nbytes)
      .def("bytes_per_token_kv_cache", &scheduler::ModelSettings::bytes_per_token_kv_cache)
      // 添加 pickle 支持
      .def(py::pickle(
          [](const scheduler::ModelSettings& self) {  // __getstate__
            return py::make_tuple(self.params_count, self.layer_count, self.num_k_heads, self.k_head_dim,
                                  self.bytes_per_params, self.bytes_per_kv_cache_element);
          },
          [](py::tuple t) {  // __setstate__
            if (t.size() != 6)
              throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
            scheduler::ModelSettings ms;
            ms.params_count = t[0].cast<size_t>();
            ms.layer_count = t[1].cast<size_t>();
            ms.num_k_heads = t[2].cast<size_t>();
            ms.k_head_dim = t[3].cast<size_t>();
            ms.bytes_per_params = t[4].cast<double>();
            ms.bytes_per_kv_cache_element = t[5].cast<double>();
            return ms;
          }));

  py::class_<scheduler::SampleOptions>(m, "SampleOptions")
    .def(py::init<>())
    .def_readwrite("temperature", &scheduler::SampleOptions::temperature)
    .def_readwrite("top_p", &scheduler::SampleOptions::top_p)  // 确保 top_p 也能被访问
    .def(py::pickle(
        [](const scheduler::SampleOptions& self) {
            return py::make_tuple(self.temperature, self.top_p);  // 序列化 temperature 和 top_p
        },
        [](py::tuple t) {
            if (t.size() != 2)  // 确保解包时参数数量匹配
                throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
            scheduler::SampleOptions so;
            so.temperature = t[0].cast<double>();
            so.top_p = t[1].cast<double>();  // 反序列化 top_p
            return so;
        }
    ));

  py::class_<scheduler::Settings>(m, "Settings")
      .def(py::init<>())
      .def_readwrite("model_name", &scheduler::Settings::model_name)
      .def_readwrite("quant_type", &scheduler::Settings::quant_type)
      .def_readwrite("model_settings", &scheduler::Settings::model_settings)
      .def_readwrite("page_size", &scheduler::Settings::page_size)
      .def_readwrite("gpu_device_id", &scheduler::Settings::gpu_device_id)
      .def_readwrite("gpu_memory_size", &scheduler::Settings::gpu_memory_size)
      .def_readwrite("memory_utilization_percentage", &scheduler::Settings::memory_utilization_percentage)
      .def_readwrite("max_batch_size", &scheduler::Settings::max_batch_size)
      .def_readwrite("recommended_chunk_prefill_token_count",
                     &scheduler::Settings::recommended_chunk_prefill_token_count)
      .def_readwrite("sample_options", &scheduler::Settings::sample_options)
      .def_readwrite("sched_metrics_port", &scheduler::Settings::sched_metrics_port)
      .def_readwrite("gpu_only", &scheduler::Settings::gpu_only)
      .def_readwrite("use_self_defined_head_dim", &scheduler::Settings::use_self_defined_head_dim)
      .def_readwrite("self_defined_head_dim", &scheduler::Settings::self_defined_head_dim)
      .def_readwrite("full_kv_cache_on_each_gpu", &scheduler::Settings::full_kv_cache_on_each_gpu)
      .def_readwrite("k_cache_on", &scheduler::Settings::k_cache_on)
      .def_readwrite("v_cache_on", &scheduler::Settings::v_cache_on)
      .def_readwrite("kvc2_config_path", &scheduler::Settings::kvc2_config_path)
      .def_readwrite("kvc2_root_path", &scheduler::Settings::kvc2_root_path)
      .def_readwrite("memory_pool_size_GB", &scheduler::Settings::memory_pool_size_GB)
      .def_readwrite("evict_count", &scheduler::Settings::evict_count)
      .def_readwrite("strategy_name", &scheduler::Settings::strategy_name)
      .def_readwrite("kvc2_metrics_port", &scheduler::Settings::kvc2_metrics_port)
      .def_readwrite("load_from_disk", &scheduler::Settings::load_from_disk)
      .def_readwrite("save_to_disk", &scheduler::Settings::save_to_disk)
      // derived
      .def_readwrite("gpu_device_count", &scheduler::Settings::gpu_device_count)
      .def_readwrite("total_kvcache_pages", &scheduler::Settings::total_kvcache_pages)
      .def_readwrite("devices", &scheduler::Settings::devices)
      .def("auto_derive", &scheduler::Settings::auto_derive);

  py::class_<scheduler::BatchQueryTodo, std::shared_ptr<scheduler::BatchQueryTodo>>(m, "BatchQueryTodo")
      .def(py::init<>())
      .def_readwrite("query_ids", &scheduler::BatchQueryTodo::query_ids)
      .def_readwrite("query_tokens", &scheduler::BatchQueryTodo::query_tokens)
      .def_readwrite("query_lengths", &scheduler::BatchQueryTodo::query_lengths)
      .def_readwrite("block_indexes", &scheduler::BatchQueryTodo::block_indexes)
      .def_readwrite("attn_masks", &scheduler::BatchQueryTodo::attn_masks)
      .def_readwrite("rope_ranges", &scheduler::BatchQueryTodo::rope_ranges)
      .def_readwrite("sample_options", &scheduler::BatchQueryTodo::sample_options)
      .def_readwrite("prefill_mini_batches", &scheduler::BatchQueryTodo::prefill_mini_batches)
      .def_readwrite("decode_mini_batches", &scheduler::BatchQueryTodo::decode_mini_batches)
      .def_readwrite("stop_criteria", &scheduler::BatchQueryTodo::stop_criteria)
      .def("debug", &scheduler::BatchQueryTodo::debug)
      .def(py::pickle(
          [](const scheduler::BatchQueryTodo& self) {
            return py::make_tuple(self.query_ids, self.query_tokens, self.query_lengths, self.block_indexes,
                                  self.attn_masks, self.rope_ranges, self.sample_options, self.prefill_mini_batches,
                                  self.decode_mini_batches, self.stop_criteria);
          },
          [](py::tuple t) {
            if (t.size() != 10)
              throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
            scheduler::BatchQueryTodo bqt;
            bqt.query_ids = t[0].cast<std::vector<scheduler::QueryID>>();
            bqt.query_tokens = t[1].cast<std::vector<torch::Tensor>>();
            bqt.query_lengths = t[2].cast<std::vector<scheduler::TokenLength>>();
            bqt.block_indexes = t[3].cast<std::vector<torch::Tensor>>();
            bqt.attn_masks = t[4].cast<std::optional<torch::Tensor>>();
            bqt.rope_ranges = t[5].cast<std::optional<torch::Tensor>>();
            bqt.sample_options = t[6].cast<std::vector<scheduler::SampleOptions>>();
            bqt.prefill_mini_batches = t[7].cast<std::vector<scheduler::PrefillTask>>();
            bqt.decode_mini_batches = t[8].cast<std::vector<std::vector<scheduler::QueryID>>>();
            bqt.stop_criteria = t[9].cast<std::vector<std::vector<std::vector<int>>>>();
            return bqt;
          }));

  py::class_<scheduler::QueryUpdate>(m, "QueryUpdate")
      .def(py::init<>())
      .def_readwrite("id", &scheduler::QueryUpdate::id)
      .def_readwrite("ok", &scheduler::QueryUpdate::ok)
      .def_readwrite("is_prefill", &scheduler::QueryUpdate::is_prefill)
      .def_readwrite("decode_done", &scheduler::QueryUpdate::decode_done)
      .def_readwrite("active_position", &scheduler::QueryUpdate::active_position)
      .def_readwrite("generated_token", &scheduler::QueryUpdate::generated_token)
      .def(py::pickle(
          [](const scheduler::QueryUpdate& self) {
            return py::make_tuple(self.id, self.ok, self.is_prefill, self.decode_done, self.active_position,
                                  self.generated_token);
          },
          [](py::tuple t) {
            if (t.size() != 6)
              throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
            scheduler::QueryUpdate qu;
            qu.id = t[0].cast<scheduler::QueryID>();
            qu.ok = t[1].cast<bool>();
            qu.is_prefill = t[2].cast<bool>();
            qu.decode_done = t[3].cast<bool>();
            qu.active_position = t[4].cast<scheduler::TokenLength>();
            qu.generated_token = t[5].cast<scheduler::Token>();
            return qu;
          }));

  py::class_<scheduler::InferenceContext>(m, "InferenceContext")
      .def(py::init<>())
      .def_readwrite("k_cache", &scheduler::InferenceContext::k_cache)
      .def_readwrite("v_cache", &scheduler::InferenceContext::v_cache)
      ;

  py::class_<scheduler::QueryAdd>(m, "QueryAdd")
      .def(py::init<>())
      .def_readwrite("query_token", &scheduler::QueryAdd::query_token)
      // .def_readwrite("attn_mask", &scheduler::QueryAdd::attn_mask)
      .def_readwrite("query_length", &scheduler::QueryAdd::query_length)
      .def_readwrite("estimated_length", &scheduler::QueryAdd::estimated_length)
      .def_readwrite("sample_options", &scheduler::QueryAdd::sample_options)
      .def_readwrite("user_id", &scheduler::QueryAdd::user_id)
      .def_readwrite("SLO_TTFT_ms", &scheduler::QueryAdd::SLO_TTFT_ms)
      .def_readwrite("SLO_TBT_ms", &scheduler::QueryAdd::SLO_TBT_ms)
      .def_readwrite("stop_criteria", &scheduler::QueryAdd::stop_criteria)
      .def("serialize", &scheduler::QueryAdd::serialize)
      .def_static("deserialize", &scheduler::QueryAdd::deserialize)
      .def(py::pickle(
          [](const scheduler::QueryAdd& self) {
            return py::make_tuple(self.query_token,
                                  // self.attn_mask,
                                  self.query_length, self.estimated_length, self.sample_options, self.user_id,
                                  self.SLO_TTFT_ms, self.SLO_TBT_ms, self.stop_criteria);
          },
          [](py::tuple t) {
            if (t.size() != 8)
              throw std::runtime_error("Invalid state! t.size() = " + std::to_string(t.size()));
            scheduler::QueryAdd qa;
            qa.query_token = t[0].cast<std::vector<scheduler::Token>>();
            // qa.attn_mask = t[1].cast<torch::Tensor>();
            qa.query_length = t[1].cast<scheduler::TokenLength>();
            qa.estimated_length = t[2].cast<scheduler::TokenLength>();
            qa.sample_options = t[3].cast<scheduler::SampleOptions>();
            qa.user_id = t[4].cast<scheduler::UserID>();
            qa.SLO_TTFT_ms = t[5].cast<int>();
            qa.SLO_TBT_ms = t[6].cast<int>();
            qa.stop_criteria = t[7].cast<std::vector<std::vector<int>>>();
            return qa;
          }));

  py::class_<scheduler::Scheduler, std::shared_ptr<scheduler::Scheduler>>(m, "Scheduler")
      .def("init", &scheduler::Scheduler::init)
      .def("run", &scheduler::Scheduler::run)
      .def("stop", &scheduler::Scheduler::stop)
      .def("add_query", &scheduler::Scheduler::add_query, py::call_guard<py::gil_scoped_release>())
      .def("cancel_query", &scheduler::Scheduler::cancel_query, py::call_guard<py::gil_scoped_release>())
      .def("update_last_batch", &scheduler::Scheduler::update_last_batch, py::call_guard<py::gil_scoped_release>())
      .def("get_inference_context", &scheduler::Scheduler::get_inference_context);

  m.def("create_scheduler", &scheduler::create_scheduler, "Create a new Scheduler instance");
}