Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
25cee581
Commit
25cee581
authored
Mar 31, 2025
by
Atream
Browse files
add balance-serve, support concurrence
parent
8d0292aa
Changes
196
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2434 additions
and
0 deletions
+2434
-0
csrc/balance_serve/kvc2/test/test_queue_perf.cpp
csrc/balance_serve/kvc2/test/test_queue_perf.cpp
+84
-0
csrc/balance_serve/kvc2/test/test_std_list.cpp
csrc/balance_serve/kvc2/test/test_std_list.cpp
+38
-0
csrc/balance_serve/kvc2/test/xxHash_test.cpp
csrc/balance_serve/kvc2/test/xxHash_test.cpp
+31
-0
csrc/balance_serve/kvc2/unit_test.sh
csrc/balance_serve/kvc2/unit_test.sh
+36
-0
csrc/balance_serve/sched/CMakeLists.txt
csrc/balance_serve/sched/CMakeLists.txt
+19
-0
csrc/balance_serve/sched/bind.cpp
csrc/balance_serve/sched/bind.cpp
+208
-0
csrc/balance_serve/sched/metrics.cpp
csrc/balance_serve/sched/metrics.cpp
+135
-0
csrc/balance_serve/sched/metrics.h
csrc/balance_serve/sched/metrics.h
+85
-0
csrc/balance_serve/sched/model_config.h
csrc/balance_serve/sched/model_config.h
+113
-0
csrc/balance_serve/sched/scheduler.cpp
csrc/balance_serve/sched/scheduler.cpp
+916
-0
csrc/balance_serve/sched/scheduler.h
csrc/balance_serve/sched/scheduler.h
+167
-0
csrc/balance_serve/sched/utils/all.hpp
csrc/balance_serve/sched/utils/all.hpp
+3
-0
csrc/balance_serve/sched/utils/arithmetic.hpp
csrc/balance_serve/sched/utils/arithmetic.hpp
+8
-0
csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp
csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp
+28
-0
csrc/balance_serve/sched/utils/csv.hpp
csrc/balance_serve/sched/utils/csv.hpp
+225
-0
csrc/balance_serve/sched/utils/easy_format.hpp
csrc/balance_serve/sched/utils/easy_format.hpp
+16
-0
csrc/balance_serve/sched/utils/mpsc.hpp
csrc/balance_serve/sched/utils/mpsc.hpp
+109
-0
csrc/balance_serve/sched/utils/readable_number.hpp
csrc/balance_serve/sched/utils/readable_number.hpp
+20
-0
csrc/balance_serve/sched/utils/statistics.hpp
csrc/balance_serve/sched/utils/statistics.hpp
+65
-0
csrc/balance_serve/sched/utils/timer.hpp
csrc/balance_serve/sched/utils/timer.hpp
+128
-0
No files found.
csrc/balance_serve/kvc2/test/test_queue_perf.cpp
0 → 100644
View file @
25cee581
#include <mutex>
#include <queue>
#include "utils/lock_free_queue.hpp"
#define STDQ
int
main
()
{
const
int
num_producers
=
48
;
const
int
num_items
=
1e6
;
#ifdef STDQ
std
::
mutex
lock
;
std
::
queue
<
int
>
queue
;
#else
MPSCQueue
<
int
>
queue
;
#endif
auto
start_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
// Launch multiple producer threads
std
::
vector
<
std
::
thread
>
producers
;
for
(
int
i
=
0
;
i
<
num_producers
;
++
i
)
{
producers
.
emplace_back
([
&
queue
,
i
#ifdef STDQ
,
&
lock
#endif
]()
{
for
(
int
j
=
0
;
j
<
num_items
;
++
j
)
{
#ifdef STDQ
std
::
lock_guard
<
std
::
mutex
>
guard
(
lock
);
queue
.
push
(
i
*
num_items
+
j
);
#else
queue
.
enqueue
(
std
::
make_shared
<
int
>
(
i
*
num_items
+
j
));
#endif
}
});
}
// Consumer thread
std
::
thread
consumer
([
&
queue
,
num_producers
#ifdef STDQ
,
&
lock
#endif
]()
{
int
count
=
0
;
while
(
count
<
num_producers
*
num_items
)
{
#ifdef STDQ
std
::
lock_guard
<
std
::
mutex
>
guard
(
lock
);
if
(
!
queue
.
empty
())
{
queue
.
pop
();
count
++
;
}
#else
if
(
auto
item
=
queue
.
dequeue
())
{
count
++
;
}
#endif
}
});
// Wait for all producers to finish
for
(
auto
&
producer
:
producers
)
{
producer
.
join
();
}
// Wait for the consumer to finish
consumer
.
join
();
auto
end_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
duration
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end_time
-
start_time
).
count
();
#ifdef STDQ
std
::
cout
<<
"std::queue with mutex "
;
#else
std
::
cout
<<
"lock free queue "
;
#endif
std
::
cout
<<
"Processed "
<<
num_producers
*
num_items
/
1e6
<<
"M items in "
<<
duration
<<
" milliseconds "
<<
num_producers
*
num_items
/
1e3
/
duration
<<
" MOps."
<<
std
::
endl
;
return
0
;
}
\ No newline at end of file
csrc/balance_serve/kvc2/test/test_std_list.cpp
0 → 100644
View file @
25cee581
#include <iostream>
#include <iterator>
#include <vector>
int
main
()
{
std
::
vector
<
int
>
v
=
{
0
,
1
,
2
,
3
,
4
,
5
};
using
RevIt
=
std
::
reverse_iterator
<
std
::
vector
<
int
>::
iterator
>
;
const
auto
it
=
v
.
begin
()
+
3
;
RevIt
r_it
{
it
};
std
::
cout
<<
"*it == "
<<
*
it
<<
'\n'
<<
"*r_it == "
<<
*
r_it
<<
'\n'
<<
"*r_it.base() == "
<<
*
r_it
.
base
()
<<
'\n'
<<
"*(r_it.base()-1) == "
<<
*
(
r_it
.
base
()
-
1
)
<<
'\n'
;
RevIt
r_end
{
v
.
begin
()};
RevIt
r_begin
{
v
.
end
()};
for
(
auto
it
=
r_end
.
base
();
it
!=
r_begin
.
base
();
++
it
)
std
::
cout
<<
*
it
<<
' '
;
std
::
cout
<<
'\n'
;
for
(
auto
it
=
r_begin
;
it
!=
r_end
;
++
it
)
std
::
cout
<<
*
it
<<
' '
;
std
::
cout
<<
'\n'
;
for
(
auto
it
=
r_begin
;
it
!=
r_end
;
++
it
)
{
if
(
*
it
==
3
)
{
v
.
erase
(
std
::
next
(
it
).
base
());
}
}
for
(
auto
it
:
v
)
std
::
cout
<<
it
<<
' '
;
std
::
cout
<<
'\n'
;
}
\ No newline at end of file
csrc/balance_serve/kvc2/test/xxHash_test.cpp
0 → 100644
View file @
25cee581
#include "xxhash.h"
#include <iostream>
int
main
()
{
std
::
string
t
=
"hello world"
;
XXH64_hash_t
hash
=
XXH64
(
t
.
data
(),
t
.
size
(),
123
);
std
::
cout
<<
hash
<<
std
::
endl
;
{
/* create a hash state */
XXH64_state_t
*
const
state
=
XXH64_createState
();
if
(
state
==
NULL
)
abort
();
if
(
XXH64_reset
(
state
,
123
)
==
XXH_ERROR
)
abort
();
if
(
XXH64_update
(
state
,
t
.
data
(),
5
)
==
XXH_ERROR
)
abort
();
if
(
XXH64_update
(
state
,
t
.
data
()
+
5
,
t
.
size
()
-
5
)
==
XXH_ERROR
)
abort
();
/* Produce the final hash value */
XXH64_hash_t
const
hash
=
XXH64_digest
(
state
);
/* State could be re-used; but in this example, it is simply freed */
XXH64_freeState
(
state
);
std
::
cout
<<
hash
<<
std
::
endl
;
}
return
0
;
}
csrc/balance_serve/kvc2/unit_test.sh
0 → 100755
View file @
25cee581
#!/bin/bash
# 检查是否提供了 disk_cache_path 参数
if
[
-z
"
$1
"
]
;
then
echo
"Usage:
$0
<disk_cache_path>"
exit
1
fi
# 将 disk_cache_path 参数赋值给变量
disk_cache_path
=
$1
# 定义测试命令数组,并使用变量替换 disk_cache_path
tests
=(
"./build/test/kvc2_export_header_test --disk_cache_path=
$disk_cache_path
"
"./build/test/kvcache_disk_insert_read_test --disk_cache_path=
$disk_cache_path
"
"./build/test/kvcache_mem_eviction_test --disk_cache_path=
$disk_cache_path
"
"./build/test/kvcache_mem_insert_read_test --disk_cache_path=
$disk_cache_path
"
"./build/test/kvcache_save_load_test --disk_cache_path=
$disk_cache_path
"
)
# 遍历每个测试命令
for
test
in
"
${
tests
[@]
}
"
;
do
echo
"Running:
$test
"
# 运行测试并捕获输出
output
=
$(
$test
)
# 检查测试输出中是否包含 "Test Passed"
if
echo
"
$output
"
|
grep
-q
"Test Passed"
;
then
echo
" Test Passed"
else
echo
" Test Failed"
fi
sleep
1
done
\ No newline at end of file
csrc/balance_serve/sched/CMakeLists.txt
0 → 100644
View file @
25cee581
set
(
CMAKE_CXX_FLAGS
"-Og -march=native -Wall -Wextra -g -fPIC"
)
# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
set
(
UTILS_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
/utils
)
add_library
(
sched_metrics metrics.cpp
)
target_include_directories
(
sched_metrics PRIVATE
${
UTILS_DIR
}
)
target_link_libraries
(
sched_metrics PUBLIC prometheus-cpp::pull
)
add_library
(
sched scheduler.cpp
)
target_include_directories
(
sched PRIVATE
${
SPDLOG_DIR
}
/include
${
FMT_DIR
}
/include
${
UTILS_DIR
}
${
KVC2_INCLUDE_DIR
}
)
target_link_libraries
(
sched PUBLIC pthread
${
TORCH_LIBRARIES
}
kvc2 async_store sched_metrics
)
pybind11_add_module
(
sched_ext bind.cpp
)
target_link_libraries
(
sched_ext PUBLIC sched
${
TORCH_LIBRARIES
}
${
TORCH_PYTHON_LIBRARY
}
)
csrc/balance_serve/sched/bind.cpp
0 → 100644
View file @
25cee581
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include "scheduler.h"
#include <torch/extension.h>
namespace
py
=
pybind11
;
PYBIND11_MODULE
(
sched_ext
,
m
)
{
py
::
class_
<
scheduler
::
ModelSettings
>
(
m
,
"ModelSettings"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"model_path"
,
&
scheduler
::
ModelSettings
::
model_path
)
.
def_readwrite
(
"params_count"
,
&
scheduler
::
ModelSettings
::
params_count
)
.
def_readwrite
(
"layer_count"
,
&
scheduler
::
ModelSettings
::
layer_count
)
.
def_readwrite
(
"num_k_heads"
,
&
scheduler
::
ModelSettings
::
num_k_heads
)
.
def_readwrite
(
"k_head_dim"
,
&
scheduler
::
ModelSettings
::
k_head_dim
)
.
def_readwrite
(
"bytes_per_params"
,
&
scheduler
::
ModelSettings
::
bytes_per_params
)
.
def_readwrite
(
"bytes_per_kv_cache_element"
,
&
scheduler
::
ModelSettings
::
bytes_per_kv_cache_element
)
.
def
(
"params_size"
,
&
scheduler
::
ModelSettings
::
params_nbytes
)
.
def
(
"bytes_per_token_kv_cache"
,
&
scheduler
::
ModelSettings
::
bytes_per_token_kv_cache
)
// 添加 pickle 支持
.
def
(
py
::
pickle
(
[](
const
scheduler
::
ModelSettings
&
self
)
{
// __getstate__
return
py
::
make_tuple
(
self
.
params_count
,
self
.
layer_count
,
self
.
num_k_heads
,
self
.
k_head_dim
,
self
.
bytes_per_params
,
self
.
bytes_per_kv_cache_element
);
},
[](
py
::
tuple
t
)
{
// __setstate__
if
(
t
.
size
()
!=
6
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
ModelSettings
ms
;
ms
.
params_count
=
t
[
0
].
cast
<
size_t
>
();
ms
.
layer_count
=
t
[
1
].
cast
<
size_t
>
();
ms
.
num_k_heads
=
t
[
2
].
cast
<
size_t
>
();
ms
.
k_head_dim
=
t
[
3
].
cast
<
size_t
>
();
ms
.
bytes_per_params
=
t
[
4
].
cast
<
double
>
();
ms
.
bytes_per_kv_cache_element
=
t
[
5
].
cast
<
double
>
();
return
ms
;
}));
py
::
class_
<
scheduler
::
SampleOptions
>
(
m
,
"SampleOptions"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"temperature"
,
&
scheduler
::
SampleOptions
::
temperature
)
.
def_readwrite
(
"top_p"
,
&
scheduler
::
SampleOptions
::
top_p
)
// 确保 top_p 也能被访问
.
def
(
py
::
pickle
(
[](
const
scheduler
::
SampleOptions
&
self
)
{
return
py
::
make_tuple
(
self
.
temperature
,
self
.
top_p
);
// 序列化 temperature 和 top_p
},
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
2
)
// 确保解包时参数数量匹配
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
SampleOptions
so
;
so
.
temperature
=
t
[
0
].
cast
<
double
>
();
so
.
top_p
=
t
[
1
].
cast
<
double
>
();
// 反序列化 top_p
return
so
;
}
));
py
::
class_
<
scheduler
::
Settings
>
(
m
,
"Settings"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"model_name"
,
&
scheduler
::
Settings
::
model_name
)
.
def_readwrite
(
"quant_type"
,
&
scheduler
::
Settings
::
quant_type
)
.
def_readwrite
(
"model_settings"
,
&
scheduler
::
Settings
::
model_settings
)
.
def_readwrite
(
"page_size"
,
&
scheduler
::
Settings
::
page_size
)
.
def_readwrite
(
"gpu_device_id"
,
&
scheduler
::
Settings
::
gpu_device_id
)
.
def_readwrite
(
"gpu_memory_size"
,
&
scheduler
::
Settings
::
gpu_memory_size
)
.
def_readwrite
(
"memory_utilization_percentage"
,
&
scheduler
::
Settings
::
memory_utilization_percentage
)
.
def_readwrite
(
"max_batch_size"
,
&
scheduler
::
Settings
::
max_batch_size
)
.
def_readwrite
(
"recommended_chunk_prefill_token_count"
,
&
scheduler
::
Settings
::
recommended_chunk_prefill_token_count
)
.
def_readwrite
(
"sample_options"
,
&
scheduler
::
Settings
::
sample_options
)
.
def_readwrite
(
"sched_metrics_port"
,
&
scheduler
::
Settings
::
sched_metrics_port
)
.
def_readwrite
(
"gpu_only"
,
&
scheduler
::
Settings
::
gpu_only
)
.
def_readwrite
(
"use_self_defined_head_dim"
,
&
scheduler
::
Settings
::
use_self_defined_head_dim
)
.
def_readwrite
(
"self_defined_head_dim"
,
&
scheduler
::
Settings
::
self_defined_head_dim
)
.
def_readwrite
(
"full_kv_cache_on_each_gpu"
,
&
scheduler
::
Settings
::
full_kv_cache_on_each_gpu
)
.
def_readwrite
(
"k_cache_on"
,
&
scheduler
::
Settings
::
k_cache_on
)
.
def_readwrite
(
"v_cache_on"
,
&
scheduler
::
Settings
::
v_cache_on
)
.
def_readwrite
(
"kvc2_config_path"
,
&
scheduler
::
Settings
::
kvc2_config_path
)
.
def_readwrite
(
"kvc2_root_path"
,
&
scheduler
::
Settings
::
kvc2_root_path
)
.
def_readwrite
(
"memory_pool_size_GB"
,
&
scheduler
::
Settings
::
memory_pool_size_GB
)
.
def_readwrite
(
"evict_count"
,
&
scheduler
::
Settings
::
evict_count
)
.
def_readwrite
(
"strategy_name"
,
&
scheduler
::
Settings
::
strategy_name
)
.
def_readwrite
(
"kvc2_metrics_port"
,
&
scheduler
::
Settings
::
kvc2_metrics_port
)
.
def_readwrite
(
"load_from_disk"
,
&
scheduler
::
Settings
::
load_from_disk
)
.
def_readwrite
(
"save_to_disk"
,
&
scheduler
::
Settings
::
save_to_disk
)
// derived
.
def_readwrite
(
"gpu_device_count"
,
&
scheduler
::
Settings
::
gpu_device_count
)
.
def_readwrite
(
"total_kvcache_pages"
,
&
scheduler
::
Settings
::
total_kvcache_pages
)
.
def_readwrite
(
"devices"
,
&
scheduler
::
Settings
::
devices
)
.
def
(
"auto_derive"
,
&
scheduler
::
Settings
::
auto_derive
);
py
::
class_
<
scheduler
::
BatchQueryTodo
,
std
::
shared_ptr
<
scheduler
::
BatchQueryTodo
>>
(
m
,
"BatchQueryTodo"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"query_ids"
,
&
scheduler
::
BatchQueryTodo
::
query_ids
)
.
def_readwrite
(
"query_tokens"
,
&
scheduler
::
BatchQueryTodo
::
query_tokens
)
.
def_readwrite
(
"query_lengths"
,
&
scheduler
::
BatchQueryTodo
::
query_lengths
)
.
def_readwrite
(
"block_indexes"
,
&
scheduler
::
BatchQueryTodo
::
block_indexes
)
.
def_readwrite
(
"attn_masks"
,
&
scheduler
::
BatchQueryTodo
::
attn_masks
)
.
def_readwrite
(
"rope_ranges"
,
&
scheduler
::
BatchQueryTodo
::
rope_ranges
)
.
def_readwrite
(
"sample_options"
,
&
scheduler
::
BatchQueryTodo
::
sample_options
)
.
def_readwrite
(
"prefill_mini_batches"
,
&
scheduler
::
BatchQueryTodo
::
prefill_mini_batches
)
.
def_readwrite
(
"decode_mini_batches"
,
&
scheduler
::
BatchQueryTodo
::
decode_mini_batches
)
.
def_readwrite
(
"stop_criteria"
,
&
scheduler
::
BatchQueryTodo
::
stop_criteria
)
.
def
(
"debug"
,
&
scheduler
::
BatchQueryTodo
::
debug
)
.
def
(
py
::
pickle
(
[](
const
scheduler
::
BatchQueryTodo
&
self
)
{
return
py
::
make_tuple
(
self
.
query_ids
,
self
.
query_tokens
,
self
.
query_lengths
,
self
.
block_indexes
,
self
.
attn_masks
,
self
.
rope_ranges
,
self
.
sample_options
,
self
.
prefill_mini_batches
,
self
.
decode_mini_batches
,
self
.
stop_criteria
);
},
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
10
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
BatchQueryTodo
bqt
;
bqt
.
query_ids
=
t
[
0
].
cast
<
std
::
vector
<
scheduler
::
QueryID
>>
();
bqt
.
query_tokens
=
t
[
1
].
cast
<
std
::
vector
<
torch
::
Tensor
>>
();
bqt
.
query_lengths
=
t
[
2
].
cast
<
std
::
vector
<
scheduler
::
TokenLength
>>
();
bqt
.
block_indexes
=
t
[
3
].
cast
<
std
::
vector
<
torch
::
Tensor
>>
();
bqt
.
attn_masks
=
t
[
4
].
cast
<
std
::
optional
<
torch
::
Tensor
>>
();
bqt
.
rope_ranges
=
t
[
5
].
cast
<
std
::
optional
<
torch
::
Tensor
>>
();
bqt
.
sample_options
=
t
[
6
].
cast
<
std
::
vector
<
scheduler
::
SampleOptions
>>
();
bqt
.
prefill_mini_batches
=
t
[
7
].
cast
<
std
::
vector
<
scheduler
::
PrefillTask
>>
();
bqt
.
decode_mini_batches
=
t
[
8
].
cast
<
std
::
vector
<
std
::
vector
<
scheduler
::
QueryID
>>>
();
bqt
.
stop_criteria
=
t
[
9
].
cast
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>>
();
return
bqt
;
}));
py
::
class_
<
scheduler
::
QueryUpdate
>
(
m
,
"QueryUpdate"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"id"
,
&
scheduler
::
QueryUpdate
::
id
)
.
def_readwrite
(
"ok"
,
&
scheduler
::
QueryUpdate
::
ok
)
.
def_readwrite
(
"is_prefill"
,
&
scheduler
::
QueryUpdate
::
is_prefill
)
.
def_readwrite
(
"decode_done"
,
&
scheduler
::
QueryUpdate
::
decode_done
)
.
def_readwrite
(
"active_position"
,
&
scheduler
::
QueryUpdate
::
active_position
)
.
def_readwrite
(
"generated_token"
,
&
scheduler
::
QueryUpdate
::
generated_token
)
.
def
(
py
::
pickle
(
[](
const
scheduler
::
QueryUpdate
&
self
)
{
return
py
::
make_tuple
(
self
.
id
,
self
.
ok
,
self
.
is_prefill
,
self
.
decode_done
,
self
.
active_position
,
self
.
generated_token
);
},
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
6
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
QueryUpdate
qu
;
qu
.
id
=
t
[
0
].
cast
<
scheduler
::
QueryID
>
();
qu
.
ok
=
t
[
1
].
cast
<
bool
>
();
qu
.
is_prefill
=
t
[
2
].
cast
<
bool
>
();
qu
.
decode_done
=
t
[
3
].
cast
<
bool
>
();
qu
.
active_position
=
t
[
4
].
cast
<
scheduler
::
TokenLength
>
();
qu
.
generated_token
=
t
[
5
].
cast
<
scheduler
::
Token
>
();
return
qu
;
}));
py
::
class_
<
scheduler
::
InferenceContext
>
(
m
,
"InferenceContext"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"k_cache"
,
&
scheduler
::
InferenceContext
::
k_cache
)
.
def_readwrite
(
"v_cache"
,
&
scheduler
::
InferenceContext
::
v_cache
)
;
py
::
class_
<
scheduler
::
QueryAdd
>
(
m
,
"QueryAdd"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"query_token"
,
&
scheduler
::
QueryAdd
::
query_token
)
// .def_readwrite("attn_mask", &scheduler::QueryAdd::attn_mask)
.
def_readwrite
(
"query_length"
,
&
scheduler
::
QueryAdd
::
query_length
)
.
def_readwrite
(
"estimated_length"
,
&
scheduler
::
QueryAdd
::
estimated_length
)
.
def_readwrite
(
"sample_options"
,
&
scheduler
::
QueryAdd
::
sample_options
)
.
def_readwrite
(
"user_id"
,
&
scheduler
::
QueryAdd
::
user_id
)
.
def_readwrite
(
"SLO_TTFT_ms"
,
&
scheduler
::
QueryAdd
::
SLO_TTFT_ms
)
.
def_readwrite
(
"SLO_TBT_ms"
,
&
scheduler
::
QueryAdd
::
SLO_TBT_ms
)
.
def_readwrite
(
"stop_criteria"
,
&
scheduler
::
QueryAdd
::
stop_criteria
)
.
def
(
"serialize"
,
&
scheduler
::
QueryAdd
::
serialize
)
.
def_static
(
"deserialize"
,
&
scheduler
::
QueryAdd
::
deserialize
)
.
def
(
py
::
pickle
(
[](
const
scheduler
::
QueryAdd
&
self
)
{
return
py
::
make_tuple
(
self
.
query_token
,
// self.attn_mask,
self
.
query_length
,
self
.
estimated_length
,
self
.
sample_options
,
self
.
user_id
,
self
.
SLO_TTFT_ms
,
self
.
SLO_TBT_ms
,
self
.
stop_criteria
);
},
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
8
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
QueryAdd
qa
;
qa
.
query_token
=
t
[
0
].
cast
<
std
::
vector
<
scheduler
::
Token
>>
();
// qa.attn_mask = t[1].cast<torch::Tensor>();
qa
.
query_length
=
t
[
1
].
cast
<
scheduler
::
TokenLength
>
();
qa
.
estimated_length
=
t
[
2
].
cast
<
scheduler
::
TokenLength
>
();
qa
.
sample_options
=
t
[
3
].
cast
<
scheduler
::
SampleOptions
>
();
qa
.
user_id
=
t
[
4
].
cast
<
scheduler
::
UserID
>
();
qa
.
SLO_TTFT_ms
=
t
[
5
].
cast
<
int
>
();
qa
.
SLO_TBT_ms
=
t
[
6
].
cast
<
int
>
();
qa
.
stop_criteria
=
t
[
7
].
cast
<
std
::
vector
<
std
::
vector
<
int
>>>
();
return
qa
;
}));
py
::
class_
<
scheduler
::
Scheduler
,
std
::
shared_ptr
<
scheduler
::
Scheduler
>>
(
m
,
"Scheduler"
)
.
def
(
"init"
,
&
scheduler
::
Scheduler
::
init
)
.
def
(
"run"
,
&
scheduler
::
Scheduler
::
run
)
.
def
(
"stop"
,
&
scheduler
::
Scheduler
::
stop
)
.
def
(
"add_query"
,
&
scheduler
::
Scheduler
::
add_query
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"cancel_query"
,
&
scheduler
::
Scheduler
::
cancel_query
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"update_last_batch"
,
&
scheduler
::
Scheduler
::
update_last_batch
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"get_inference_context"
,
&
scheduler
::
Scheduler
::
get_inference_context
);
m
.
def
(
"create_scheduler"
,
&
scheduler
::
create_scheduler
,
"Create a new Scheduler instance"
);
}
csrc/balance_serve/sched/metrics.cpp
0 → 100644
View file @
25cee581
#include "metrics.h"
#include <iostream>
// 构造函数
Metrics
::
Metrics
(
const
MetricsConfig
&
config
)
:
registry_
(
std
::
make_shared
<
prometheus
::
Registry
>
()),
exposer_
(
config
.
endpoint
),
stop_uptime_thread_
(
false
),
start_time_
(
std
::
chrono
::
steady_clock
::
now
())
{
// 定义统一的桶大小,最大为 10000 ms (10 s)
std
::
vector
<
double
>
common_buckets
=
{
0.001
,
0.005
,
0.01
,
0.05
,
0.1
,
0.5
,
1.0
,
5.0
,
10.0
,
50.0
,
100.0
,
500.0
,
1000.0
,
5000.0
,
10000.0
};
// 毫秒
// 注册 TTFT_ms Histogram
auto
&
TTFT_family
=
prometheus
::
BuildHistogram
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_TTFT_ms"
)
.
Help
(
"Time to first token in milliseconds"
)
.
Register
(
*
registry_
);
TTFT_ms
=
&
TTFT_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
// 注册 TBT_ms Histogram
auto
&
TBT_family
=
prometheus
::
BuildHistogram
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_TBT_ms"
)
.
Help
(
"Time between tokens in milliseconds"
)
.
Register
(
*
registry_
);
TBT_ms
=
&
TBT_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
// 注册 schedule_time Histogram
auto
&
schedule_time_family
=
prometheus
::
BuildHistogram
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_schedule_time_ms"
)
.
Help
(
"Time to generate schedule in milliseconds"
)
.
Register
(
*
registry_
);
schedule_time
=
&
schedule_time_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
// 注册 generated_tokens Counter
auto
&
generated_tokens_family
=
prometheus
::
BuildCounter
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_generated_tokens_total"
)
.
Help
(
"Total generated tokens"
)
.
Register
(
*
registry_
);
generated_tokens
=
&
generated_tokens_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 throughput_query Gauge
auto
&
throughput_query_family
=
prometheus
::
BuildGauge
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_throughput_query"
)
.
Help
(
"Throughput per second based on queries"
)
.
Register
(
*
registry_
);
throughput_query
=
&
throughput_query_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 throughput_generated_tokens Gauge
auto
&
throughput_generated_tokens_family
=
prometheus
::
BuildGauge
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_throughput_generated_tokens"
)
.
Help
(
"Throughput per second based on generated tokens"
)
.
Register
(
*
registry_
);
throughput_generated_tokens
=
&
throughput_generated_tokens_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 event_count Counter family
event_count_family_
=
&
prometheus
::
BuildCounter
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_event_count_total"
)
.
Help
(
"Count of various events"
)
.
Register
(
*
registry_
);
batch_count_family_
=
&
prometheus
::
BuildCounter
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_batch_count_total"
)
.
Help
(
"Count of various batch by status"
)
.
Register
(
*
registry_
);
// 注册 query_count Counter family
query_count_family_
=
&
prometheus
::
BuildCounter
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_query_count_total"
)
.
Help
(
"Count of queries by status"
)
.
Register
(
*
registry_
);
// 注册 uptime_ms Gauge
auto
&
uptime_family
=
prometheus
::
BuildGauge
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_uptime_ms"
)
.
Help
(
"Uptime of the scheduler in milliseconds"
)
.
Register
(
*
registry_
);
uptime_ms
=
&
uptime_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 GPU 利用率 Gauges
auto
&
gpu_util_family
=
prometheus
::
BuildGauge
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_gpu_utilization_ratio"
)
.
Help
(
"Current GPU utilization ratio (0 to 1)"
)
.
Register
(
*
registry_
);
for
(
size_t
i
=
0
;
i
<
config
.
gpu_count
;
++
i
)
{
gpu_utilization_gauges
.
push_back
(
&
gpu_util_family
.
Add
({{
"gpu_id"
,
std
::
to_string
(
i
)},
{
"model"
,
config
.
model_name
}}));
}
// 将 Registry 注册到 Exposer 中
exposer_
.
RegisterCollectable
(
registry_
);
// 启动 uptime 更新线程
StartUptimeUpdater
();
}
// 析构函数
Metrics
::~
Metrics
()
{
StopUptimeUpdater
();
}
// 启动 uptime 更新线程
void
Metrics
::
StartUptimeUpdater
()
{
uptime_thread_
=
std
::
thread
([
this
]()
{
while
(
!
stop_uptime_thread_
)
{
auto
now
=
std
::
chrono
::
steady_clock
::
now
();
std
::
chrono
::
duration
<
double
,
std
::
milli
>
uptime_duration
=
now
-
start_time_
;
uptime_ms
->
Set
(
uptime_duration
.
count
());
// fn_every_sec(this);
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
}
});
}
// 停止 uptime 更新线程
void
Metrics
::
StopUptimeUpdater
()
{
stop_uptime_thread_
=
true
;
if
(
uptime_thread_
.
joinable
())
{
uptime_thread_
.
join
();
}
}
// 获取 event_count 指标
prometheus
::
Counter
*
Metrics
::
event_count
(
const
std
::
string
&
type
)
{
return
&
event_count_family_
->
Add
({{
"type"
,
type
}});
// 可根据需要添加更多标签
}
// 获取 query_count 指标
prometheus
::
Counter
*
Metrics
::
query_count
(
const
std
::
string
&
status
)
{
return
&
query_count_family_
->
Add
({{
"status"
,
status
}});
// 可根据需要添加更多标签
}
prometheus
::
Counter
*
Metrics
::
batch_count
(
const
std
::
string
&
type
)
{
return
&
batch_count_family_
->
Add
({{
"type"
,
type
}});
}
csrc/balance_serve/sched/metrics.h
0 → 100644
View file @
25cee581
#ifndef Metrics_H
#define Metrics_H
#include <prometheus/counter.h>
#include <prometheus/exposer.h>
#include <prometheus/gauge.h>
#include <prometheus/histogram.h>
#include <prometheus/registry.h>
#include <atomic>
#include <chrono>
#include <memory>
#include <string>
#include <thread>
#include <vector>
#include "timer.hpp"
// 指标前缀宏定义
#define METRIC_PREFIX "scheduler"
class
Metrics
;
// 配置结构体
struct
MetricsConfig
{
std
::
string
endpoint
;
std
::
string
model_name
;
// 模型名称,如 "gpt-4"
size_t
gpu_count
;
// GPU数量
};
// Metrics 类,根据配置初始化 Prometheus 指标
class
Metrics
{
public:
// 构造函数传入 MetricsConfig
Metrics
(
const
MetricsConfig
&
config
);
~
Metrics
();
// 禁止拷贝和赋值
Metrics
(
const
Metrics
&
)
=
delete
;
Metrics
&
operator
=
(
const
Metrics
&
)
=
delete
;
std
::
function
<
void
(
Metrics
*
)
>
fn_every_sec
;
// 指标指针
prometheus
::
Gauge
*
uptime_ms
;
prometheus
::
Histogram
*
TTFT_ms
;
prometheus
::
Histogram
*
TBT_ms
;
prometheus
::
Histogram
*
schedule_time
;
prometheus
::
Gauge
*
throughput_query
;
prometheus
::
Gauge
*
throughput_generated_tokens
;
prometheus
::
Counter
*
generated_tokens
;
std
::
vector
<
prometheus
::
Gauge
*>
gpu_utilization_gauges
;
// 计数器家族
prometheus
::
Counter
*
event_count
(
const
std
::
string
&
type
);
prometheus
::
Counter
*
query_count
(
const
std
::
string
&
status
);
prometheus
::
Counter
*
batch_count
(
const
std
::
string
&
type
);
private:
std
::
shared_ptr
<
prometheus
::
Registry
>
registry_
;
prometheus
::
Exposer
exposer_
;
// 计数器家族
prometheus
::
Family
<
prometheus
::
Counter
>*
event_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>*
batch_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>*
query_count_family_
;
// 线程和控制变量用于更新 uptime_ms
std
::
thread
uptime_thread_
;
std
::
atomic
<
bool
>
stop_uptime_thread_
;
// 启动 uptime 更新线程
void
StartUptimeUpdater
();
// 停止 uptime 更新线程
void
StopUptimeUpdater
();
// 记录程序启动时间
std
::
chrono
::
steady_clock
::
time_point
start_time_
;
};
struct
HistogramTimerWrapper
{
prometheus
::
Histogram
*
histogram
;
Timer
timer
;
inline
HistogramTimerWrapper
(
prometheus
::
Histogram
*
histogram
)
:
histogram
(
histogram
),
timer
()
{
timer
.
start
();
}
inline
~
HistogramTimerWrapper
()
{
histogram
->
Observe
(
timer
.
elapsedMs
());
}
};
#endif // Metrics_H
csrc/balance_serve/sched/model_config.h
0 → 100644
View file @
25cee581
#ifndef __MODEL_CONFIG_HPP_
#define __MODEL_CONFIG_HPP_
#include <iostream>
#include "nlohmann/json.hpp"
#include <filesystem>
#include <fstream>
using
DimSize
=
size_t
;
using
URL
=
std
::
string
;
using
ModelName
=
std
::
string
;
// We must assure this can be load by config.json
class
ModelConfig
{
public:
DimSize
hidden_size
;
DimSize
intermediate_size
;
size_t
max_position_embeddings
;
std
::
string
model_type
;
size_t
num_attention_heads
;
size_t
num_hidden_layers
;
size_t
num_key_value_heads
;
size_t
vocab_size
;
NLOHMANN_DEFINE_TYPE_INTRUSIVE
(
ModelConfig
,
hidden_size
,
intermediate_size
,
max_position_embeddings
,
model_type
,
num_attention_heads
,
num_hidden_layers
,
num_key_value_heads
,
vocab_size
);
void
load_from
(
std
::
filesystem
::
path
path
)
{
std
::
ifstream
i
(
path
);
nlohmann
::
json
j
;
i
>>
j
;
*
this
=
j
.
get
<
ModelConfig
>
();
}
};
using
QuantType
=
std
::
string
;
static
const
QuantType
NoQuantType
=
""
;
class
QuantConfig
{
public:
QuantType
name
;
// For GEMV
QuantType
type_of_dot_vector
=
NoQuantType
;
inline
bool
can_be_used_as_matrix
()
{
return
type_of_dot_vector
!=
NoQuantType
;
}
bool
can_be_used_as_vector
;
double
bytes_per_element
;
bool
has_scale
;
bool
has_min
;
size_t
block_element_count
;
size_t
block_element_size
;
URL
reference
=
""
;
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT
(
QuantConfig
,
name
,
type_of_dot_vector
,
can_be_used_as_vector
,
bytes_per_element
,
has_scale
,
has_min
,
block_element_count
,
block_element_size
,
reference
);
};
inline
std
::
map
<
QuantType
,
QuantConfig
>
quant_configs
;
inline
std
::
map
<
ModelName
,
ModelConfig
>
model_configs
;
inline
void
load_quant_configs
(
std
::
filesystem
::
path
path
)
{
nlohmann
::
json
j
;
if
(
std
::
filesystem
::
exists
(
path
))
{
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
i
>>
j
;
}
else
{
std
::
cout
<<
__FUNCTION__
<<
" create new at "
<<
path
<<
std
::
endl
;
}
quant_configs
=
j
.
get
<
std
::
map
<
QuantType
,
QuantConfig
>>
();
std
::
cout
<<
"Loaded Quant Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
quant_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
inline
void
dump_quant_configs
(
std
::
filesystem
::
path
path
)
{
std
::
ofstream
o
(
path
);
nlohmann
::
json
j
=
quant_configs
;
o
<<
j
.
dump
(
4
);
}
inline
void
load_model_configs
(
std
::
filesystem
::
path
path
)
{
nlohmann
::
json
j
;
if
(
std
::
filesystem
::
exists
(
path
))
{
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
i
>>
j
;
}
else
{
std
::
cout
<<
__FUNCTION__
<<
" create new at "
<<
path
<<
std
::
endl
;
}
model_configs
=
j
.
get
<
std
::
map
<
ModelName
,
ModelConfig
>>
();
std
::
cout
<<
"Loaded Model Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
model_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
inline
void
dump_model_configs
(
std
::
filesystem
::
path
path
)
{
std
::
ofstream
o
(
path
);
nlohmann
::
json
j
=
model_configs
;
o
<<
j
.
dump
(
4
);
}
#endif
\ No newline at end of file
csrc/balance_serve/sched/scheduler.cpp
0 → 100644
View file @
25cee581
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO
#define FMT_HEADER_ONLY
#include "nlohmann/json.hpp"
#include "spdlog/spdlog.h"
#include <optional>
#include "scheduler.h"
#include <atomic>
#include <cassert>
#include <future>
#include <memory>
#include <queue>
#include "arithmetic.hpp"
#include "atomic_ptr_with_flags.hpp"
#include "easy_format.hpp"
#include "metrics.h"
#include "mpsc.hpp"
#include "timer.hpp"
#include "kvc2.h"
using
json
=
nlohmann
::
json
;
namespace
scheduler
{
void
Settings
::
auto_derive
()
{
gpu_device_count
=
gpu_device_id
.
size
();
if
(
torch
::
cuda
::
is_available
())
{
size_t
gpu_count
=
torch
::
cuda
::
device_count
();
SPDLOG_INFO
(
"Number of available GPUs: {}, want {}"
,
gpu_count
,
gpu_device_count
);
if
(
gpu_count
<
gpu_device_count
)
{
SPDLOG_ERROR
(
"Not enough GPUs available."
);
exit
(
0
);
}
for
(
size_t
i
=
0
;
i
<
gpu_device_count
;
i
++
)
{
devices
.
push_back
(
torch
::
Device
(
torch
::
kCUDA
,
gpu_device_id
[
i
]));
}
}
else
{
SPDLOG_ERROR
(
"CUDA is not available on this system."
);
exit
(
0
);
}
if
(
model_settings
.
num_k_heads
%
gpu_device_count
!=
0
)
{
SPDLOG_ERROR
(
"num_k_heads {} is not divisible by gpu_device_count {}"
,
model_settings
.
num_k_heads
,
gpu_device_count
);
assert
(
false
);
}
size_t
gpu_memory_available
=
gpu_memory_size
*
memory_utilization_percentage
;
if
(
gpu_memory_available
*
gpu_device_count
<
model_settings
.
params_nbytes
())
{
SPDLOG_ERROR
(
"GPU memory size {}G is smaller than {}G"
,
gpu_memory_available
*
gpu_device_count
/
1e9
,
model_settings
.
params_nbytes
()
/
1e9
);
assert
(
false
);
}
assert
(
model_settings
.
k_head_dim
%
model_settings
.
num_k_heads
==
0
);
size_t
head_per_gpu
=
model_settings
.
num_k_heads
/
gpu_device_count
;
size_t
gpu_memory_for_kv_cache
=
gpu_memory_available
/*- model_settings.params_nbytes() / gpu_device_count*/
;
SPDLOG_INFO
(
"Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB"
,
gpu_memory_size
/
(
1
<<
20
),
model_settings
.
params_nbytes
()
/
gpu_device_count
/
(
1
<<
20
),
gpu_memory_for_kv_cache
/
(
1
<<
20
),
(
gpu_memory_size
-
gpu_memory_available
)
/
(
1
<<
20
));
size_t
kv_cache_on_cnt
=
(
size_t
)(
k_cache_on
)
+
(
size_t
)(
v_cache_on
);
size_t
max_total_kvcache_pages
=
gpu_memory_for_kv_cache
/
(
kv_cache_on_cnt
*
head_per_gpu
*
model_settings
.
k_head_dim
*
model_settings
.
bytes_per_kv_cache_element
*
page_size
*
model_settings
.
layer_count
);
if
(
total_kvcache_pages
.
has_value
())
{
if
(
total_kvcache_pages
.
value
()
>
max_total_kvcache_pages
)
{
SPDLOG_ERROR
(
"total_kvcache_pages {} is larger than max_total_kvcache_pages {}"
,
total_kvcache_pages
.
value
(),
max_total_kvcache_pages
);
assert
(
false
);
}
}
else
{
total_kvcache_pages
=
max_total_kvcache_pages
;
SPDLOG_INFO
(
"total_kvcache_pages is auto derived as {}"
,
max_total_kvcache_pages
);
}
if
(
page_size
%
256
!=
0
)
{
SPDLOG_ERROR
(
"page_size {} is not divisible by 256"
,
page_size
);
assert
(
false
);
}
if
(
page_size
<
256
)
{
SPDLOG_ERROR
(
"page_size {} is smaller than 256"
,
page_size
);
assert
(
false
);
}
}
std
::
string
BatchQueryTodo
::
debug
()
{
std
::
string
re
=
"BatchQueryTodo: "
;
re
+=
"QueryIDs: "
;
for
(
auto
&
id
:
query_ids
)
{
re
+=
std
::
to_string
(
id
)
+
" "
;
}
return
re
;
}
bool
BatchQueryTodo
::
empty
()
{
return
prefill_mini_batches
.
empty
()
&&
decode_mini_batches
.
empty
();
}
struct
QueryMaintainer
;
struct
Query
{
QueryID
id
;
torch
::
Tensor
query_token
;
TokenLength
prompt_length
;
TokenLength
no_kvcache_from
;
TokenLength
estimated_length
;
SampleOptions
sample_options
;
UserID
user_id
;
std
::
optional
<
int
>
SLO_TTFT_ms
;
std
::
optional
<
int
>
SLO_TBT_ms
;
std
::
vector
<
std
::
vector
<
int
>>
stop_criteria
;
// status
// Query status changed by this order
enum
Status
{
Received
,
Preparing
,
Ready
,
Prefill
,
Decode
,
Done
};
Status
plan_status
=
Received
;
TokenLength
active_position
;
// the position where no kvcache now
TokenLength
plan_position
;
// the position where no kvcache now, in plan
size_t
prepare_try_count
=
0
;
std
::
shared_ptr
<
kvc2
::
DoubleCacheHandleInterface
>
kvc2_handle
=
nullptr
;
// derived from kvc2_handle
torch
::
Tensor
block_index
;
// block indexes
struct
QueryContext
{
ModelName
model_name
;
QuantType
quant_type
;
kvc2
::
KVC2Interface
*
kvc2_interface
;
QueryMaintainer
*
query_maintainer
;
Metrics
*
met
;
}
ctx
;
void
after_load
(
bool
ok
);
void
to_status
(
Status
to
);
void
export_metrics
()
{
ctx
.
met
->
query_count
(
status_to_string
(
plan_status
))
->
Increment
(
1
);
}
Query
(
QueryID
id
,
QueryAdd
query_add
,
QueryContext
context
)
:
id
(
id
),
prompt_length
(
query_add
.
query_length
),
no_kvcache_from
(
0
),
estimated_length
(
query_add
.
estimated_length
),
sample_options
(
query_add
.
sample_options
),
user_id
(
query_add
.
user_id
),
SLO_TTFT_ms
(
query_add
.
SLO_TTFT_ms
),
SLO_TBT_ms
(
query_add
.
SLO_TBT_ms
),
stop_criteria
(
query_add
.
stop_criteria
),
ctx
(
context
)
{
std
::
vector
<
int64_t
>
shape
=
{
int64_t
(
query_add
.
estimated_length
)};
query_token
=
torch
::
zeros
(
shape
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
));
assert
(
query_token
.
is_contiguous
());
if
(
query_token
.
is_contiguous
()
==
false
)
{
SPDLOG_ERROR
(
"Query Token must be contiguous!"
);
exit
(
1
);
}
memcpy
(
query_token
.
data_ptr
(),
query_add
.
query_token
.
data
(),
query_add
.
query_length
*
sizeof
(
Token
));
no_kvcache_from
=
0
;
// maybe match prefix later
export_metrics
();
}
Token
&
token_at
(
size_t
idx
)
{
return
reinterpret_cast
<
Token
*>
(
query_token
.
data_ptr
())[
idx
];
}
void
absorb_update
(
const
QueryUpdate
&
update
)
{
SPDLOG_DEBUG
(
"{}"
,
update
.
debug
());
active_position
=
update
.
active_position
;
kvc2_handle
->
append_tokens
(
&
token_at
(
0
),
active_position
);
// active_position is length -1
if
(
update
.
is_prefill
)
{
if
(
active_position
==
prompt_length
)
{
token_at
(
active_position
)
=
update
.
generated_token
;
ctx
.
met
->
generated_tokens
->
Increment
(
1
);
}
}
else
{
token_at
(
active_position
)
=
update
.
generated_token
;
ctx
.
met
->
generated_tokens
->
Increment
(
1
);
}
if
(
update
.
decode_done
||
active_position
==
estimated_length
-
1
)
{
to_status
(
Done
);
}
}
void
absorb_prefill_task
(
const
PrefillTask
&
task
)
{
auto
&
[
id
,
start
,
length
]
=
task
;
this
->
plan_position
=
start
+
length
;
if
(
this
->
plan_position
==
prompt_length
)
{
to_status
(
Decode
);
}
}
void
absorb_decode_task
([[
maybe_unused
]]
const
QueryID
&
task
)
{
this
->
plan_position
+=
1
;
}
PrefillTask
get_prefill_task
(
size_t
prefill_length
)
{
if
(
prefill_length
+
plan_position
>
prompt_length
)
{
prefill_length
=
prompt_length
-
plan_position
;
}
return
{
id
,
plan_position
,
prefill_length
};
}
static
std
::
string
status_to_string
(
Status
status
)
{
switch
(
status
)
{
case
Received
:
return
"Received"
;
case
Preparing
:
return
"Preparing"
;
case
Ready
:
return
"Ready"
;
case
Prefill
:
return
"Prefill"
;
case
Decode
:
return
"Decode"
;
case
Done
:
return
"Done"
;
}
assert
(
false
);
}
void
debug
()
{
std
::
string
status_string
=
status_to_string
(
plan_status
);
SPDLOG_DEBUG
(
"Query {}, prompt_length {}, estimated_length {}, plan status {}, plan position {} "
"active position {}"
,
id
,
prompt_length
,
estimated_length
,
status_string
,
plan_position
,
active_position
);
}
};
std
::
string
QueryUpdate
::
debug
()
const
{
return
fmt
::
format
(
"Query {}, ok {}, is_prefill {}, done {}, active_position {}, gen token {}"
,
id
,
ok
,
is_prefill
,
decode_done
,
active_position
,
generated_token
);
}
using
Q
=
std
::
shared_ptr
<
Query
>
;
struct
KVC2_Maintainer
{
Settings
settings
;
std
::
vector
<
torch
::
Tensor
>
k_cache
;
std
::
vector
<
torch
::
Tensor
>
v_cache
;
std
::
shared_ptr
<
kvc2
::
KVC2Interface
>
kvc2_interface
;
KVC2_Maintainer
(
Settings
settings
)
:
settings
(
settings
)
{
// SPDLOG_WARN("Creating KVC2 Instance {}", settings.kvc2_root_path);
assert
(
settings
.
kvc2_root_path
.
size
()
>
0
);
// SPDLOG_WARN("Sizeof KVC2Config {} upper", sizeof(kvc2::KVC2Config));
kvc2
::
GPUPageCacheConfig
gpu_cache_config
{
.
gpu_only
=
settings
.
gpu_only
,
.
gpu_devices_id
=
settings
.
gpu_device_id
,
.
layer_count
=
settings
.
model_settings
.
layer_count
,
.
total_kvcache_pages
=
settings
.
total_kvcache_pages
.
value
(),
.
num_token_per_page
=
settings
.
page_size
,
.
num_k_heads
=
settings
.
model_settings
.
num_k_heads
,
.
k_head_dim
=
settings
.
use_self_defined_head_dim
?
settings
.
self_defined_head_dim
:
settings
.
model_settings
.
k_head_dim
,
.
full_kv_cache_on_each_gpu
=
settings
.
full_kv_cache_on_each_gpu
,
.
k_cache_on
=
settings
.
k_cache_on
,
.
v_cache_on
=
settings
.
v_cache_on
,
.
tensor_type
=
torch
::
kBFloat16
,
};
auto
model_configs_path
=
std
::
filesystem
::
path
(
settings
.
kvc2_config_path
)
/
"model_configs.json"
;
load_model_configs
(
model_configs_path
);
auto
my_model_config
=
ModelConfig
();
my_model_config
.
load_from
(
std
::
filesystem
::
path
(
settings
.
model_settings
.
model_path
)
/
"config.json"
);
model_configs
[
settings
.
model_name
]
=
my_model_config
;
dump_model_configs
(
model_configs_path
);
kvc2
::
KVC2Config
kvc2_config
=
{
.
k_cache_on
=
settings
.
k_cache_on
,
.
v_cache_on
=
settings
.
v_cache_on
,
.
gpu_only
=
settings
.
gpu_only
,
.
load_from_disk
=
settings
.
load_from_disk
,
.
save_to_disk
=
settings
.
save_to_disk
,
.
path
=
settings
.
kvc2_root_path
,
.
config_path
=
settings
.
kvc2_config_path
,
.
num_token_per_page
=
settings
.
page_size
,
.
memory_pool_size
=
size_t
(
settings
.
memory_pool_size_GB
*
1e9
),
.
evict_count
=
settings
.
evict_count
,
.
gpu_cache_config
=
gpu_cache_config
,
.
metrics_port
=
settings
.
kvc2_metrics_port
,
};
kvc2_interface
=
kvc2
::
create_kvc2
(
kvc2_config
);
if
(
settings
.
load_from_disk
)
kvc2_interface
->
load
();
SPDLOG_DEBUG
(
"KVC2 created ok"
);
auto
[
k_cache
,
v_cache
]
=
kvc2_interface
->
get_kvcache
();
this
->
k_cache
=
k_cache
;
this
->
v_cache
=
v_cache
;
}
};
using
EventAddQuery
=
std
::
pair
<
QueryAdd
,
std
::
promise
<
QueryID
>*>
;
using
EventUpdateQuery
=
BatchQueryUpdate
;
using
EventTakenBatch
=
std
::
shared_ptr
<
BatchQueryTodo
>
;
struct
EventPrepare
{
QueryID
query_id
;
bool
first_try
;
};
struct
EventPrepared
{
QueryID
query_id
;
bool
ok
;
};
struct
EventQueryStatus
{
QueryID
query_id
;
Query
::
Status
now_status
;
};
struct
EventSchedule
{};
using
Event
=
std
::
variant
<
EventAddQuery
,
EventUpdateQuery
,
EventTakenBatch
,
EventPrepare
,
EventPrepared
,
EventQueryStatus
,
EventSchedule
>
;
template
<
typename
T
>
std
::
string
event_name
(
const
T
&
event
);
template
<
>
std
::
string
event_name
(
const
EventAddQuery
&
)
{
return
"EventAddQuery"
;
}
template
<
>
std
::
string
event_name
(
const
EventUpdateQuery
&
)
{
return
"EventUpdateQuery"
;
}
template
<
>
std
::
string
event_name
(
const
EventTakenBatch
&
)
{
return
"EventTakenBatch"
;
}
template
<
>
std
::
string
event_name
(
const
EventPrepare
&
)
{
return
"EventPrepare"
;
}
template
<
>
std
::
string
event_name
(
const
EventPrepared
&
)
{
return
"EventPrepared"
;
}
template
<
>
std
::
string
event_name
(
const
EventQueryStatus
&
)
{
return
"EventQueryStatus"
;
}
template
<
>
std
::
string
event_name
(
const
EventSchedule
&
)
{
return
"EventSchedule"
;
}
// 用 std::visit 实现对 variant 的 event_name
std
::
string
event_name
(
const
Event
&
event
)
{
return
std
::
visit
([](
const
auto
&
e
)
{
return
event_name
(
e
);
},
event
);
}
static_assert
(
std
::
is_copy_constructible
<
Event
>::
value
);
static_assert
(
std
::
is_move_constructible
<
Event
>::
value
);
struct
QueryMaintainer
:
public
Scheduler
{
// only get access by event loop
Settings
settings
;
QueryID
query_id_counter
=
NoQueryID
+
1
;
std
::
map
<
QueryID
,
Q
>
query_map
;
std
::
shared_ptr
<
KVC2_Maintainer
>
kvc2_maintainer
;
std
::
shared_ptr
<
Metrics
>
met
;
// multi-thread visit
std
::
atomic_bool
stop_flag
=
false
;
// TODO consider correctness of event loop
MPSCQueueConsumerLock
<
Event
>
event_loop_queue
;
// std::binary_semaphore batch_ready{0};
AtomicPtrWithFlag
<
BatchQueryTodo
>
next_batch
;
QueryMaintainer
()
=
default
;
void
gen_batch_query_todo
(
BatchQueryTodo
*
re
,
const
std
::
set
<
Q
>&
queries
)
{
std
::
vector
<
std
::
vector
<
QueryID
>>
d_batch
(
2
);
size_t
last_decode_batch
=
0
;
size_t
prefill_num
=
0
;
size_t
decode_num
=
0
;
size_t
preill_length
=
0
;
for
(
auto
&
q
:
queries
)
{
if
(
q
->
plan_status
==
Query
::
Prefill
)
{
prefill_num
+=
1
;
}
if
(
q
->
plan_status
==
Query
::
Decode
)
{
decode_num
+=
1
;
}
}
if
(
prefill_num
>=
2
||
(
prefill_num
==
1
&&
settings
.
max_batch_size
-
2
<
decode_num
))
{
preill_length
=
settings
.
recommended_chunk_prefill_token_count
;
}
else
{
preill_length
=
settings
.
recommended_chunk_prefill_token_count
*
2
;
}
for
(
auto
&
q
:
queries
)
{
re
->
query_ids
.
push_back
(
q
->
id
);
re
->
query_tokens
.
push_back
(
q
->
query_token
);
re
->
query_lengths
.
push_back
(
q
->
prompt_length
);
if
(
q
->
plan_status
==
Query
::
Prefill
)
{
re
->
prefill_mini_batches
.
push_back
(
q
->
get_prefill_task
(
preill_length
));
assert
(
re
->
prefill_mini_batches
.
size
()
<=
2
);
}
if
(
q
->
plan_status
==
Query
::
Decode
)
{
d_batch
[
last_decode_batch
].
push_back
(
q
->
id
);
// last_decode_batch = 1 - last_decode_batch;
if
(
d_batch
[
last_decode_batch
].
size
()
==
settings
.
max_batch_size
-
1
)
{
last_decode_batch
+=
1
;
assert
(
last_decode_batch
<
2
);
}
}
re
->
block_indexes
.
push_back
(
q
->
block_index
);
re
->
sample_options
.
push_back
(
q
->
sample_options
);
re
->
stop_criteria
.
push_back
(
q
->
stop_criteria
);
}
re
->
attn_masks
=
std
::
nullopt
;
re
->
rope_ranges
=
std
::
nullopt
;
for
(
auto
&
b
:
d_batch
)
{
if
(
b
.
empty
())
continue
;
re
->
decode_mini_batches
.
push_back
(
b
);
}
met
->
batch_count
(
"Generated"
)
->
Increment
(
1
);
}
// Interface
void
init
(
Settings
settings
)
override
{
SPDLOG_INFO
(
"
\n
Scheduler Settings:
\n
"
" model_name: {}
\n
"
" quant_type: {}
\n
"
" model_path: {}
\n
"
" params_count: {}
\n
"
" layer_count: {}
\n
"
" num_k_heads: {}
\n
"
" k_head_dim: {}
\n
"
" bytes_per_params: {}
\n
"
" bytes_per_kv_cache_element: {}
\n
"
" page_size: {}
\n
"
" gpu_device_id: {}
\n
"
" gpu_memory_size: {}
\n
"
" memory_utilization_percentage: {}
\n
"
" max_batch_size: {}
\n
"
" recommended_chunk_prefill_token_count: {}
\n
"
" sched_metrics_port: {}
\n
"
" kvc2_config_path: {}
\n
"
" kvc2_root_path: {}
\n
"
" memory_pool_size_GB: {}
\n
"
" evict_count: {}
\n
"
" kvc2_metrics_port: {}
\n
"
" load_from_disk: {}
\n
"
" save_to_disk: {}
\n
"
" strategy_name: {}
\n
"
" gpu_device_count: {}
\n
"
,
settings
.
model_name
,
settings
.
quant_type
,
settings
.
model_settings
.
model_path
,
settings
.
model_settings
.
params_count
,
settings
.
model_settings
.
layer_count
,
settings
.
model_settings
.
num_k_heads
,
settings
.
model_settings
.
k_head_dim
,
settings
.
model_settings
.
bytes_per_params
,
settings
.
model_settings
.
bytes_per_kv_cache_element
,
settings
.
page_size
,
format_vector
(
settings
.
gpu_device_id
),
readable_number
(
settings
.
gpu_memory_size
),
settings
.
memory_utilization_percentage
,
settings
.
max_batch_size
,
settings
.
recommended_chunk_prefill_token_count
,
settings
.
sched_metrics_port
,
settings
.
kvc2_config_path
,
settings
.
kvc2_root_path
,
settings
.
memory_pool_size_GB
,
settings
.
evict_count
,
settings
.
kvc2_metrics_port
,
settings
.
load_from_disk
,
settings
.
save_to_disk
,
settings
.
strategy_name
,
settings
.
gpu_device_count
);
this
->
settings
=
settings
;
kvc2_maintainer
=
std
::
shared_ptr
<
KVC2_Maintainer
>
(
new
KVC2_Maintainer
(
settings
));
MetricsConfig
met_conf
=
{
.
endpoint
=
"0.0.0.0:"
+
std
::
to_string
(
settings
.
sched_metrics_port
),
.
model_name
=
settings
.
model_name
,
.
gpu_count
=
settings
.
gpu_device_count
,
};
SPDLOG_INFO
(
"Creating scheduler metrics exporter on {}"
,
met_conf
.
endpoint
);
met
=
std
::
make_shared
<
Metrics
>
(
met_conf
);
met
->
fn_every_sec
=
[](
Metrics
*
met
)
{
auto
generated_tokens
=
met
->
generated_tokens
->
Collect
().
counter
.
value
;
SPDLOG_INFO
(
"Last Sec Generated Tokens {}"
,
generated_tokens
);
};
}
Query
::
QueryContext
get_query_context
()
{
return
Query
::
QueryContext
{
.
model_name
=
settings
.
model_name
,
.
quant_type
=
settings
.
quant_type
,
.
kvc2_interface
=
kvc2_maintainer
->
kvc2_interface
.
get
(),
.
query_maintainer
=
this
,
.
met
=
met
.
get
(),
};
}
QueryID
add_query
(
QueryAdd
query_add
)
override
{
std
::
promise
<
QueryID
>
p
;
event_loop_queue
.
enqueue
(
EventAddQuery
(
query_add
,
&
p
));
return
p
.
get_future
().
get
();
}
void
cancel_query
(
QueryID
id
)
override
{
SPDLOG_INFO
(
"Cancel Query"
);
SPDLOG_INFO
(
"sched:{} Cancel Query"
,
fmt
::
ptr
(
this
));
auto
it
=
query_map
.
find
(
id
);
if
(
it
==
query_map
.
end
())
{
SPDLOG_ERROR
(
"Query {} is not found"
,
id
);
return
;
}
query_map
.
erase
(
it
);
}
// Here this function update last batch results and get the next batch
// in most cases, the batch is ready,
// if not, busy wait to get it
std
::
shared_ptr
<
BatchQueryTodo
>
update_last_batch
(
BatchQueryUpdate
updates
)
override
{
event_loop_queue
.
enqueue
(
updates
);
// Busy Wait
while
(
true
)
{
auto
[
ptr
,
is_new
]
=
next_batch
.
touch_load
();
// SPDLOG_INFO("ptr {} is_new {}", fmt::ptr(ptr), is_new);
if
(
is_new
)
{
// SPDLOG_DEBUG("New Batch {}", fmt::ptr(ptr));
auto
re
=
std
::
shared_ptr
<
BatchQueryTodo
>
(
ptr
);
event_loop_queue
.
enqueue
(
re
);
return
re
;
}
else
{
// // here to busy wait
// SPDLOG_INFO("Not New");
// using namespace std::chrono_literals;
// std::this_thread::sleep_for(1s);
}
}
}
InferenceContext
get_inference_context
()
override
{
InferenceContext
re
;
re
.
k_cache
=
kvc2_maintainer
->
k_cache
;
re
.
v_cache
=
kvc2_maintainer
->
v_cache
;
// kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass this to inference loop
return
re
;
}
virtual
void
strategy_add_query
(
Q
new_query
)
=
0
;
virtual
void
strategy_update_query
(
const
EventUpdateQuery
&
update
)
=
0
;
virtual
void
strategy_taken_batch
(
const
EventTakenBatch
&
batch
)
=
0
;
virtual
void
strategy_prepare
(
const
EventPrepare
&
prepare
)
=
0
;
virtual
void
strategy_prepared
(
const
EventPrepared
&
prepared
)
=
0
;
virtual
void
strategy_query_status
(
const
EventQueryStatus
&
query_status
)
=
0
;
virtual
void
strategy_schedule
(
const
EventSchedule
&
event
,
BatchQueryTodo
*
new_batch
)
=
0
;
void
tackle_event
(
EventAddQuery
&
event
)
{
auto
&
query_add
=
event
.
first
;
QueryID
id
=
query_id_counter
;
event
.
second
->
set_value
(
id
);
query_id_counter
+=
1
;
Q
new_query
(
new
Query
(
id
,
query_add
,
get_query_context
()));
query_map
[
id
]
=
new_query
;
SPDLOG_INFO
(
"New Query {} is added"
,
id
);
strategy_add_query
(
new_query
);
}
void
tackle_event
(
const
EventUpdateQuery
&
update
)
{
// SPDLOG_INFO("Tackle Update Query");
for
(
auto
&
u
:
update
)
{
if
(
u
.
ok
==
false
)
{
SPDLOG_ERROR
(
"Query {} is not exectued OK"
,
u
.
id
);
exit
(
1
);
}
auto
q
=
query_map
[
u
.
id
];
if
(
q
->
plan_status
==
Query
::
Status
::
Prefill
||
q
->
plan_status
==
Query
::
Status
::
Decode
)
{
q
->
absorb_update
(
u
);
}
else
{
SPDLOG_DEBUG
(
"Query {} is not in Prefill or Decode status, do not update it"
,
u
.
id
);
}
}
strategy_update_query
(
update
);
}
void
tackle_event
(
const
EventTakenBatch
&
batch
)
{
met
->
batch_count
(
"Taken"
)
->
Increment
(
1
);
for
(
auto
&
task
:
batch
->
prefill_mini_batches
)
{
auto
[
id
,
s
,
l
]
=
task
;
if
(
l
==
0
)
continue
;
query_map
.
at
(
id
)
->
absorb_prefill_task
(
task
);
}
for
(
auto
&
mini_batch
:
batch
->
decode_mini_batches
)
{
for
(
auto
&
id
:
mini_batch
)
{
query_map
.
at
(
id
)
->
absorb_decode_task
(
id
);
}
}
strategy_taken_batch
(
batch
);
}
void
tackle_event
(
const
EventPrepare
&
event
)
{
strategy_prepare
(
event
);
}
void
tackle_event
(
const
EventPrepared
&
event
)
{
strategy_prepared
(
event
);
}
void
tackle_event
(
const
EventQueryStatus
&
event
)
{
strategy_query_status
(
event
);
}
void
tackle_event
(
const
EventSchedule
&
event
)
{
// SPDLOG_INFO("Tackle Schedule Event");
HistogramTimerWrapper
t
(
met
->
schedule_time
);
BatchQueryTodo
*
new_batch
=
new
BatchQueryTodo
;
strategy_schedule
(
event
,
new_batch
);
// if (new_batch->query_ids.empty()) {
// SPDLOG_INFO("Nothing todo");
// delete new_batch;
// return;
// }
auto
[
old_batch
,
flag
]
=
next_batch
.
exchange
(
new_batch
,
true
);
if
(
new_batch
->
empty
()
==
false
)
{
SPDLOG_DEBUG
(
"set new batch {}"
,
fmt
::
ptr
(
new_batch
));
}
if
(
flag
)
{
SPDLOG_INFO
(
"Batch {} is not consumed"
,
fmt
::
ptr
(
old_batch
));
delete
old_batch
;
}
}
void
run
()
override
{
std
::
thread
([
this
]()
{
SPDLOG_WARN
(
"Starting Scheduler Event Loop"
);
while
(
stop_flag
.
load
()
==
false
)
{
auto
event
=
event_loop_queue
.
dequeue
();
met
->
event_count
(
event_name
(
event
))
->
Increment
(
1
);
std
::
visit
(
[
this
](
auto
event
)
{
using
T
=
std
::
decay_t
<
decltype
(
event
)
>
;
// SPDLOG_INFO("Event Loop: {}", typeid(T).name());
if
constexpr
(
std
::
is_same_v
<
T
,
EventAddQuery
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventUpdateQuery
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventTakenBatch
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventPrepare
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventPrepared
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventQueryStatus
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventSchedule
>
)
{
tackle_event
(
event
);
}
else
{
SPDLOG_ERROR
(
"Should not be here"
);
assert
(
false
);
}
},
event
);
if
(
event_loop_queue
.
size
()
==
0
&&
std
::
holds_alternative
<
EventSchedule
>
(
event
)
==
false
)
{
// if this is not a schedule event, we need to schedule one
event_loop_queue
.
enqueue
(
EventSchedule
());
}
}
}).
detach
();
}
void
stop
()
override
{
stop_flag
.
store
(
true
);
}
~
QueryMaintainer
()
{
kvc2_maintainer
->
kvc2_interface
->
save
();
stop
();
}
};
void
Query
::
to_status
(
Status
to
)
{
SPDLOG_DEBUG
(
"Calling to status query {}, to {}"
,
id
,
status_to_string
(
to
));
switch
(
to
)
{
case
Received
:
assert
(
false
);
break
;
case
Preparing
:
SPDLOG_INFO
(
"Preparing Query {} {}"
,
id
,
prepare_try_count
>
0
?
(
std
::
to_string
(
prepare_try_count
)
+
" Try"
)
:
""
);
prepare_try_count
+=
1
;
ctx
.
kvc2_interface
->
lookup_to_gpu_async
(
ctx
.
model_name
,
ctx
.
quant_type
,
static_cast
<
kvc2
::
Token
*>
(
query_token
.
data_ptr
()),
prompt_length
,
estimated_length
,
[
this
](
std
::
shared_ptr
<
kvc2
::
DoubleCacheHandleInterface
>
handle
)
{
if
(
handle
==
nullptr
)
{
SPDLOG_INFO
(
"Get handle from kvc2 Failed."
);
this
->
after_load
(
false
);
}
else
{
SPDLOG_INFO
(
"Get handle from kvc2 Success."
);
this
->
kvc2_handle
=
handle
;
this
->
to_status
(
Ready
);
this
->
after_load
(
true
);
}
});
break
;
case
Ready
:
SPDLOG_INFO
(
"Ready Query {}"
,
id
);
break
;
case
Prefill
:
SPDLOG_INFO
(
"Prefilling Query {}"
,
id
);
// assert(plan_status == Received);
plan_position
=
kvc2_handle
->
matched_length
();
if
(
prompt_length
-
plan_position
==
0
)
{
assert
(
prompt_length
>
0
);
plan_position
-=
1
;
}
break
;
case
Decode
:
SPDLOG_INFO
(
"Decoding Query {}"
,
id
);
// assert(plan_status == Prefill);
break
;
case
Done
:
SPDLOG_INFO
(
"Finish Query {}"
,
id
);
kvc2_handle
=
nullptr
;
ctx
.
query_maintainer
->
event_loop_queue
.
enqueue
(
EventQueryStatus
{
.
query_id
=
id
,
.
now_status
=
to
,
});
// assert(plan_status == Decode);
break
;
}
plan_status
=
to
;
export_metrics
();
}
void
Query
::
after_load
(
bool
ok
)
{
if
(
ok
)
{
size_t
page_count
=
div_up
(
estimated_length
,
ctx
.
query_maintainer
->
settings
.
page_size
);
std
::
vector
<
int64_t
>
shape
;
shape
.
push_back
(
page_count
);
block_index
=
torch
::
zeros
(
shape
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
)).
contiguous
();
auto
ptr
=
reinterpret_cast
<
int32_t
*>
(
block_index
.
data_ptr
());
auto
vec_idx
=
kvc2_handle
->
get_gpu_block_idx
();
for
(
size_t
i
=
0
;
i
<
vec_idx
.
size
();
i
++
)
{
ptr
[
i
]
=
vec_idx
[
i
];
}
no_kvcache_from
=
kvc2_handle
->
matched_length
();
}
if
(
ok
)
{
ctx
.
query_maintainer
->
event_loop_queue
.
enqueue
(
EventPrepared
{
.
query_id
=
id
,
.
ok
=
ok
,
});
}
else
{
ctx
.
query_maintainer
->
event_loop_queue
.
enqueue
(
EventPrepare
{
.
query_id
=
id
,
.
first_try
=
false
,
});
}
}
struct
FCFS_single_prefill
:
public
QueryMaintainer
{
std
::
queue
<
Q
>
queue
;
std
::
queue
<
Q
>
ready_queue
;
bool
has_query_preparing
=
false
;
std
::
optional
<
EventPrepare
>
wait_done_prepare
=
std
::
nullopt
;
std
::
set
<
Q
>
active_query
;
// on going queries for LLMs
// interface all these are executed in a single thread
void
strategy_add_query
(
Q
new_query
)
override
{
queue
.
push
(
new_query
);
if
(
has_query_preparing
==
false
)
{
has_query_preparing
=
true
;
auto
next_q
=
queue
.
front
();
queue
.
pop
();
event_loop_queue
.
enqueue
(
EventPrepare
{
next_q
->
id
,
true
});
}
}
void
strategy_update_query
(
const
EventUpdateQuery
&
update
)
override
{
for
(
auto
u
:
update
)
{
auto
&
q
=
query_map
[
u
.
id
];
if
(
q
->
plan_status
==
Query
::
Done
)
{
active_query
.
erase
(
q
);
}
}
}
void
strategy_taken_batch
(
const
EventTakenBatch
&
batch
)
override
{
for
(
auto
&
q
:
batch
->
query_ids
)
{
if
(
query_map
[
q
]
->
plan_status
!=
Query
::
Done
)
{
active_query
.
insert
(
query_map
[
q
]);
}
}
}
void
strategy_prepare
(
const
EventPrepare
&
prepare
)
override
{
if
(
prepare
.
first_try
){
auto
&
q
=
query_map
[
prepare
.
query_id
];
q
->
to_status
(
Query
::
Preparing
);
}
else
{
assert
(
wait_done_prepare
.
has_value
()
==
false
);
wait_done_prepare
=
prepare
;
wait_done_prepare
->
first_try
=
true
;
}
}
void
strategy_prepared
(
const
EventPrepared
&
prepared
)
override
{
assert
(
prepared
.
ok
);
ready_queue
.
push
(
query_map
[
prepared
.
query_id
]);
if
(
queue
.
empty
()
==
false
)
{
auto
next_q_prepare
=
queue
.
front
();
queue
.
pop
();
event_loop_queue
.
enqueue
(
EventPrepare
{
next_q_prepare
->
id
,
true
});
}
else
{
has_query_preparing
=
false
;
}
}
void
strategy_query_status
(
const
EventQueryStatus
&
query_status
)
override
{
if
(
query_status
.
now_status
==
Query
::
Done
){
if
(
wait_done_prepare
.
has_value
()){
event_loop_queue
.
enqueue
(
wait_done_prepare
.
value
());
wait_done_prepare
=
std
::
nullopt
;
}
}
}
void
strategy_schedule
([[
maybe_unused
]]
const
EventSchedule
&
event
,
BatchQueryTodo
*
new_batch
)
override
{
bool
have_prefill
=
false
;
for
(
auto
&
q
:
active_query
)
{
if
(
q
->
plan_status
==
Query
::
Prefill
)
{
have_prefill
=
true
;
}
}
if
(
have_prefill
==
false
&&
ready_queue
.
empty
()
==
false
&&
active_query
.
size
()
<
settings
.
max_batch_size
)
{
auto
&
next_q
=
ready_queue
.
front
();
ready_queue
.
pop
();
SPDLOG_INFO
(
"Active query {}"
,
next_q
->
id
);
active_query
.
insert
(
next_q
);
next_q
->
to_status
(
Query
::
Prefill
);
}
if
(
active_query
.
empty
()
==
false
)
SPDLOG_INFO
(
"Active Query Size {}"
,
active_query
.
size
());
for
(
auto
&
q
:
active_query
)
{
q
->
debug
();
}
gen_batch_query_todo
(
new_batch
,
active_query
);
}
};
struct
FCFS
:
public
FCFS_single_prefill
{
void
strategy_schedule
([[
maybe_unused
]]
const
EventSchedule
&
event
,
BatchQueryTodo
*
new_batch
)
override
{
int
prefill_count
=
0
;
const
int
max_prefill_count
=
2
;
for
(
auto
&
q
:
active_query
)
{
if
(
q
->
plan_status
==
Query
::
Prefill
)
{
prefill_count
+=
1
;
}
}
while
(
prefill_count
<
max_prefill_count
&&
ready_queue
.
empty
()
==
false
&&
active_query
.
size
()
<
settings
.
max_batch_size
)
{
auto
next_q
=
ready_queue
.
front
();
ready_queue
.
pop
();
SPDLOG_INFO
(
"Active query {}"
,
next_q
->
id
);
active_query
.
insert
(
next_q
);
next_q
->
to_status
(
Query
::
Prefill
);
prefill_count
+=
1
;
}
if
(
active_query
.
empty
()
==
false
)
{
SPDLOG_DEBUG
(
"Active Query Size {}"
,
active_query
.
size
());
}
for
(
auto
&
q
:
active_query
)
{
q
->
debug
();
}
gen_batch_query_todo
(
new_batch
,
active_query
);
}
};
std
::
shared_ptr
<
Scheduler
>
create_scheduler
(
Settings
settings
)
{
spdlog
::
set_level
(
spdlog
::
level
::
debug
);
std
::
shared_ptr
<
Scheduler
>
re
;
SPDLOG_INFO
(
"Using Strategy {}"
,
settings
.
strategy_name
);
if
(
settings
.
strategy_name
==
"FCFS-single-prefill"
)
{
re
=
std
::
shared_ptr
<
Scheduler
>
(
new
FCFS_single_prefill
());
}
else
if
(
settings
.
strategy_name
==
"FCFS"
)
{
re
=
std
::
shared_ptr
<
Scheduler
>
(
new
FCFS
());
}
else
{
SPDLOG_ERROR
(
"Unknown strategy {}"
,
settings
.
strategy_name
);
}
re
->
init
(
settings
);
return
re
;
}
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE
(
SampleOptions
,
temperature
,
top_p
);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE
(
QueryAdd
,
query_token
,
query_length
,
estimated_length
,
sample_options
,
user_id
,
SLO_TTFT_ms
,
SLO_TBT_ms
);
std
::
string
QueryAdd
::
serialize
()
{
json
j
=
*
this
;
return
j
.
dump
();
}
QueryAdd
QueryAdd
::
deserialize
(
const
std
::
string
&
input
)
{
json
j
=
json
::
parse
(
input
);
return
j
.
get
<
QueryAdd
>
();
}
};
// namespace scheduler
csrc/balance_serve/sched/scheduler.h
0 → 100644
View file @
25cee581
#pragma once
#include <torch/torch.h>
#include <cstdint>
#include <memory>
#include <optional>
#include <vector>
#include "model_config.h"
namespace
scheduler
{
using
Token
=
uint32_t
;
using
QueryID
=
uint64_t
;
constexpr
QueryID
NoQueryID
=
0
;
using
TokenLength
=
size_t
;
using
BatchID
=
uint64_t
;
using
PageCount
=
size_t
;
struct
ModelSettings
{
std
::
string
model_path
;
size_t
params_count
;
size_t
layer_count
;
size_t
num_k_heads
;
size_t
k_head_dim
;
double
bytes_per_params
;
double
bytes_per_kv_cache_element
;
inline
size_t
params_nbytes
()
{
return
params_count
*
bytes_per_params
;
}
inline
size_t
bytes_per_token_kv_cache
()
{
return
bytes_per_kv_cache_element
*
num_k_heads
*
k_head_dim
;
}
};
struct
SampleOptions
{
double
temperature
=
1.0
;
double
top_p
=
1.0
;
};
struct
Settings
{
// something is aukward here, kvc2 only use model_name and quant_type to get model infos.
ModelName
model_name
;
QuantType
quant_type
;
// model_setting is ignore by kvc2
ModelSettings
model_settings
;
size_t
page_size
=
256
;
// how many token in a page
std
::
vector
<
size_t
>
gpu_device_id
;
//
size_t
gpu_memory_size
;
// memory size in bytes of each GPU, each
double
memory_utilization_percentage
;
size_t
max_batch_size
=
256
;
size_t
recommended_chunk_prefill_token_count
;
SampleOptions
sample_options
;
size_t
sched_metrics_port
;
// for kvc2
bool
gpu_only
;
bool
use_self_defined_head_dim
=
false
;
size_t
self_defined_head_dim
;
bool
full_kv_cache_on_each_gpu
=
false
;
bool
k_cache_on
=
true
;
bool
v_cache_on
=
true
;
std
::
string
kvc2_config_path
;
std
::
string
kvc2_root_path
;
double
memory_pool_size_GB
=
100
;
size_t
evict_count
=
20
;
size_t
kvc2_metrics_port
;
bool
load_from_disk
=
false
;
bool
save_to_disk
=
false
;
// for strategy
std
::
string
strategy_name
;
// derived
size_t
gpu_device_count
;
std
::
optional
<
size_t
>
total_kvcache_pages
;
std
::
vector
<
torch
::
Device
>
devices
;
void
auto_derive
();
};
using
PrefillTask
=
std
::
tuple
<
QueryID
,
TokenLength
,
TokenLength
>
;
// id, start, length
struct
BatchQueryTodo
{
// query
std
::
vector
<
QueryID
>
query_ids
;
std
::
vector
<
torch
::
Tensor
>
query_tokens
;
std
::
vector
<
TokenLength
>
query_lengths
;
std
::
vector
<
torch
::
Tensor
>
block_indexes
;
// (max_num_blocks_per_seq), dtype torch.int32.
std
::
optional
<
torch
::
Tensor
>
attn_masks
;
std
::
optional
<
torch
::
Tensor
>
rope_ranges
;
std
::
vector
<
SampleOptions
>
sample_options
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
stop_criteria
;
// mini batches, adjacent two mini batches are executed together
// tasks count must be <=2, because of flash infer attention
std
::
vector
<
PrefillTask
>
prefill_mini_batches
;
// prefill minibatch only has 1 prefill
std
::
vector
<
std
::
vector
<
QueryID
>>
decode_mini_batches
;
// decode minibatch has multiple decode
std
::
string
debug
();
bool
empty
();
};
struct
QueryUpdate
{
QueryID
id
;
bool
ok
;
bool
is_prefill
;
bool
decode_done
;
// no use for now
TokenLength
active_position
;
// the position where no kvcache now,
// kvcache[active_position] == None
Token
generated_token
;
std
::
string
debug
()
const
;
};
using
BatchQueryUpdate
=
std
::
vector
<
QueryUpdate
>
;
struct
InferenceContext
{
std
::
vector
<
torch
::
Tensor
>
k_cache
;
// [gpu num] (layer_count, num blocks,
// page size, kheadnum, head_dim)
std
::
vector
<
torch
::
Tensor
>
v_cache
;
};
using
UserID
=
int64_t
;
constexpr
UserID
NoUser
=
-
1
;
const
int
MAX_SLO_TIME
=
1e9
;
struct
QueryAdd
{
std
::
vector
<
Token
>
query_token
;
// int here
// torch::Tensor attn_mask;
TokenLength
query_length
;
TokenLength
estimated_length
;
std
::
vector
<
std
::
vector
<
int
>>
stop_criteria
;
SampleOptions
sample_options
;
UserID
user_id
;
int
SLO_TTFT_ms
=
MAX_SLO_TIME
;
int
SLO_TBT_ms
=
MAX_SLO_TIME
;
std
::
string
serialize
();
static
QueryAdd
deserialize
(
const
std
::
string
&
input
);
};
class
Scheduler
{
public:
virtual
void
init
(
Settings
settings
)
=
0
;
virtual
void
run
()
=
0
;
virtual
void
stop
()
=
0
;
// webserver call this
virtual
QueryID
add_query
(
QueryAdd
query
)
=
0
;
virtual
void
cancel_query
(
QueryID
id
)
=
0
;
// inference loop call this
virtual
std
::
shared_ptr
<
BatchQueryTodo
>
update_last_batch
(
BatchQueryUpdate
updates
)
=
0
;
virtual
InferenceContext
get_inference_context
()
=
0
;
virtual
~
Scheduler
()
=
default
;
};
std
::
shared_ptr
<
Scheduler
>
create_scheduler
(
Settings
settings
);
};
// namespace scheduler
\ No newline at end of file
csrc/balance_serve/sched/utils/all.hpp
0 → 100644
View file @
25cee581
#pragma once
#include "readable_number.hpp"
#include "timer.hpp"
\ No newline at end of file
csrc/balance_serve/sched/utils/arithmetic.hpp
0 → 100644
View file @
25cee581
#include <type_traits>
template
<
typename
T
,
typename
U
>
T
div_up
(
T
x
,
U
by
)
{
static_assert
(
std
::
is_integral_v
<
T
>
);
static_assert
(
std
::
is_integral_v
<
U
>
);
return
(
x
+
by
-
1
)
/
by
;
}
\ No newline at end of file
csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp
0 → 100644
View file @
25cee581
#include <atomic>
template
<
typename
T
>
struct
AtomicPtrWithFlag
{
constexpr
static
uint64_t
mask
=
1ull
<<
63
;
std
::
atomic_uint64_t
ptr
=
0
;
std
::
pair
<
T
*
,
bool
>
load
(
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
uint64_t
val
=
ptr
.
load
(
order
);
return
{
reinterpret_cast
<
T
*>
(
val
&
(
~
mask
)),
val
&
mask
};
}
void
store
(
T
*
p
,
bool
flag
,
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
ptr
.
store
(
reinterpret_cast
<
uint64_t
>
(
p
)
|
(
flag
?
mask
:
0
),
order
);
}
std
::
pair
<
T
*
,
bool
>
exchange
(
T
*
p
,
bool
flag
,
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
uint64_t
val
=
ptr
.
exchange
(
reinterpret_cast
<
uint64_t
>
(
p
)
|
(
flag
?
mask
:
0
),
order
);
return
{
reinterpret_cast
<
T
*>
(
val
&
(
~
mask
)),
val
&
mask
};
}
std
::
pair
<
T
*
,
bool
>
touch_load
(
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
uint64_t
val
=
ptr
.
fetch_and
(
~
mask
,
order
);
return
{
reinterpret_cast
<
T
*>
(
val
&
(
~
mask
)),
val
&
mask
};
}
bool
check_flag
(
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
return
ptr
.
load
(
order
)
&
mask
;
}
};
csrc/balance_serve/sched/utils/csv.hpp
0 → 100644
View file @
25cee581
#ifndef CSV_READER_HPP
#define CSV_READER_HPP
#include <fstream>
#include <iostream>
#include <mutex>
#include <sstream>
#include <stdexcept>
#include <string>
#include <thread>
#include <vector>
namespace
csv
{
/**
* @brief Parses a CSV line into individual fields, handling quoted fields with
* commas and newlines.
*
* @param line The CSV line to parse.
* @return A vector of strings, each representing a field in the CSV line.
*/
inline
std
::
vector
<
std
::
string
>
parse_csv_line
(
const
std
::
string
&
line
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
string
field
;
bool
in_quotes
=
false
;
for
(
size_t
i
=
0
;
i
<
line
.
length
();
++
i
)
{
char
c
=
line
[
i
];
if
(
c
==
'"'
)
{
// Handle double quotes inside quoted fields
if
(
in_quotes
&&
i
+
1
<
line
.
length
()
&&
line
[
i
+
1
]
==
'"'
)
{
field
+=
'"'
;
++
i
;
}
else
{
in_quotes
=
!
in_quotes
;
}
}
else
if
(
c
==
','
&&
!
in_quotes
)
{
result
.
push_back
(
field
);
field
.
clear
();
}
else
{
field
+=
c
;
}
}
result
.
push_back
(
field
);
return
result
;
}
/**
* @brief Reads a CSV file and returns a vector of pairs containing column names
* and their corresponding data vectors.
*
* This function reads the header to obtain column names and uses multithreading
* to read and parse the CSV file in chunks.
*
* @param filename The path to the CSV file.
* @return A vector of pairs, each containing a column name and a vector of data
* for that column.
*/
inline
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
read_csv
(
const
std
::
string
&
filename
)
{
std
::
cout
<<
"Reading CSV file: "
<<
filename
<<
std
::
endl
;
// Open the file
std
::
ifstream
file
(
filename
);
if
(
!
file
)
{
throw
std
::
runtime_error
(
"Cannot open file"
);
}
// Read the header line and parse column names
std
::
string
header_line
;
std
::
getline
(
file
,
header_line
);
std
::
vector
<
std
::
string
>
column_names
=
parse_csv_line
(
header_line
);
// Prepare the result vector with column names
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
result
;
for
(
const
auto
&
name
:
column_names
)
{
result
.
emplace_back
(
name
,
std
::
vector
<
std
::
string
>
());
}
// Read the rest of the file into a string buffer
std
::
stringstream
buffer
;
buffer
<<
file
.
rdbuf
();
std
::
string
content
=
buffer
.
str
();
// Determine the number of threads to use
unsigned
int
num_threads
=
std
::
thread
::
hardware_concurrency
();
if
(
num_threads
==
0
)
num_threads
=
4
;
// Default to 4 threads if hardware_concurrency returns 0
// Calculate chunk start positions based on content size
std
::
vector
<
size_t
>
chunk_starts
;
size_t
content_size
=
content
.
size
();
size_t
chunk_size
=
content_size
/
num_threads
;
chunk_starts
.
push_back
(
0
);
for
(
unsigned
int
i
=
1
;
i
<
num_threads
;
++
i
)
{
size_t
pos
=
i
*
chunk_size
;
// Adjust position to the next newline character to ensure we start at the
// beginning of a line
while
(
pos
<
content_size
&&
content
[
pos
]
!=
'\n'
)
{
++
pos
;
}
if
(
pos
<
content_size
)
{
++
pos
;
// Skip the newline character
}
chunk_starts
.
push_back
(
pos
);
}
chunk_starts
.
push_back
(
content_size
);
// Create threads to parse each chunk
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
string
>>>
thread_results
(
num_threads
);
std
::
vector
<
std
::
thread
>
threads
;
for
(
unsigned
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
size_t
start
=
chunk_starts
[
i
];
size_t
end
=
chunk_starts
[
i
+
1
];
threads
.
emplace_back
([
&
content
,
start
,
end
,
&
thread_results
,
i
]()
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
local_result
;
size_t
pos
=
start
;
while
(
pos
<
end
)
{
size_t
next_pos
=
content
.
find
(
'\n'
,
pos
);
if
(
next_pos
==
std
::
string
::
npos
||
next_pos
>
end
)
{
next_pos
=
end
;
}
std
::
string
line
=
content
.
substr
(
pos
,
next_pos
-
pos
);
if
(
!
line
.
empty
())
{
local_result
.
push_back
(
parse_csv_line
(
line
));
}
pos
=
next_pos
+
1
;
}
thread_results
[
i
]
=
std
::
move
(
local_result
);
});
}
// Wait for all threads to finish
for
(
auto
&
t
:
threads
)
{
t
.
join
();
}
// Combine the results from all threads into the final result
for
(
const
auto
&
local_result
:
thread_results
)
{
for
(
const
auto
&
row
:
local_result
)
{
for
(
size_t
i
=
0
;
i
<
row
.
size
();
++
i
)
{
if
(
i
<
result
.
size
())
{
result
[
i
].
second
.
push_back
(
row
[
i
]);
}
}
}
}
return
result
;
}
/**
* @brief Writes the CSV data into a file.
*
* @param filename The path to the output CSV file.
* @param data A vector of pairs, each containing a column name and a vector of
* data for that column.
*/
inline
void
write_csv
(
const
std
::
string
&
filename
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>>&
data
)
{
std
::
cout
<<
"Writing CSV file: "
<<
filename
<<
std
::
endl
;
// Open the file for writing
std
::
ofstream
file
(
filename
);
if
(
!
file
)
{
throw
std
::
runtime_error
(
"Cannot open file for writing"
);
}
// Check that all columns have the same number of rows
if
(
data
.
empty
())
{
return
;
// Nothing to write
}
size_t
num_rows
=
data
[
0
].
second
.
size
();
for
(
const
auto
&
column
:
data
)
{
if
(
column
.
second
.
size
()
!=
num_rows
)
{
throw
std
::
runtime_error
(
"All columns must have the same number of rows"
);
}
}
// Write the header
for
(
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
{
file
<<
data
[
i
].
first
;
if
(
i
!=
data
.
size
()
-
1
)
{
file
<<
','
;
}
}
file
<<
'\n'
;
// Write the data rows
for
(
size_t
row
=
0
;
row
<
num_rows
;
++
row
)
{
for
(
size_t
col
=
0
;
col
<
data
.
size
();
++
col
)
{
const
std
::
string
&
field
=
data
[
col
].
second
[
row
];
// Handle CSV escaping
std
::
string
escaped_field
=
field
;
bool
needs_quotes
=
false
;
if
(
escaped_field
.
find
(
'"'
)
!=
std
::
string
::
npos
)
{
needs_quotes
=
true
;
// Escape double quotes
size_t
pos
=
0
;
while
((
pos
=
escaped_field
.
find
(
'"'
,
pos
))
!=
std
::
string
::
npos
)
{
escaped_field
.
insert
(
pos
,
"
\"
"
);
pos
+=
2
;
}
}
if
(
escaped_field
.
find
(
','
)
!=
std
::
string
::
npos
||
escaped_field
.
find
(
'\n'
)
!=
std
::
string
::
npos
)
{
needs_quotes
=
true
;
}
if
(
needs_quotes
)
{
file
<<
'"'
<<
escaped_field
<<
'"'
;
}
else
{
file
<<
escaped_field
;
}
if
(
col
!=
data
.
size
()
-
1
)
{
file
<<
','
;
}
}
file
<<
'\n'
;
}
}
}
// namespace csv
#endif // CSV_READER_HPP
csrc/balance_serve/sched/utils/easy_format.hpp
0 → 100644
View file @
25cee581
#include <sstream>
#include <string>
#include <vector>
template
<
typename
T
>
std
::
string
format_vector
(
const
std
::
vector
<
T
>&
v
)
{
std
::
ostringstream
oss
;
if
(
v
.
empty
())
return
"[]"
;
for
(
size_t
i
=
0
;
i
<
v
.
size
();
++
i
)
{
oss
<<
v
[
i
];
if
(
i
<
v
.
size
()
-
1
)
oss
<<
", "
;
// 逗号分隔
}
return
oss
.
str
();
}
csrc/balance_serve/sched/utils/mpsc.hpp
0 → 100644
View file @
25cee581
#include <atomic>
#include <cassert>
#include <iostream>
#include <optional>
#include <semaphore>
template
<
typename
T
>
class
MPSCQueue
{
struct
Node
{
T
data
;
std
::
atomic
<
Node
*>
next
;
Node
()
:
next
(
nullptr
)
{}
Node
(
T
data_
)
:
data
(
std
::
move
(
data_
)),
next
(
nullptr
)
{}
};
std
::
atomic
<
Node
*>
head
;
Node
*
tail
;
public:
std
::
atomic_size_t
enqueue_count
=
0
;
size_t
dequeue_count
=
0
;
MPSCQueue
()
{
Node
*
dummy
=
new
Node
();
head
.
store
(
dummy
,
std
::
memory_order_seq_cst
);
tail
=
dummy
;
}
~
MPSCQueue
()
{
Node
*
node
=
tail
;
while
(
node
)
{
Node
*
next
=
node
->
next
.
load
(
std
::
memory_order_seq_cst
);
delete
node
;
node
=
next
;
}
}
// 生产者调用
void
enqueue
(
T
data
)
{
enqueue_count
.
fetch_add
(
1
);
Node
*
node
=
new
Node
(
std
::
move
(
data
));
Node
*
prev_head
=
head
.
exchange
(
node
,
std
::
memory_order_seq_cst
);
prev_head
->
next
.
store
(
node
,
std
::
memory_order_seq_cst
);
}
// 消费者调用
std
::
optional
<
T
>
dequeue
()
{
Node
*
next
=
tail
->
next
.
load
(
std
::
memory_order_seq_cst
);
if
(
next
)
{
T
res
=
std
::
move
(
next
->
data
);
delete
tail
;
tail
=
next
;
dequeue_count
+=
1
;
return
res
;
}
return
std
::
nullopt
;
}
size_t
size
()
{
return
enqueue_count
.
load
()
-
dequeue_count
;
}
};
template
<
typename
T
>
class
MPSCQueueConsumerLock
{
MPSCQueue
<
T
>
queue
;
std
::
counting_semaphore
<>
sema
{
0
};
public:
void
enqueue
(
T
data
)
{
queue
.
enqueue
(
std
::
move
(
data
));
// std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this because the memory order might be wrong, I
// am also not that sure about this.
sema
.
release
();
}
T
dequeue
()
{
auto
re
=
queue
.
dequeue
();
if
(
re
.
has_value
())
{
while
(
sema
.
try_acquire
()
==
false
)
{
std
::
cerr
<<
__FILE__
<<
":"
<<
__FUNCTION__
<<
" sema try acquire should be success, retrying, please check"
<<
std
::
endl
;
// assert(false);
}
return
re
.
value
();
}
sema
.
acquire
();
return
queue
.
dequeue
().
value
();
}
template
<
typename
Rep
,
typename
Period
>
std
::
optional
<
T
>
try_dequeue_for
(
std
::
chrono
::
duration
<
Rep
,
Period
>
dur
)
{
auto
re
=
queue
.
dequeue
();
if
(
re
.
has_value
())
{
while
(
sema
.
try_acquire
()
==
false
)
{
std
::
cerr
<<
__FILE__
<<
":"
<<
__FUNCTION__
<<
" sema try acquire should be success, retrying, please check"
<<
std
::
endl
;
// assert(false);
}
return
re
.
value
();
}
if
(
sema
.
try_acquire_for
(
dur
))
{
return
queue
.
dequeue
().
value
();
}
else
{
return
std
::
nullopt
;
}
}
size_t
size
()
{
return
queue
.
size
();
}
};
csrc/balance_serve/sched/utils/readable_number.hpp
0 → 100644
View file @
25cee581
#pragma once
#include <array>
#include <iomanip>
#include <sstream>
#include <string>
inline
std
::
array
<
std
::
string
,
7
>
units
=
{
""
,
"K"
,
"M"
,
"G"
,
"T"
,
"P"
,
"E"
};
inline
std
::
string
readable_number
(
size_t
size
)
{
size_t
unit_index
=
0
;
double
readable_size
=
size
;
while
(
readable_size
>=
1000
&&
unit_index
<
units
.
size
()
-
1
)
{
readable_size
/=
1000
;
unit_index
++
;
}
std
::
ostringstream
ss
;
ss
<<
std
::
fixed
<<
std
::
setprecision
(
2
)
<<
readable_size
;
std
::
string
str
=
ss
.
str
();
return
str
+
""
+
units
[
unit_index
];
}
\ No newline at end of file
csrc/balance_serve/sched/utils/statistics.hpp
0 → 100644
View file @
25cee581
#ifndef STATISTICS_HPP
#define STATISTICS_HPP
#include <chrono>
#include <iostream>
#include <string>
#include <unordered_map>
class
Statistics
{
public:
// Increment the counter for a given key by a specified value (default is 1)
void
increment_counter
(
const
std
::
string
&
key
,
int64_t
value
=
1
)
{
counters_
[
key
]
+=
value
;
}
int64_t
&
get_counter
(
const
std
::
string
&
key
)
{
return
counters_
[
key
];
}
// Start the timer for a given key
void
start_timer
(
const
std
::
string
&
key
)
{
active_timers_
[
key
]
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// Stop the timer for a given key and update the total time and count
void
stop_timer
(
const
std
::
string
&
key
)
{
auto
start_it
=
active_timers_
.
find
(
key
);
if
(
start_it
!=
active_timers_
.
end
())
{
auto
duration
=
std
::
chrono
::
high_resolution_clock
::
now
()
-
start_it
->
second
;
timings_
[
key
].
total_time
+=
duration
;
timings_
[
key
].
count
+=
1
;
active_timers_
.
erase
(
start_it
);
}
else
{
// Handle error: stop_timer called without a matching start_timer
std
::
cerr
<<
"Warning: stop_timer called for key '"
<<
key
<<
"' without a matching start_timer.
\n
"
;
}
}
// Print out the collected statistical information
void
report
()
const
{
std
::
cout
<<
"Counters:
\n
"
;
for
(
const
auto
&
kv
:
counters_
)
{
std
::
cout
<<
" "
<<
kv
.
first
<<
": "
<<
kv
.
second
<<
"
\n
"
;
}
std
::
cout
<<
"
\n
Timers:
\n
"
;
for
(
const
auto
&
kv
:
timings_
)
{
std
::
cout
<<
" "
<<
kv
.
first
<<
": count = "
<<
kv
.
second
.
count
<<
", total_time = "
<<
kv
.
second
.
total_time
.
count
()
<<
"s"
<<
", average_time = "
<<
(
kv
.
second
.
count
>
0
?
kv
.
second
.
total_time
.
count
()
/
kv
.
second
.
count
:
0
)
<<
"s
\n
"
;
}
}
private:
// Mapping from key to counter
std
::
unordered_map
<
std
::
string
,
int64_t
>
counters_
;
// Struct to hold timing information for a key
struct
TimingInfo
{
int64_t
count
=
0
;
std
::
chrono
::
duration
<
double
>
total_time
=
std
::
chrono
::
duration
<
double
>::
zero
();
};
// Mapping from key to timing information
std
::
unordered_map
<
std
::
string
,
TimingInfo
>
timings_
;
// Mapping from key to the start time of active timers
std
::
unordered_map
<
std
::
string
,
std
::
chrono
::
high_resolution_clock
::
time_point
>
active_timers_
;
};
#endif // STATISTICS_HPP
csrc/balance_serve/sched/utils/timer.hpp
0 → 100644
View file @
25cee581
#pragma once
#include <cassert>
#include <chrono>
#include <iomanip>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
#include "readable_number.hpp"
inline
std
::
string
doubleToStringR2
(
double
value
)
{
std
::
stringstream
stream
;
stream
<<
std
::
fixed
<<
std
::
setprecision
(
2
)
<<
value
;
return
stream
.
str
();
}
class
Timer
{
public:
std
::
string
name
;
bool
tmp_timer
=
false
;
Timer
()
{}
Timer
(
std
::
string
name
)
:
name
(
name
),
tmp_timer
(
true
)
{
start
();
}
~
Timer
()
{
if
(
tmp_timer
)
{
std
::
cout
<<
name
<<
" "
<<
elapsedMs
()
<<
" ms"
<<
std
::
endl
;
}
}
void
start
()
{
m_startTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
assert
(
m_isRunning
==
false
);
m_isRunning
=
true
;
}
void
stop
()
{
m_endTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
assert
(
m_isRunning
==
true
);
m_isRunning
=
false
;
m_runningNs
+=
elapsedNs
();
}
double
elapsedNs
()
{
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
endTime
;
if
(
m_isRunning
)
{
endTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
else
{
endTime
=
m_endTime
;
}
return
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
endTime
-
m_startTime
).
count
();
}
void
printElapsedMilliseconds
()
{
std
::
cout
<<
elapsedNs
()
/
1e6
<<
" ms"
<<
std
::
endl
;
}
static
std
::
string
ns_to_string
(
double
duration
)
{
auto
nano_sec
=
duration
;
if
(
nano_sec
>=
1000
)
{
auto
mirco_sec
=
nano_sec
/
1000.0
;
if
(
mirco_sec
>=
1000
)
{
auto
milli_sec
=
mirco_sec
/
1000.0
;
if
(
milli_sec
>=
1000
)
{
auto
seconds
=
milli_sec
/
1000.0
;
if
(
seconds
>=
60.0
)
{
auto
minutes
=
seconds
/
60.0
;
if
(
minutes
>=
60.0
)
{
auto
hours
=
minutes
/
60.0
;
return
doubleToStringR2
(
hours
)
+
" h"
;
}
else
{
return
doubleToStringR2
(
minutes
)
+
" min"
;
}
}
else
{
return
doubleToStringR2
(
seconds
)
+
" sec"
;
}
}
else
{
return
doubleToStringR2
(
milli_sec
)
+
" ms"
;
}
}
else
{
return
doubleToStringR2
(
mirco_sec
)
+
" us"
;
}
}
else
{
return
doubleToStringR2
(
nano_sec
)
+
" ns"
;
}
}
double
runningTimeNs
()
{
return
m_runningNs
;
}
std
::
string
runningTime
()
{
auto
duration
=
m_runningNs
;
return
ns_to_string
(
duration
);
}
std
::
string
elapsedTime
()
{
return
ns_to_string
(
elapsedNs
());
}
double
elapsedMs
()
{
return
elapsedNs
()
/
1e6
;
}
std
::
string
report_throughput
(
size_t
op_cnt
)
{
double
ops
=
op_cnt
/
elapsedMs
()
*
1000
;
return
readable_number
(
ops
)
+
"op/s"
;
}
void
merge
(
Timer
&
other
)
{
assert
(
m_isRunning
==
false
);
assert
(
other
.
m_isRunning
==
false
);
m_runningNs
+=
other
.
runningTimeNs
();
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
m_startTime
;
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
m_endTime
;
bool
m_isRunning
=
false
;
double
m_runningNs
=
0.0
;
};
class
Counter
{
public:
Counter
()
{}
std
::
map
<
std
::
string
,
size_t
>
counters
;
void
inc
(
const
char
*
name
,
size_t
num
)
{
counters
[
name
]
+=
num
;
};
void
print
()
{
for
(
auto
&
p
:
counters
)
{
std
::
cout
<<
p
.
first
<<
" : "
<<
p
.
second
<<
std
::
endl
;
}
};
};
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment