scheduler.h 4.44 KB
Newer Older
1
#pragma once
2
#include "model_config.h"
3
4
5
#include <cstdint>
#include <memory>
#include <optional>
6
#include <torch/torch.h>
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <vector>

namespace scheduler {

using Token = uint32_t;
using QueryID = uint64_t;
constexpr QueryID NoQueryID = 0;

using TokenLength = size_t;
using BatchID = uint64_t;

using PageCount = size_t;

struct ModelSettings {
  std::string model_path;
  size_t params_count;
  size_t layer_count;
  size_t num_k_heads;
  size_t k_head_dim;

  double bytes_per_params;
  double bytes_per_kv_cache_element;

  inline size_t params_nbytes() { return params_count * bytes_per_params; }
31
32
33
  inline size_t bytes_per_token_kv_cache() {
    return bytes_per_kv_cache_element * num_k_heads * k_head_dim;
  }
34
35
36
37
38
39
40
41
};

struct SampleOptions {
  double temperature = 1.0;
  double top_p = 1.0;
};

struct Settings {
42
43
  // something is aukward here, kvc2 only use model_name and quant_type to get
  // model infos.
44
45
46
47
48
  ModelName model_name;
  QuantType quant_type;
  // model_setting is ignore by kvc2
  ModelSettings model_settings;

49
50
51
  size_t page_size = 256;            // how many token in a page
  std::vector<size_t> gpu_device_id; //
  size_t gpu_memory_size;            // memory size in bytes of each GPU, each
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  double memory_utilization_percentage;

  size_t max_batch_size = 256;

  size_t recommended_chunk_prefill_token_count;
  SampleOptions sample_options;
  size_t sched_metrics_port;

  // for kvc2
  bool gpu_only;
  bool use_self_defined_head_dim = false;
  size_t self_defined_head_dim;
  bool full_kv_cache_on_each_gpu = false;
  bool k_cache_on = true;
  bool v_cache_on = true;
  std::string kvc2_config_path;
  std::string kvc2_root_path;
  double memory_pool_size_GB = 100;
  size_t evict_count = 20;
  size_t kvc2_metrics_port;
  bool load_from_disk = false;
  bool save_to_disk = false;

  // for strategy
  std::string strategy_name;

  // derived
  size_t gpu_device_count;
  std::optional<size_t> total_kvcache_pages;
  std::vector<torch::Device> devices;
  void auto_derive();
};

85
86
using PrefillTask =
    std::tuple<QueryID, TokenLength, TokenLength>; // id, start, length
87
88
89
90
91
92

struct BatchQueryTodo {
  // query
  std::vector<QueryID> query_ids;
  std::vector<torch::Tensor> query_tokens;
  std::vector<TokenLength> query_lengths;
93
94
  std::vector<torch::Tensor>
      block_indexes; // (max_num_blocks_per_seq), dtype torch.int32.
95
96
97
98
99
100
101
  std::optional<torch::Tensor> attn_masks;
  std::optional<torch::Tensor> rope_ranges;
  std::vector<SampleOptions> sample_options;
  std::vector<std::vector<std::vector<int>>> stop_criteria;

  // mini batches, adjacent two mini batches are executed together
  // tasks count must be <=2, because of flash infer attention
102
103
104
105
  std::vector<PrefillTask>
      prefill_mini_batches; // prefill minibatch only has 1 prefill
  std::vector<std::vector<QueryID>>
      decode_mini_batches; // decode minibatch has multiple decode
106
107
108
109
110
111
112
113
114

  std::string debug();
  bool empty();
};

struct QueryUpdate {
  QueryID id;
  bool ok;
  bool is_prefill;
115
116
117
  bool decode_done;            // no use for now
  TokenLength active_position; // the position where no kvcache now,
                               // kvcache[active_position] == None
118
119
120
121
122
123
124
125
126

  Token generated_token;

  std::string debug() const;
};

using BatchQueryUpdate = std::vector<QueryUpdate>;

struct InferenceContext {
127
128
  std::vector<torch::Tensor> k_cache; // [gpu num] (layer_count, num blocks,
                                      // page size, kheadnum, head_dim)
129
130
131
132
133
134
135
136
  std::vector<torch::Tensor> v_cache;
};

using UserID = int64_t;
constexpr UserID NoUser = -1;
const int MAX_SLO_TIME = 1e9;

struct QueryAdd {
137
  std::vector<Token> query_token; // int here
138
139
140
141
142
143
144
145
146
147
148
149
150
  // torch::Tensor attn_mask;
  TokenLength query_length;
  TokenLength estimated_length;

  std::vector<std::vector<int>> stop_criteria;

  SampleOptions sample_options;

  UserID user_id;
  int SLO_TTFT_ms = MAX_SLO_TIME;
  int SLO_TBT_ms = MAX_SLO_TIME;

  std::string serialize();
151
  static QueryAdd deserialize(const std::string &input);
152
153
154
};

class Scheduler {
155
public:
156
157
158
159
160
161
162
163
164
165
  virtual void init(Settings settings) = 0;

  virtual void run() = 0;
  virtual void stop() = 0;

  // webserver call this
  virtual QueryID add_query(QueryAdd query) = 0;
  virtual void cancel_query(QueryID id) = 0;

  // inference loop call this
166
167
  virtual std::shared_ptr<BatchQueryTodo>
  update_last_batch(BatchQueryUpdate updates) = 0;
168
169
170
171
172
173
174
  virtual InferenceContext get_inference_context() = 0;

  virtual ~Scheduler() = default;
};

std::shared_ptr<Scheduler> create_scheduler(Settings settings);

175
}; // namespace scheduler