scheduler.h 4.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#pragma once
#include <torch/torch.h>
#include <cstdint>
#include <memory>
#include <optional>
#include <vector>
#include "model_config.h"

namespace scheduler {

using Token = uint32_t;
using QueryID = uint64_t;
constexpr QueryID NoQueryID = 0;

using TokenLength = size_t;
using BatchID = uint64_t;

using PageCount = size_t;

struct ModelSettings {
  std::string model_path;
  size_t params_count;
  size_t layer_count;
  size_t num_k_heads;
  size_t k_head_dim;

  double bytes_per_params;
  double bytes_per_kv_cache_element;

  inline size_t params_nbytes() { return params_count * bytes_per_params; }
  inline size_t bytes_per_token_kv_cache() { return bytes_per_kv_cache_element * num_k_heads * k_head_dim; }
};

struct SampleOptions {
  double temperature = 1.0;
  double top_p = 1.0;
};

struct Settings {
  // something is aukward here, kvc2 only use model_name and quant_type to get model infos.
  ModelName model_name;
  QuantType quant_type;
  // model_setting is ignore by kvc2
  ModelSettings model_settings;

  size_t page_size = 256;             // how many token in a page
  std::vector<size_t> gpu_device_id;  //
  size_t gpu_memory_size;             // memory size in bytes of each GPU, each
  double memory_utilization_percentage;

  size_t max_batch_size = 256;

  size_t recommended_chunk_prefill_token_count;
  SampleOptions sample_options;
  size_t sched_metrics_port;

  // for kvc2
  bool gpu_only;
  bool use_self_defined_head_dim = false;
  size_t self_defined_head_dim;
  bool full_kv_cache_on_each_gpu = false;
  bool k_cache_on = true;
  bool v_cache_on = true;
  std::string kvc2_config_path;
  std::string kvc2_root_path;
  double memory_pool_size_GB = 100;
  size_t evict_count = 20;
  size_t kvc2_metrics_port;
  bool load_from_disk = false;
  bool save_to_disk = false;

  // for strategy
  std::string strategy_name;

  // derived
  size_t gpu_device_count;
  std::optional<size_t> total_kvcache_pages;
  std::vector<torch::Device> devices;
  void auto_derive();
};

using PrefillTask = std::tuple<QueryID, TokenLength, TokenLength>;  // id, start, length

struct BatchQueryTodo {
  // query
  std::vector<QueryID> query_ids;
  std::vector<torch::Tensor> query_tokens;
  std::vector<TokenLength> query_lengths;
  std::vector<torch::Tensor> block_indexes;  // (max_num_blocks_per_seq), dtype torch.int32.
  std::optional<torch::Tensor> attn_masks;
  std::optional<torch::Tensor> rope_ranges;
  std::vector<SampleOptions> sample_options;
  std::vector<std::vector<std::vector<int>>> stop_criteria;

  // mini batches, adjacent two mini batches are executed together
  // tasks count must be <=2, because of flash infer attention
  std::vector<PrefillTask> prefill_mini_batches;          // prefill minibatch only has 1 prefill
  std::vector<std::vector<QueryID>> decode_mini_batches;  // decode minibatch has multiple decode

  std::string debug();
  bool empty();
};

struct QueryUpdate {
  QueryID id;
  bool ok;
  bool is_prefill;
  bool decode_done;             // no use for now
  TokenLength active_position;  // the position where no kvcache now,
                                // kvcache[active_position] == None

  Token generated_token;

  std::string debug() const;
};

using BatchQueryUpdate = std::vector<QueryUpdate>;

struct InferenceContext {
  std::vector<torch::Tensor> k_cache;  // [gpu num] (layer_count, num blocks,
                                       // page size, kheadnum, head_dim)
  std::vector<torch::Tensor> v_cache;
};

using UserID = int64_t;
constexpr UserID NoUser = -1;
const int MAX_SLO_TIME = 1e9;

struct QueryAdd {
  std::vector<Token> query_token;  // int here
  // torch::Tensor attn_mask;
  TokenLength query_length;
  TokenLength estimated_length;

  std::vector<std::vector<int>> stop_criteria;

  SampleOptions sample_options;

  UserID user_id;
  int SLO_TTFT_ms = MAX_SLO_TIME;
  int SLO_TBT_ms = MAX_SLO_TIME;

  std::string serialize();
  static QueryAdd deserialize(const std::string& input);
};

class Scheduler {
 public:
  virtual void init(Settings settings) = 0;

  virtual void run() = 0;
  virtual void stop() = 0;

  // webserver call this
  virtual QueryID add_query(QueryAdd query) = 0;
  virtual void cancel_query(QueryID id) = 0;

  // inference loop call this
  virtual std::shared_ptr<BatchQueryTodo> update_last_batch(BatchQueryUpdate updates) = 0;
  virtual InferenceContext get_inference_context() = 0;

  virtual ~Scheduler() = default;
};

std::shared_ptr<Scheduler> create_scheduler(Settings settings);

};  // namespace scheduler