Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
877aec85
Unverified
Commit
877aec85
authored
Apr 09, 2025
by
Yuhao Tsui
Committed by
GitHub
Apr 09, 2025
Browse files
Merge branch 'kvcache-ai:main' into main
parents
84164f58
9037bf30
Changes
251
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2587 additions
and
0 deletions
+2587
-0
csrc/balance_serve/kvc2/test/test_lock_free_queue.cpp
csrc/balance_serve/kvc2/test/test_lock_free_queue.cpp
+56
-0
csrc/balance_serve/kvc2/test/test_periodic_task.cpp
csrc/balance_serve/kvc2/test/test_periodic_task.cpp
+163
-0
csrc/balance_serve/kvc2/test/test_queue_perf.cpp
csrc/balance_serve/kvc2/test/test_queue_perf.cpp
+84
-0
csrc/balance_serve/kvc2/test/test_std_list.cpp
csrc/balance_serve/kvc2/test/test_std_list.cpp
+38
-0
csrc/balance_serve/kvc2/test/xxHash_test.cpp
csrc/balance_serve/kvc2/test/xxHash_test.cpp
+31
-0
csrc/balance_serve/kvc2/unit_test.sh
csrc/balance_serve/kvc2/unit_test.sh
+36
-0
csrc/balance_serve/sched/CMakeLists.txt
csrc/balance_serve/sched/CMakeLists.txt
+20
-0
csrc/balance_serve/sched/bind.cpp
csrc/balance_serve/sched/bind.cpp
+249
-0
csrc/balance_serve/sched/metrics.cpp
csrc/balance_serve/sched/metrics.cpp
+147
-0
csrc/balance_serve/sched/metrics.h
csrc/balance_serve/sched/metrics.h
+88
-0
csrc/balance_serve/sched/model_config.h
csrc/balance_serve/sched/model_config.h
+119
-0
csrc/balance_serve/sched/scheduler.cpp
csrc/balance_serve/sched/scheduler.cpp
+960
-0
csrc/balance_serve/sched/scheduler.h
csrc/balance_serve/sched/scheduler.h
+175
-0
csrc/balance_serve/sched/utils/all.hpp
csrc/balance_serve/sched/utils/all.hpp
+3
-0
csrc/balance_serve/sched/utils/arithmetic.hpp
csrc/balance_serve/sched/utils/arithmetic.hpp
+7
-0
csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp
csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp
+35
-0
csrc/balance_serve/sched/utils/csv.hpp
csrc/balance_serve/sched/utils/csv.hpp
+229
-0
csrc/balance_serve/sched/utils/easy_format.hpp
csrc/balance_serve/sched/utils/easy_format.hpp
+15
-0
csrc/balance_serve/sched/utils/mpsc.hpp
csrc/balance_serve/sched/utils/mpsc.hpp
+112
-0
csrc/balance_serve/sched/utils/readable_number.hpp
csrc/balance_serve/sched/utils/readable_number.hpp
+20
-0
No files found.
csrc/balance_serve/kvc2/test/test_lock_free_queue.cpp
0 → 100644
View file @
877aec85
#include <chrono>
#include <iostream>
#include <thread>
#include <vector>
#include "utils/lock_free_queue.hpp"
struct
Item
{
int
value
;
std
::
promise
<
void
>
promise
;
};
int
main
()
{
MPSCQueue
<
Item
>
queue
;
std
::
vector
<
std
::
thread
>
producers
;
const
int
num_producers
=
4
;
const
int
items_per_producer
=
5
;
// 启动生产者线程
for
(
int
i
=
0
;
i
<
num_producers
;
++
i
)
{
producers
.
emplace_back
([
&
queue
,
i
]()
{
for
(
int
j
=
0
;
j
<
items_per_producer
;
++
j
)
{
auto
item
=
std
::
make_shared
<
Item
>
();
item
->
value
=
i
*
items_per_producer
+
j
;
std
::
future
<
void
>
future
=
item
->
promise
.
get_future
();
queue
.
enqueue
(
item
);
future
.
wait
();
// 等待消费者处理完成
}
});
}
// 启动消费者线程
std
::
thread
consumer
([
&
queue
,
num_producers
,
items_per_producer
]()
{
int
total_items
=
num_producers
*
items_per_producer
;
int
processed
=
0
;
while
(
processed
<
total_items
)
{
std
::
shared_ptr
<
Item
>
item
=
queue
.
dequeue
();
if
(
item
)
{
std
::
cout
<<
"Consumed item with value: "
<<
item
->
value
<<
std
::
endl
;
item
->
promise
.
set_value
();
// 通知生产者
++
processed
;
}
else
{
// 如果队列为空,可以选择休眠或让出线程
std
::
this_thread
::
yield
();
}
}
});
// 等待所有线程完成
for
(
auto
&
producer
:
producers
)
{
producer
.
join
();
}
consumer
.
join
();
return
0
;
}
\ No newline at end of file
csrc/balance_serve/kvc2/test/test_periodic_task.cpp
0 → 100644
View file @
877aec85
#include <atomic>
#include <cassert>
#include <chrono>
#include <cstdio>
#include <future>
#include <iostream>
#include <thread>
#include "utils/periodic_task.hpp"
// 1. 任务是否按预期执行
void
testPeriodicTaskExecution
()
{
std
::
atomic
<
int
>
execution_count
{
0
};
auto
task
=
[
&
execution_count
]()
{
execution_count
++
;
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
50
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
2
));
assert
(
execution_count
>=
20
);
// 确保任务执行了至少 20 次
std
::
cout
<<
"Test 1 passed: Task executed periodically."
<<
std
::
endl
;
std
::
cout
<<
"Task executed "
<<
execution_count
.
load
()
<<
" times."
<<
std
::
endl
;
}
// 2. 提前唤醒任务的功能
void
testWakeUpImmediately
()
{
std
::
atomic
<
int
>
execution_count
{
0
};
auto
task
=
[
&
execution_count
]()
{
execution_count
++
;
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
// 提前唤醒任务
periodic_task
.
wakeUp
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
50
));
// 等待任务执行
std
::
cout
<<
"Execution count after wakeUp: "
<<
execution_count
.
load
()
<<
std
::
endl
;
assert
(
execution_count
==
1
);
// 确保任务立即执行
std
::
cout
<<
"Test 2 passed: Task woke up immediately."
<<
std
::
endl
;
}
// 3. wakeUpWait() 的等待功能
void
testWakeUpWait
()
{
std
::
promise
<
void
>
promise
;
std
::
future
<
void
>
future
=
promise
.
get_future
();
auto
task
=
[
&
promise
]()
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// 模拟任务执行
promise
.
set_value
();
// 任务完成时设置 promise
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
// 调用 wakeUpWait 并等待任务完成
std
::
future
<
void
>
wakeup_future
=
periodic_task
.
wakeUpWait
();
wakeup_future
.
wait
();
// 等待任务完成
assert
(
wakeup_future
.
valid
());
// 确保 future 是有效的
std
::
cout
<<
"Test 3 passed: wakeUpWait() works correctly."
<<
std
::
endl
;
std
::
cout
<<
"wakeUpWait() future is valid."
<<
std
::
endl
;
}
// 4. 任务抛出异常的处理
void
testTaskExceptionHandling
()
{
auto
task
=
[]()
{
throw
std
::
runtime_error
(
"Test exception"
);
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
300
));
// 等待一段时间
std
::
cout
<<
"Test 4 passed: Task exception is handled correctly."
<<
std
::
endl
;
std
::
cout
<<
"Exception handled and task did not crash."
<<
std
::
endl
;
}
// 5. 线程是否能正确停止
void
testTaskStop
()
{
std
::
atomic
<
bool
>
stopped
{
false
};
auto
task
=
[
&
stopped
]()
{
while
(
!
stopped
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
50
));
}
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
100
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
// 运行一段时间
stopped
=
true
;
// 请求停止
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
50
));
// 等待线程停止
std
::
cout
<<
"Test 5 passed: Task thread stops correctly."
<<
std
::
endl
;
std
::
cout
<<
"Task has been stopped successfully."
<<
std
::
endl
;
}
// 6. 高频唤醒的情况下任务执行是否正常
void
testHighFrequencyWakeUp
()
{
std
::
atomic
<
int
>
execution_count
{
0
};
auto
task
=
[
&
execution_count
]()
{
execution_count
++
;
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
periodic_task
.
wakeUp
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
// 每 10 毫秒唤醒一次
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
// 等待任务执行完成
assert
(
execution_count
>
50
);
// 确保任务至少执行了 50 次
std
::
cout
<<
"Test 6 passed: Task handles frequent wake ups correctly."
<<
std
::
endl
;
std
::
cout
<<
"Task executed "
<<
execution_count
.
load
()
<<
" times."
<<
std
::
endl
;
}
// 7. 多个 wakeUpWait() 调用的处理
void
testMultipleWakeUpWait
()
{
std
::
atomic
<
int
>
execution_count
{
0
};
auto
task
=
[
&
execution_count
]()
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// 模拟任务执行
execution_count
++
;
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
// 同时调用两个 wakeUpWait
std
::
future
<
void
>
future1
=
periodic_task
.
wakeUpWait
();
std
::
future
<
void
>
future2
=
periodic_task
.
wakeUpWait
();
future1
.
wait
();
future2
.
wait
();
assert
(
execution_count
==
1
);
// 确保任务只执行了一次
std
::
cout
<<
"Test 7 passed: Multiple wakeUpWait() calls are handled correctly."
<<
std
::
endl
;
std
::
cout
<<
"Task executed "
<<
execution_count
.
load
()
<<
" times."
<<
std
::
endl
;
}
// 8. 任务函数为空的边界情况
void
testEmptyTaskFunction
()
{
auto
task
=
[]()
{
// 空任务函数
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
100
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
// 等待一段时间
std
::
cout
<<
"Test 8 passed: Empty task function works correctly."
<<
std
::
endl
;
std
::
cout
<<
"Empty task function executed without issues."
<<
std
::
endl
;
}
int
main
()
{
std
::
cout
<<
"Starting tests..."
<<
std
::
endl
;
// testWakeUpImmediately();
testPeriodicTaskExecution
();
testWakeUpImmediately
();
testWakeUpWait
();
testTaskExceptionHandling
();
testTaskStop
();
testHighFrequencyWakeUp
();
testMultipleWakeUpWait
();
testEmptyTaskFunction
();
std
::
cout
<<
"All tests passed!"
<<
std
::
endl
;
return
0
;
}
csrc/balance_serve/kvc2/test/test_queue_perf.cpp
0 → 100644
View file @
877aec85
#include <mutex>
#include <queue>
#include "utils/lock_free_queue.hpp"
#define STDQ
int
main
()
{
const
int
num_producers
=
48
;
const
int
num_items
=
1e6
;
#ifdef STDQ
std
::
mutex
lock
;
std
::
queue
<
int
>
queue
;
#else
MPSCQueue
<
int
>
queue
;
#endif
auto
start_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
// Launch multiple producer threads
std
::
vector
<
std
::
thread
>
producers
;
for
(
int
i
=
0
;
i
<
num_producers
;
++
i
)
{
producers
.
emplace_back
([
&
queue
,
i
#ifdef STDQ
,
&
lock
#endif
]()
{
for
(
int
j
=
0
;
j
<
num_items
;
++
j
)
{
#ifdef STDQ
std
::
lock_guard
<
std
::
mutex
>
guard
(
lock
);
queue
.
push
(
i
*
num_items
+
j
);
#else
queue
.
enqueue
(
std
::
make_shared
<
int
>
(
i
*
num_items
+
j
));
#endif
}
});
}
// Consumer thread
std
::
thread
consumer
([
&
queue
,
num_producers
#ifdef STDQ
,
&
lock
#endif
]()
{
int
count
=
0
;
while
(
count
<
num_producers
*
num_items
)
{
#ifdef STDQ
std
::
lock_guard
<
std
::
mutex
>
guard
(
lock
);
if
(
!
queue
.
empty
())
{
queue
.
pop
();
count
++
;
}
#else
if
(
auto
item
=
queue
.
dequeue
())
{
count
++
;
}
#endif
}
});
// Wait for all producers to finish
for
(
auto
&
producer
:
producers
)
{
producer
.
join
();
}
// Wait for the consumer to finish
consumer
.
join
();
auto
end_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
duration
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end_time
-
start_time
).
count
();
#ifdef STDQ
std
::
cout
<<
"std::queue with mutex "
;
#else
std
::
cout
<<
"lock free queue "
;
#endif
std
::
cout
<<
"Processed "
<<
num_producers
*
num_items
/
1e6
<<
"M items in "
<<
duration
<<
" milliseconds "
<<
num_producers
*
num_items
/
1e3
/
duration
<<
" MOps."
<<
std
::
endl
;
return
0
;
}
\ No newline at end of file
csrc/balance_serve/kvc2/test/test_std_list.cpp
0 → 100644
View file @
877aec85
#include <iostream>
#include <iterator>
#include <vector>
int
main
()
{
std
::
vector
<
int
>
v
=
{
0
,
1
,
2
,
3
,
4
,
5
};
using
RevIt
=
std
::
reverse_iterator
<
std
::
vector
<
int
>::
iterator
>
;
const
auto
it
=
v
.
begin
()
+
3
;
RevIt
r_it
{
it
};
std
::
cout
<<
"*it == "
<<
*
it
<<
'\n'
<<
"*r_it == "
<<
*
r_it
<<
'\n'
<<
"*r_it.base() == "
<<
*
r_it
.
base
()
<<
'\n'
<<
"*(r_it.base()-1) == "
<<
*
(
r_it
.
base
()
-
1
)
<<
'\n'
;
RevIt
r_end
{
v
.
begin
()};
RevIt
r_begin
{
v
.
end
()};
for
(
auto
it
=
r_end
.
base
();
it
!=
r_begin
.
base
();
++
it
)
std
::
cout
<<
*
it
<<
' '
;
std
::
cout
<<
'\n'
;
for
(
auto
it
=
r_begin
;
it
!=
r_end
;
++
it
)
std
::
cout
<<
*
it
<<
' '
;
std
::
cout
<<
'\n'
;
for
(
auto
it
=
r_begin
;
it
!=
r_end
;
++
it
)
{
if
(
*
it
==
3
)
{
v
.
erase
(
std
::
next
(
it
).
base
());
}
}
for
(
auto
it
:
v
)
std
::
cout
<<
it
<<
' '
;
std
::
cout
<<
'\n'
;
}
\ No newline at end of file
csrc/balance_serve/kvc2/test/xxHash_test.cpp
0 → 100644
View file @
877aec85
#include "xxhash.h"
#include <iostream>
int
main
()
{
std
::
string
t
=
"hello world"
;
XXH64_hash_t
hash
=
XXH64
(
t
.
data
(),
t
.
size
(),
123
);
std
::
cout
<<
hash
<<
std
::
endl
;
{
/* create a hash state */
XXH64_state_t
*
const
state
=
XXH64_createState
();
if
(
state
==
NULL
)
abort
();
if
(
XXH64_reset
(
state
,
123
)
==
XXH_ERROR
)
abort
();
if
(
XXH64_update
(
state
,
t
.
data
(),
5
)
==
XXH_ERROR
)
abort
();
if
(
XXH64_update
(
state
,
t
.
data
()
+
5
,
t
.
size
()
-
5
)
==
XXH_ERROR
)
abort
();
/* Produce the final hash value */
XXH64_hash_t
const
hash
=
XXH64_digest
(
state
);
/* State could be re-used; but in this example, it is simply freed */
XXH64_freeState
(
state
);
std
::
cout
<<
hash
<<
std
::
endl
;
}
return
0
;
}
csrc/balance_serve/kvc2/unit_test.sh
0 → 100755
View file @
877aec85
#!/bin/bash
# 检查是否提供了 disk_cache_path 参数
if
[
-z
"
$1
"
]
;
then
echo
"Usage:
$0
<disk_cache_path>"
exit
1
fi
# 将 disk_cache_path 参数赋值给变量
disk_cache_path
=
$1
# 定义测试命令数组,并使用变量替换 disk_cache_path
tests
=(
"./build/test/kvc2_export_header_test --disk_cache_path=
$disk_cache_path
"
"./build/test/kvcache_disk_insert_read_test --disk_cache_path=
$disk_cache_path
"
"./build/test/kvcache_mem_eviction_test --disk_cache_path=
$disk_cache_path
"
"./build/test/kvcache_mem_insert_read_test --disk_cache_path=
$disk_cache_path
"
"./build/test/kvcache_save_load_test --disk_cache_path=
$disk_cache_path
"
)
# 遍历每个测试命令
for
test
in
"
${
tests
[@]
}
"
;
do
echo
"Running:
$test
"
# 运行测试并捕获输出
output
=
$(
$test
)
# 检查测试输出中是否包含 "Test Passed"
if
echo
"
$output
"
|
grep
-q
"Test Passed"
;
then
echo
" Test Passed"
else
echo
" Test Failed"
fi
sleep
1
done
\ No newline at end of file
csrc/balance_serve/sched/CMakeLists.txt
0 → 100644
View file @
877aec85
set
(
CMAKE_CXX_FLAGS
"-Og -march=native -Wall -Wextra -g -fPIC"
)
# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
add_compile_definitions
(
_GLIBCXX_USE_CXX11_ABI=
${
_GLIBCXX_USE_CXX11_ABI
}
)
set
(
UTILS_DIR
${
CMAKE_CURRENT_SOURCE_DIR
}
/utils
)
add_library
(
sched_metrics metrics.cpp
)
target_include_directories
(
sched_metrics PRIVATE
${
UTILS_DIR
}
)
target_link_libraries
(
sched_metrics PUBLIC prometheus-cpp::pull
)
add_library
(
sched scheduler.cpp
)
target_include_directories
(
sched PRIVATE
${
SPDLOG_DIR
}
/include
${
FMT_DIR
}
/include
${
UTILS_DIR
}
${
KVC2_INCLUDE_DIR
}
)
target_link_libraries
(
sched PUBLIC pthread
${
TORCH_LIBRARIES
}
kvc2 async_store sched_metrics
)
pybind11_add_module
(
sched_ext bind.cpp
)
target_link_libraries
(
sched_ext PUBLIC sched
${
TORCH_LIBRARIES
}
${
TORCH_PYTHON_LIBRARY
}
)
csrc/balance_serve/sched/bind.cpp
0 → 100644
View file @
877aec85
#include "scheduler.h"
#include <memory>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace
py
=
pybind11
;
PYBIND11_MODULE
(
sched_ext
,
m
)
{
py
::
class_
<
scheduler
::
ModelSettings
>
(
m
,
"ModelSettings"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"model_path"
,
&
scheduler
::
ModelSettings
::
model_path
)
.
def_readwrite
(
"params_count"
,
&
scheduler
::
ModelSettings
::
params_count
)
.
def_readwrite
(
"layer_count"
,
&
scheduler
::
ModelSettings
::
layer_count
)
.
def_readwrite
(
"num_k_heads"
,
&
scheduler
::
ModelSettings
::
num_k_heads
)
.
def_readwrite
(
"k_head_dim"
,
&
scheduler
::
ModelSettings
::
k_head_dim
)
.
def_readwrite
(
"bytes_per_params"
,
&
scheduler
::
ModelSettings
::
bytes_per_params
)
.
def_readwrite
(
"bytes_per_kv_cache_element"
,
&
scheduler
::
ModelSettings
::
bytes_per_kv_cache_element
)
.
def
(
"params_size"
,
&
scheduler
::
ModelSettings
::
params_nbytes
)
.
def
(
"bytes_per_token_kv_cache"
,
&
scheduler
::
ModelSettings
::
bytes_per_token_kv_cache
)
// 添加 pickle 支持
.
def
(
py
::
pickle
(
[](
const
scheduler
::
ModelSettings
&
self
)
{
// __getstate__
return
py
::
make_tuple
(
self
.
params_count
,
self
.
layer_count
,
self
.
num_k_heads
,
self
.
k_head_dim
,
self
.
bytes_per_params
,
self
.
bytes_per_kv_cache_element
);
},
[](
py
::
tuple
t
)
{
// __setstate__
if
(
t
.
size
()
!=
6
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
ModelSettings
ms
;
ms
.
params_count
=
t
[
0
].
cast
<
size_t
>
();
ms
.
layer_count
=
t
[
1
].
cast
<
size_t
>
();
ms
.
num_k_heads
=
t
[
2
].
cast
<
size_t
>
();
ms
.
k_head_dim
=
t
[
3
].
cast
<
size_t
>
();
ms
.
bytes_per_params
=
t
[
4
].
cast
<
double
>
();
ms
.
bytes_per_kv_cache_element
=
t
[
5
].
cast
<
double
>
();
return
ms
;
}));
py
::
class_
<
scheduler
::
SampleOptions
>
(
m
,
"SampleOptions"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"temperature"
,
&
scheduler
::
SampleOptions
::
temperature
)
.
def_readwrite
(
"top_p"
,
&
scheduler
::
SampleOptions
::
top_p
)
// 确保 top_p 也能被访问
.
def
(
py
::
pickle
(
[](
const
scheduler
::
SampleOptions
&
self
)
{
return
py
::
make_tuple
(
self
.
temperature
,
self
.
top_p
);
// 序列化 temperature 和 top_p
},
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
2
)
// 确保解包时参数数量匹配
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
SampleOptions
so
;
so
.
temperature
=
t
[
0
].
cast
<
double
>
();
so
.
top_p
=
t
[
1
].
cast
<
double
>
();
// 反序列化 top_p
return
so
;
}));
py
::
class_
<
scheduler
::
Settings
>
(
m
,
"Settings"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"model_name"
,
&
scheduler
::
Settings
::
model_name
)
.
def_readwrite
(
"quant_type"
,
&
scheduler
::
Settings
::
quant_type
)
.
def_readwrite
(
"model_settings"
,
&
scheduler
::
Settings
::
model_settings
)
.
def_readwrite
(
"page_size"
,
&
scheduler
::
Settings
::
page_size
)
.
def_readwrite
(
"gpu_device_id"
,
&
scheduler
::
Settings
::
gpu_device_id
)
.
def_readwrite
(
"gpu_memory_size"
,
&
scheduler
::
Settings
::
gpu_memory_size
)
.
def_readwrite
(
"memory_utilization_percentage"
,
&
scheduler
::
Settings
::
memory_utilization_percentage
)
.
def_readwrite
(
"max_batch_size"
,
&
scheduler
::
Settings
::
max_batch_size
)
.
def_readwrite
(
"recommended_chunk_prefill_token_count"
,
&
scheduler
::
Settings
::
recommended_chunk_prefill_token_count
)
.
def_readwrite
(
"sample_options"
,
&
scheduler
::
Settings
::
sample_options
)
.
def_readwrite
(
"sched_metrics_port"
,
&
scheduler
::
Settings
::
sched_metrics_port
)
.
def_readwrite
(
"gpu_only"
,
&
scheduler
::
Settings
::
gpu_only
)
.
def_readwrite
(
"use_self_defined_head_dim"
,
&
scheduler
::
Settings
::
use_self_defined_head_dim
)
.
def_readwrite
(
"self_defined_head_dim"
,
&
scheduler
::
Settings
::
self_defined_head_dim
)
.
def_readwrite
(
"full_kv_cache_on_each_gpu"
,
&
scheduler
::
Settings
::
full_kv_cache_on_each_gpu
)
.
def_readwrite
(
"k_cache_on"
,
&
scheduler
::
Settings
::
k_cache_on
)
.
def_readwrite
(
"v_cache_on"
,
&
scheduler
::
Settings
::
v_cache_on
)
.
def_readwrite
(
"kvc2_config_path"
,
&
scheduler
::
Settings
::
kvc2_config_path
)
.
def_readwrite
(
"kvc2_root_path"
,
&
scheduler
::
Settings
::
kvc2_root_path
)
.
def_readwrite
(
"memory_pool_size_GB"
,
&
scheduler
::
Settings
::
memory_pool_size_GB
)
.
def_readwrite
(
"evict_count"
,
&
scheduler
::
Settings
::
evict_count
)
.
def_readwrite
(
"strategy_name"
,
&
scheduler
::
Settings
::
strategy_name
)
.
def_readwrite
(
"kvc2_metrics_port"
,
&
scheduler
::
Settings
::
kvc2_metrics_port
)
.
def_readwrite
(
"load_from_disk"
,
&
scheduler
::
Settings
::
load_from_disk
)
.
def_readwrite
(
"save_to_disk"
,
&
scheduler
::
Settings
::
save_to_disk
)
// derived
.
def_readwrite
(
"gpu_device_count"
,
&
scheduler
::
Settings
::
gpu_device_count
)
.
def_readwrite
(
"total_kvcache_pages"
,
&
scheduler
::
Settings
::
total_kvcache_pages
)
.
def_readwrite
(
"devices"
,
&
scheduler
::
Settings
::
devices
)
.
def
(
"auto_derive"
,
&
scheduler
::
Settings
::
auto_derive
);
py
::
class_
<
scheduler
::
BatchQueryTodo
,
std
::
shared_ptr
<
scheduler
::
BatchQueryTodo
>>
(
m
,
"BatchQueryTodo"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"query_ids"
,
&
scheduler
::
BatchQueryTodo
::
query_ids
)
.
def_readwrite
(
"query_tokens"
,
&
scheduler
::
BatchQueryTodo
::
query_tokens
)
.
def_readwrite
(
"query_lengths"
,
&
scheduler
::
BatchQueryTodo
::
query_lengths
)
.
def_readwrite
(
"block_indexes"
,
&
scheduler
::
BatchQueryTodo
::
block_indexes
)
.
def_readwrite
(
"attn_masks"
,
&
scheduler
::
BatchQueryTodo
::
attn_masks
)
.
def_readwrite
(
"rope_ranges"
,
&
scheduler
::
BatchQueryTodo
::
rope_ranges
)
.
def_readwrite
(
"sample_options"
,
&
scheduler
::
BatchQueryTodo
::
sample_options
)
.
def_readwrite
(
"prefill_mini_batches"
,
&
scheduler
::
BatchQueryTodo
::
prefill_mini_batches
)
.
def_readwrite
(
"decode_mini_batches"
,
&
scheduler
::
BatchQueryTodo
::
decode_mini_batches
)
.
def_readwrite
(
"stop_criteria"
,
&
scheduler
::
BatchQueryTodo
::
stop_criteria
)
.
def
(
"debug"
,
&
scheduler
::
BatchQueryTodo
::
debug
)
.
def
(
py
::
pickle
(
[](
const
scheduler
::
BatchQueryTodo
&
self
)
{
return
py
::
make_tuple
(
self
.
query_ids
,
self
.
query_tokens
,
self
.
query_lengths
,
self
.
block_indexes
,
self
.
attn_masks
,
self
.
rope_ranges
,
self
.
sample_options
,
self
.
prefill_mini_batches
,
self
.
decode_mini_batches
,
self
.
stop_criteria
);
},
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
10
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
BatchQueryTodo
bqt
;
bqt
.
query_ids
=
t
[
0
].
cast
<
std
::
vector
<
scheduler
::
QueryID
>>
();
bqt
.
query_tokens
=
t
[
1
].
cast
<
std
::
vector
<
torch
::
Tensor
>>
();
bqt
.
query_lengths
=
t
[
2
].
cast
<
std
::
vector
<
scheduler
::
TokenLength
>>
();
bqt
.
block_indexes
=
t
[
3
].
cast
<
std
::
vector
<
torch
::
Tensor
>>
();
bqt
.
attn_masks
=
t
[
4
].
cast
<
std
::
optional
<
torch
::
Tensor
>>
();
bqt
.
rope_ranges
=
t
[
5
].
cast
<
std
::
optional
<
torch
::
Tensor
>>
();
bqt
.
sample_options
=
t
[
6
].
cast
<
std
::
vector
<
scheduler
::
SampleOptions
>>
();
bqt
.
prefill_mini_batches
=
t
[
7
].
cast
<
std
::
vector
<
scheduler
::
PrefillTask
>>
();
bqt
.
decode_mini_batches
=
t
[
8
].
cast
<
std
::
vector
<
std
::
vector
<
scheduler
::
QueryID
>>>
();
bqt
.
stop_criteria
=
t
[
9
].
cast
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>>
();
return
bqt
;
}));
py
::
class_
<
scheduler
::
QueryUpdate
>
(
m
,
"QueryUpdate"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"id"
,
&
scheduler
::
QueryUpdate
::
id
)
.
def_readwrite
(
"ok"
,
&
scheduler
::
QueryUpdate
::
ok
)
.
def_readwrite
(
"is_prefill"
,
&
scheduler
::
QueryUpdate
::
is_prefill
)
.
def_readwrite
(
"decode_done"
,
&
scheduler
::
QueryUpdate
::
decode_done
)
.
def_readwrite
(
"active_position"
,
&
scheduler
::
QueryUpdate
::
active_position
)
.
def_readwrite
(
"generated_token"
,
&
scheduler
::
QueryUpdate
::
generated_token
)
.
def
(
py
::
pickle
(
[](
const
scheduler
::
QueryUpdate
&
self
)
{
return
py
::
make_tuple
(
self
.
id
,
self
.
ok
,
self
.
is_prefill
,
self
.
decode_done
,
self
.
active_position
,
self
.
generated_token
);
},
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
6
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
QueryUpdate
qu
;
qu
.
id
=
t
[
0
].
cast
<
scheduler
::
QueryID
>
();
qu
.
ok
=
t
[
1
].
cast
<
bool
>
();
qu
.
is_prefill
=
t
[
2
].
cast
<
bool
>
();
qu
.
decode_done
=
t
[
3
].
cast
<
bool
>
();
qu
.
active_position
=
t
[
4
].
cast
<
scheduler
::
TokenLength
>
();
qu
.
generated_token
=
t
[
5
].
cast
<
scheduler
::
Token
>
();
return
qu
;
}));
py
::
class_
<
scheduler
::
InferenceContext
>
(
m
,
"InferenceContext"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"k_cache"
,
&
scheduler
::
InferenceContext
::
k_cache
)
.
def_readwrite
(
"v_cache"
,
&
scheduler
::
InferenceContext
::
v_cache
);
py
::
class_
<
scheduler
::
QueryAdd
>
(
m
,
"QueryAdd"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"query_token"
,
&
scheduler
::
QueryAdd
::
query_token
)
// .def_readwrite("attn_mask", &scheduler::QueryAdd::attn_mask)
.
def_readwrite
(
"query_length"
,
&
scheduler
::
QueryAdd
::
query_length
)
.
def_readwrite
(
"estimated_length"
,
&
scheduler
::
QueryAdd
::
estimated_length
)
.
def_readwrite
(
"sample_options"
,
&
scheduler
::
QueryAdd
::
sample_options
)
.
def_readwrite
(
"user_id"
,
&
scheduler
::
QueryAdd
::
user_id
)
.
def_readwrite
(
"SLO_TTFT_ms"
,
&
scheduler
::
QueryAdd
::
SLO_TTFT_ms
)
.
def_readwrite
(
"SLO_TBT_ms"
,
&
scheduler
::
QueryAdd
::
SLO_TBT_ms
)
.
def_readwrite
(
"stop_criteria"
,
&
scheduler
::
QueryAdd
::
stop_criteria
)
.
def
(
"serialize"
,
&
scheduler
::
QueryAdd
::
serialize
)
.
def_static
(
"deserialize"
,
&
scheduler
::
QueryAdd
::
deserialize
)
.
def
(
py
::
pickle
(
[](
const
scheduler
::
QueryAdd
&
self
)
{
return
py
::
make_tuple
(
self
.
query_token
,
// self.attn_mask,
self
.
query_length
,
self
.
estimated_length
,
self
.
sample_options
,
self
.
user_id
,
self
.
SLO_TTFT_ms
,
self
.
SLO_TBT_ms
,
self
.
stop_criteria
);
},
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
8
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
QueryAdd
qa
;
qa
.
query_token
=
t
[
0
].
cast
<
std
::
vector
<
scheduler
::
Token
>>
();
// qa.attn_mask = t[1].cast<torch::Tensor>();
qa
.
query_length
=
t
[
1
].
cast
<
scheduler
::
TokenLength
>
();
qa
.
estimated_length
=
t
[
2
].
cast
<
scheduler
::
TokenLength
>
();
qa
.
sample_options
=
t
[
3
].
cast
<
scheduler
::
SampleOptions
>
();
qa
.
user_id
=
t
[
4
].
cast
<
scheduler
::
UserID
>
();
qa
.
SLO_TTFT_ms
=
t
[
5
].
cast
<
int
>
();
qa
.
SLO_TBT_ms
=
t
[
6
].
cast
<
int
>
();
qa
.
stop_criteria
=
t
[
7
].
cast
<
std
::
vector
<
std
::
vector
<
int
>>>
();
return
qa
;
}));
py
::
class_
<
scheduler
::
Scheduler
,
std
::
shared_ptr
<
scheduler
::
Scheduler
>>
(
m
,
"Scheduler"
)
.
def
(
"init"
,
&
scheduler
::
Scheduler
::
init
)
.
def
(
"run"
,
&
scheduler
::
Scheduler
::
run
)
.
def
(
"stop"
,
&
scheduler
::
Scheduler
::
stop
)
.
def
(
"add_query"
,
&
scheduler
::
Scheduler
::
add_query
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"cancel_query"
,
&
scheduler
::
Scheduler
::
cancel_query
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"update_last_batch"
,
&
scheduler
::
Scheduler
::
update_last_batch
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"get_inference_context"
,
&
scheduler
::
Scheduler
::
get_inference_context
);
m
.
def
(
"create_scheduler"
,
&
scheduler
::
create_scheduler
,
"Create a new Scheduler instance"
);
}
csrc/balance_serve/sched/metrics.cpp
0 → 100644
View file @
877aec85
#include "metrics.h"
#include <iostream>
// 构造函数
Metrics
::
Metrics
(
const
MetricsConfig
&
config
)
:
registry_
(
std
::
make_shared
<
prometheus
::
Registry
>
()),
exposer_
(
config
.
endpoint
),
stop_uptime_thread_
(
false
),
start_time_
(
std
::
chrono
::
steady_clock
::
now
())
{
// 定义统一的桶大小,最大为 10000 ms (10 s)
std
::
vector
<
double
>
common_buckets
=
{
0.001
,
0.005
,
0.01
,
0.05
,
0.1
,
0.5
,
1.0
,
5.0
,
10.0
,
50.0
,
100.0
,
500.0
,
1000.0
,
5000.0
,
10000.0
};
// 毫秒
// 注册 TTFT_ms Histogram
auto
&
TTFT_family
=
prometheus
::
BuildHistogram
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_TTFT_ms"
)
.
Help
(
"Time to first token in milliseconds"
)
.
Register
(
*
registry_
);
TTFT_ms
=
&
TTFT_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
// 注册 TBT_ms Histogram
auto
&
TBT_family
=
prometheus
::
BuildHistogram
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_TBT_ms"
)
.
Help
(
"Time between tokens in milliseconds"
)
.
Register
(
*
registry_
);
TBT_ms
=
&
TBT_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
// 注册 schedule_time Histogram
auto
&
schedule_time_family
=
prometheus
::
BuildHistogram
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_schedule_time_ms"
)
.
Help
(
"Time to generate schedule in milliseconds"
)
.
Register
(
*
registry_
);
schedule_time
=
&
schedule_time_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
// 注册 generated_tokens Counter
auto
&
generated_tokens_family
=
prometheus
::
BuildCounter
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_generated_tokens_total"
)
.
Help
(
"Total generated tokens"
)
.
Register
(
*
registry_
);
generated_tokens
=
&
generated_tokens_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 throughput_query Gauge
auto
&
throughput_query_family
=
prometheus
::
BuildGauge
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_throughput_query"
)
.
Help
(
"Throughput per second based on queries"
)
.
Register
(
*
registry_
);
throughput_query
=
&
throughput_query_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 throughput_generated_tokens Gauge
auto
&
throughput_generated_tokens_family
=
prometheus
::
BuildGauge
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_throughput_generated_tokens"
)
.
Help
(
"Throughput per second based on generated tokens"
)
.
Register
(
*
registry_
);
throughput_generated_tokens
=
&
throughput_generated_tokens_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 event_count Counter family
event_count_family_
=
&
prometheus
::
BuildCounter
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_event_count_total"
)
.
Help
(
"Count of various events"
)
.
Register
(
*
registry_
);
batch_count_family_
=
&
prometheus
::
BuildCounter
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_batch_count_total"
)
.
Help
(
"Count of various batch by status"
)
.
Register
(
*
registry_
);
// 注册 query_count Counter family
query_count_family_
=
&
prometheus
::
BuildCounter
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_query_count_total"
)
.
Help
(
"Count of queries by status"
)
.
Register
(
*
registry_
);
// 注册 uptime_ms Gauge
auto
&
uptime_family
=
prometheus
::
BuildGauge
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_uptime_ms"
)
.
Help
(
"Uptime of the scheduler in milliseconds"
)
.
Register
(
*
registry_
);
uptime_ms
=
&
uptime_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 GPU 利用率 Gauges
auto
&
gpu_util_family
=
prometheus
::
BuildGauge
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_gpu_utilization_ratio"
)
.
Help
(
"Current GPU utilization ratio (0 to 1)"
)
.
Register
(
*
registry_
);
for
(
size_t
i
=
0
;
i
<
config
.
gpu_count
;
++
i
)
{
gpu_utilization_gauges
.
push_back
(
&
gpu_util_family
.
Add
(
{{
"gpu_id"
,
std
::
to_string
(
i
)},
{
"model"
,
config
.
model_name
}}));
}
// 将 Registry 注册到 Exposer 中
exposer_
.
RegisterCollectable
(
registry_
);
// 启动 uptime 更新线程
StartUptimeUpdater
();
}
// 析构函数
Metrics
::~
Metrics
()
{
StopUptimeUpdater
();
}
// 启动 uptime 更新线程
void
Metrics
::
StartUptimeUpdater
()
{
uptime_thread_
=
std
::
thread
([
this
]()
{
while
(
!
stop_uptime_thread_
)
{
auto
now
=
std
::
chrono
::
steady_clock
::
now
();
std
::
chrono
::
duration
<
double
,
std
::
milli
>
uptime_duration
=
now
-
start_time_
;
uptime_ms
->
Set
(
uptime_duration
.
count
());
// fn_every_sec(this);
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
}
});
}
// 停止 uptime 更新线程
void
Metrics
::
StopUptimeUpdater
()
{
stop_uptime_thread_
=
true
;
if
(
uptime_thread_
.
joinable
())
{
uptime_thread_
.
join
();
}
}
// 获取 event_count 指标
prometheus
::
Counter
*
Metrics
::
event_count
(
const
std
::
string
&
type
)
{
return
&
event_count_family_
->
Add
({{
"type"
,
type
}});
// 可根据需要添加更多标签
}
// 获取 query_count 指标
prometheus
::
Counter
*
Metrics
::
query_count
(
const
std
::
string
&
status
)
{
return
&
query_count_family_
->
Add
(
{{
"status"
,
status
}});
// 可根据需要添加更多标签
}
prometheus
::
Counter
*
Metrics
::
batch_count
(
const
std
::
string
&
type
)
{
return
&
batch_count_family_
->
Add
({{
"type"
,
type
}});
}
csrc/balance_serve/sched/metrics.h
0 → 100644
View file @
877aec85
#ifndef Metrics_H
#define Metrics_H
#include <atomic>
#include <chrono>
#include <memory>
#include <prometheus/counter.h>
#include <prometheus/exposer.h>
#include <prometheus/gauge.h>
#include <prometheus/histogram.h>
#include <prometheus/registry.h>
#include <string>
#include <thread>
#include <vector>
#include "timer.hpp"
// 指标前缀宏定义
#define METRIC_PREFIX "scheduler"
class
Metrics
;
// 配置结构体
struct
MetricsConfig
{
std
::
string
endpoint
;
std
::
string
model_name
;
// 模型名称,如 "gpt-4"
size_t
gpu_count
;
// GPU数量
};
// Metrics 类,根据配置初始化 Prometheus 指标
class
Metrics
{
public:
// 构造函数传入 MetricsConfig
Metrics
(
const
MetricsConfig
&
config
);
~
Metrics
();
// 禁止拷贝和赋值
Metrics
(
const
Metrics
&
)
=
delete
;
Metrics
&
operator
=
(
const
Metrics
&
)
=
delete
;
std
::
function
<
void
(
Metrics
*
)
>
fn_every_sec
;
// 指标指针
prometheus
::
Gauge
*
uptime_ms
;
prometheus
::
Histogram
*
TTFT_ms
;
prometheus
::
Histogram
*
TBT_ms
;
prometheus
::
Histogram
*
schedule_time
;
prometheus
::
Gauge
*
throughput_query
;
prometheus
::
Gauge
*
throughput_generated_tokens
;
prometheus
::
Counter
*
generated_tokens
;
std
::
vector
<
prometheus
::
Gauge
*>
gpu_utilization_gauges
;
// 计数器家族
prometheus
::
Counter
*
event_count
(
const
std
::
string
&
type
);
prometheus
::
Counter
*
query_count
(
const
std
::
string
&
status
);
prometheus
::
Counter
*
batch_count
(
const
std
::
string
&
type
);
private:
std
::
shared_ptr
<
prometheus
::
Registry
>
registry_
;
prometheus
::
Exposer
exposer_
;
// 计数器家族
prometheus
::
Family
<
prometheus
::
Counter
>
*
event_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>
*
batch_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>
*
query_count_family_
;
// 线程和控制变量用于更新 uptime_ms
std
::
thread
uptime_thread_
;
std
::
atomic
<
bool
>
stop_uptime_thread_
;
// 启动 uptime 更新线程
void
StartUptimeUpdater
();
// 停止 uptime 更新线程
void
StopUptimeUpdater
();
// 记录程序启动时间
std
::
chrono
::
steady_clock
::
time_point
start_time_
;
};
struct
HistogramTimerWrapper
{
prometheus
::
Histogram
*
histogram
;
Timer
timer
;
inline
HistogramTimerWrapper
(
prometheus
::
Histogram
*
histogram
)
:
histogram
(
histogram
),
timer
()
{
timer
.
start
();
}
inline
~
HistogramTimerWrapper
()
{
histogram
->
Observe
(
timer
.
elapsedMs
());
}
};
#endif // Metrics_H
csrc/balance_serve/sched/model_config.h
0 → 100644
View file @
877aec85
#ifndef __MODEL_CONFIG_HPP_
#define __MODEL_CONFIG_HPP_
#include "nlohmann/json.hpp"
#include <iostream>
#include <filesystem>
#include <fstream>
using
DimSize
=
size_t
;
using
URL
=
std
::
string
;
using
ModelName
=
std
::
string
;
// We must assure this can be load by config.json
class
ModelConfig
{
public:
DimSize
hidden_size
;
DimSize
intermediate_size
;
size_t
max_position_embeddings
;
std
::
string
model_type
;
size_t
num_attention_heads
;
size_t
num_hidden_layers
;
size_t
num_key_value_heads
;
size_t
vocab_size
;
NLOHMANN_DEFINE_TYPE_INTRUSIVE
(
ModelConfig
,
hidden_size
,
intermediate_size
,
max_position_embeddings
,
model_type
,
num_attention_heads
,
num_hidden_layers
,
num_key_value_heads
,
vocab_size
);
void
load_from
(
std
::
filesystem
::
path
path
)
{
std
::
cout
<<
"Load from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
nlohmann
::
json
j
;
i
>>
j
;
*
this
=
j
.
get
<
ModelConfig
>
();
}
};
using
QuantType
=
std
::
string
;
static
const
QuantType
NoQuantType
=
""
;
class
QuantConfig
{
public:
QuantType
name
;
// For GEMV
QuantType
type_of_dot_vector
=
NoQuantType
;
inline
bool
can_be_used_as_matrix
()
{
return
type_of_dot_vector
!=
NoQuantType
;
}
bool
can_be_used_as_vector
;
double
bytes_per_element
;
bool
has_scale
;
bool
has_min
;
size_t
block_element_count
;
size_t
block_element_size
;
URL
reference
=
""
;
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT
(
QuantConfig
,
name
,
type_of_dot_vector
,
can_be_used_as_vector
,
bytes_per_element
,
has_scale
,
has_min
,
block_element_count
,
block_element_size
,
reference
);
};
inline
std
::
map
<
QuantType
,
QuantConfig
>
quant_configs
;
inline
std
::
map
<
ModelName
,
ModelConfig
>
model_configs
;
inline
void
load_quant_configs
(
std
::
filesystem
::
path
path
)
{
nlohmann
::
json
j
;
if
(
std
::
filesystem
::
exists
(
path
))
{
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
i
>>
j
;
quant_configs
=
j
.
get
<
std
::
map
<
QuantType
,
QuantConfig
>>
();
std
::
cout
<<
"Loaded Quant Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
quant_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
else
{
std
::
cout
<<
__FUNCTION__
<<
" no file at "
<<
path
<<
std
::
endl
;
}
}
inline
void
dump_quant_configs
(
std
::
filesystem
::
path
path
)
{
std
::
ofstream
o
(
path
);
nlohmann
::
json
j
=
quant_configs
;
o
<<
j
.
dump
(
4
);
}
inline
void
load_model_configs
(
std
::
filesystem
::
path
path
)
{
nlohmann
::
json
j
;
if
(
std
::
filesystem
::
exists
(
path
))
{
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
i
>>
j
;
model_configs
=
j
.
get
<
std
::
map
<
ModelName
,
ModelConfig
>>
();
std
::
cout
<<
"Loaded Model Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
model_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
else
{
std
::
cout
<<
__FUNCTION__
<<
" no file at "
<<
path
<<
std
::
endl
;
}
}
inline
void
dump_model_configs
(
std
::
filesystem
::
path
path
)
{
std
::
ofstream
o
(
path
);
nlohmann
::
json
j
=
model_configs
;
o
<<
j
.
dump
(
4
);
}
#endif
\ No newline at end of file
csrc/balance_serve/sched/scheduler.cpp
0 → 100644
View file @
877aec85
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO
#define FMT_HEADER_ONLY
#include "nlohmann/json.hpp"
#include "spdlog/spdlog.h"
#include "scheduler.h"
#include <optional>
#include "arithmetic.hpp"
#include "atomic_ptr_with_flags.hpp"
#include "easy_format.hpp"
#include "metrics.h"
#include "mpsc.hpp"
#include "timer.hpp"
#include <atomic>
#include <cassert>
#include <future>
#include <memory>
#include <queue>
#include "kvc2.h"
using
json
=
nlohmann
::
json
;
namespace
scheduler
{
void
Settings
::
auto_derive
()
{
gpu_device_count
=
gpu_device_id
.
size
();
if
(
torch
::
cuda
::
is_available
())
{
size_t
gpu_count
=
torch
::
cuda
::
device_count
();
SPDLOG_INFO
(
"Number of available GPUs: {}, want {}"
,
gpu_count
,
gpu_device_count
);
if
(
gpu_count
<
gpu_device_count
)
{
SPDLOG_ERROR
(
"Not enough GPUs available."
);
exit
(
0
);
}
for
(
size_t
i
=
0
;
i
<
gpu_device_count
;
i
++
)
{
devices
.
push_back
(
torch
::
Device
(
torch
::
kCUDA
,
gpu_device_id
[
i
]));
}
}
else
{
SPDLOG_ERROR
(
"CUDA is not available on this system."
);
exit
(
0
);
}
if
(
model_settings
.
num_k_heads
%
gpu_device_count
!=
0
)
{
SPDLOG_ERROR
(
"num_k_heads {} is not divisible by gpu_device_count {}"
,
model_settings
.
num_k_heads
,
gpu_device_count
);
assert
(
false
);
}
size_t
gpu_memory_available
=
gpu_memory_size
*
memory_utilization_percentage
;
if
(
gpu_memory_available
*
gpu_device_count
<
model_settings
.
params_nbytes
())
{
SPDLOG_ERROR
(
"GPU memory size {}G is smaller than {}G"
,
gpu_memory_available
*
gpu_device_count
/
1e9
,
model_settings
.
params_nbytes
()
/
1e9
);
assert
(
false
);
}
assert
(
model_settings
.
k_head_dim
%
model_settings
.
num_k_heads
==
0
);
size_t
head_per_gpu
=
model_settings
.
num_k_heads
/
gpu_device_count
;
size_t
gpu_memory_for_kv_cache
=
gpu_memory_available
/*- model_settings.params_nbytes() /
gpu_device_count*/
;
SPDLOG_INFO
(
"Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB"
,
gpu_memory_size
/
(
1
<<
20
),
model_settings
.
params_nbytes
()
/
gpu_device_count
/
(
1
<<
20
),
gpu_memory_for_kv_cache
/
(
1
<<
20
),
(
gpu_memory_size
-
gpu_memory_available
)
/
(
1
<<
20
));
size_t
kv_cache_on_cnt
=
(
size_t
)(
k_cache_on
)
+
(
size_t
)(
v_cache_on
);
size_t
max_total_kvcache_pages
=
gpu_memory_for_kv_cache
/
(
kv_cache_on_cnt
*
head_per_gpu
*
model_settings
.
k_head_dim
*
model_settings
.
bytes_per_kv_cache_element
*
page_size
*
model_settings
.
layer_count
);
if
(
total_kvcache_pages
.
has_value
())
{
if
(
total_kvcache_pages
.
value
()
>
max_total_kvcache_pages
)
{
SPDLOG_ERROR
(
"total_kvcache_pages {} is larger than max_total_kvcache_pages {}"
,
total_kvcache_pages
.
value
(),
max_total_kvcache_pages
);
assert
(
false
);
}
}
else
{
total_kvcache_pages
=
max_total_kvcache_pages
;
SPDLOG_INFO
(
"total_kvcache_pages is auto derived as {}"
,
max_total_kvcache_pages
);
}
if
(
page_size
%
256
!=
0
)
{
SPDLOG_ERROR
(
"page_size {} is not divisible by 256"
,
page_size
);
assert
(
false
);
}
if
(
page_size
<
256
)
{
SPDLOG_ERROR
(
"page_size {} is smaller than 256"
,
page_size
);
assert
(
false
);
}
}
std
::
string
BatchQueryTodo
::
debug
()
{
std
::
string
re
=
"BatchQueryTodo: "
;
re
+=
"QueryIDs: "
;
for
(
auto
&
id
:
query_ids
)
{
re
+=
std
::
to_string
(
id
)
+
" "
;
}
return
re
;
}
bool
BatchQueryTodo
::
empty
()
{
return
prefill_mini_batches
.
empty
()
&&
decode_mini_batches
.
empty
();
}
struct
QueryMaintainer
;
struct
Query
{
QueryID
id
;
torch
::
Tensor
query_token
;
TokenLength
prompt_length
;
TokenLength
no_kvcache_from
;
TokenLength
estimated_length
;
SampleOptions
sample_options
;
UserID
user_id
;
std
::
optional
<
int
>
SLO_TTFT_ms
;
std
::
optional
<
int
>
SLO_TBT_ms
;
std
::
vector
<
std
::
vector
<
int
>>
stop_criteria
;
// status
// Query status changed by this order
enum
Status
{
Received
,
Preparing
,
Ready
,
Prefill
,
Decode
,
Done
};
Status
plan_status
=
Received
;
TokenLength
active_position
;
// the position where no kvcache now
TokenLength
plan_position
;
// the position where no kvcache now, in plan
size_t
prepare_try_count
=
0
;
std
::
shared_ptr
<
kvc2
::
DoubleCacheHandleInterface
>
kvc2_handle
=
nullptr
;
// derived from kvc2_handle
torch
::
Tensor
block_index
;
// block indexes
struct
QueryContext
{
ModelName
model_name
;
QuantType
quant_type
;
kvc2
::
KVC2Interface
*
kvc2_interface
;
QueryMaintainer
*
query_maintainer
;
Metrics
*
met
;
}
ctx
;
void
after_load
(
bool
ok
);
void
to_status
(
Status
to
);
void
export_metrics
()
{
ctx
.
met
->
query_count
(
status_to_string
(
plan_status
))
->
Increment
(
1
);
}
Query
(
QueryID
id
,
QueryAdd
query_add
,
QueryContext
context
)
:
id
(
id
),
prompt_length
(
query_add
.
query_length
),
no_kvcache_from
(
0
),
estimated_length
(
query_add
.
estimated_length
),
sample_options
(
query_add
.
sample_options
),
user_id
(
query_add
.
user_id
),
SLO_TTFT_ms
(
query_add
.
SLO_TTFT_ms
),
SLO_TBT_ms
(
query_add
.
SLO_TBT_ms
),
stop_criteria
(
query_add
.
stop_criteria
),
ctx
(
context
)
{
std
::
vector
<
int64_t
>
shape
=
{
int64_t
(
query_add
.
estimated_length
)};
query_token
=
torch
::
zeros
(
shape
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
));
assert
(
query_token
.
is_contiguous
());
if
(
query_token
.
is_contiguous
()
==
false
)
{
SPDLOG_ERROR
(
"Query Token must be contiguous!"
);
exit
(
1
);
}
memcpy
(
query_token
.
data_ptr
(),
query_add
.
query_token
.
data
(),
query_add
.
query_length
*
sizeof
(
Token
));
no_kvcache_from
=
0
;
// maybe match prefix later
export_metrics
();
}
Token
&
token_at
(
size_t
idx
)
{
return
reinterpret_cast
<
Token
*>
(
query_token
.
data_ptr
())[
idx
];
}
void
absorb_update
(
const
QueryUpdate
&
update
)
{
SPDLOG_DEBUG
(
"{}"
,
update
.
debug
());
active_position
=
update
.
active_position
;
kvc2_handle
->
append_tokens
(
&
token_at
(
0
),
active_position
);
// active_position is length -1
if
(
update
.
is_prefill
)
{
if
(
active_position
==
prompt_length
)
{
token_at
(
active_position
)
=
update
.
generated_token
;
ctx
.
met
->
generated_tokens
->
Increment
(
1
);
}
}
else
{
token_at
(
active_position
)
=
update
.
generated_token
;
ctx
.
met
->
generated_tokens
->
Increment
(
1
);
}
if
(
update
.
decode_done
||
active_position
==
estimated_length
-
1
)
{
to_status
(
Done
);
}
}
void
absorb_prefill_task
(
const
PrefillTask
&
task
)
{
auto
&
[
id
,
start
,
length
]
=
task
;
this
->
plan_position
=
start
+
length
;
if
(
this
->
plan_position
==
prompt_length
)
{
to_status
(
Decode
);
}
}
void
absorb_decode_task
([[
maybe_unused
]]
const
QueryID
&
task
)
{
this
->
plan_position
+=
1
;
}
PrefillTask
get_prefill_task
(
size_t
prefill_length
)
{
if
(
prefill_length
+
plan_position
>
prompt_length
)
{
prefill_length
=
prompt_length
-
plan_position
;
}
return
{
id
,
plan_position
,
prefill_length
};
}
static
std
::
string
status_to_string
(
Status
status
)
{
switch
(
status
)
{
case
Received
:
return
"Received"
;
case
Preparing
:
return
"Preparing"
;
case
Ready
:
return
"Ready"
;
case
Prefill
:
return
"Prefill"
;
case
Decode
:
return
"Decode"
;
case
Done
:
return
"Done"
;
}
assert
(
false
);
}
void
debug
()
{
std
::
string
status_string
=
status_to_string
(
plan_status
);
SPDLOG_DEBUG
(
"Query {}, prompt_length {}, estimated_length {}, plan status "
"{}, plan position {} "
"active position {}"
,
id
,
prompt_length
,
estimated_length
,
status_string
,
plan_position
,
active_position
);
}
};
std
::
string
QueryUpdate
::
debug
()
const
{
return
fmt
::
format
(
"Query {}, ok {}, is_prefill {}, done {}, active_position "
"{}, gen token {}"
,
id
,
ok
,
is_prefill
,
decode_done
,
active_position
,
generated_token
);
}
using
Q
=
std
::
shared_ptr
<
Query
>
;
struct
KVC2_Maintainer
{
Settings
settings
;
std
::
vector
<
torch
::
Tensor
>
k_cache
;
std
::
vector
<
torch
::
Tensor
>
v_cache
;
std
::
shared_ptr
<
kvc2
::
KVC2Interface
>
kvc2_interface
;
KVC2_Maintainer
(
Settings
settings
)
:
settings
(
settings
)
{
// SPDLOG_WARN("Creating KVC2 Instance {}", settings.kvc2_root_path);
assert
(
settings
.
kvc2_root_path
.
size
()
>
0
);
// SPDLOG_WARN("Sizeof KVC2Config {} upper", sizeof(kvc2::KVC2Config));
kvc2
::
GPUPageCacheConfig
gpu_cache_config
{
.
gpu_only
=
settings
.
gpu_only
,
.
gpu_devices_id
=
settings
.
gpu_device_id
,
.
layer_count
=
settings
.
model_settings
.
layer_count
,
.
total_kvcache_pages
=
settings
.
total_kvcache_pages
.
value
(),
.
num_token_per_page
=
settings
.
page_size
,
.
num_k_heads
=
settings
.
model_settings
.
num_k_heads
,
.
k_head_dim
=
settings
.
use_self_defined_head_dim
?
settings
.
self_defined_head_dim
:
settings
.
model_settings
.
k_head_dim
,
.
full_kv_cache_on_each_gpu
=
settings
.
full_kv_cache_on_each_gpu
,
.
k_cache_on
=
settings
.
k_cache_on
,
.
v_cache_on
=
settings
.
v_cache_on
,
.
tensor_type
=
torch
::
kBFloat16
,
};
auto
model_configs_path
=
std
::
filesystem
::
path
(
settings
.
kvc2_config_path
)
/
"model_configs.json"
;
load_model_configs
(
model_configs_path
);
auto
my_model_config
=
ModelConfig
();
my_model_config
.
load_from
(
std
::
filesystem
::
path
(
settings
.
model_settings
.
model_path
)
/
"config.json"
);
model_configs
[
settings
.
model_name
]
=
my_model_config
;
dump_model_configs
(
model_configs_path
);
kvc2
::
KVC2Config
kvc2_config
=
{
.
k_cache_on
=
settings
.
k_cache_on
,
.
v_cache_on
=
settings
.
v_cache_on
,
.
gpu_only
=
settings
.
gpu_only
,
.
load_from_disk
=
settings
.
load_from_disk
,
.
save_to_disk
=
settings
.
save_to_disk
,
.
path
=
settings
.
kvc2_root_path
,
.
config_path
=
settings
.
kvc2_config_path
,
.
num_token_per_page
=
settings
.
page_size
,
.
memory_pool_size
=
size_t
(
settings
.
memory_pool_size_GB
*
1e9
),
.
evict_count
=
settings
.
evict_count
,
.
gpu_cache_config
=
gpu_cache_config
,
.
metrics_port
=
settings
.
kvc2_metrics_port
,
};
kvc2_interface
=
kvc2
::
create_kvc2
(
kvc2_config
);
if
(
settings
.
load_from_disk
)
kvc2_interface
->
load
();
SPDLOG_DEBUG
(
"KVC2 created ok"
);
auto
[
k_cache
,
v_cache
]
=
kvc2_interface
->
get_kvcache
();
this
->
k_cache
=
k_cache
;
this
->
v_cache
=
v_cache
;
}
};
using
EventAddQuery
=
std
::
pair
<
QueryAdd
,
std
::
promise
<
QueryID
>
*>
;
using
EventUpdateQuery
=
BatchQueryUpdate
;
using
EventTakenBatch
=
std
::
shared_ptr
<
BatchQueryTodo
>
;
struct
EventPrepare
{
QueryID
query_id
;
bool
first_try
;
};
struct
EventPrepared
{
QueryID
query_id
;
bool
ok
;
};
struct
EventQueryStatus
{
QueryID
query_id
;
Query
::
Status
now_status
;
};
struct
EventSchedule
{};
using
Event
=
std
::
variant
<
EventAddQuery
,
EventUpdateQuery
,
EventTakenBatch
,
EventPrepare
,
EventPrepared
,
EventQueryStatus
,
EventSchedule
>
;
template
<
typename
T
>
std
::
string
event_name
(
const
T
&
event
);
template
<
>
std
::
string
event_name
(
const
EventAddQuery
&
)
{
return
"EventAddQuery"
;
}
template
<
>
std
::
string
event_name
(
const
EventUpdateQuery
&
)
{
return
"EventUpdateQuery"
;
}
template
<
>
std
::
string
event_name
(
const
EventTakenBatch
&
)
{
return
"EventTakenBatch"
;
}
template
<
>
std
::
string
event_name
(
const
EventPrepare
&
)
{
return
"EventPrepare"
;
}
template
<
>
std
::
string
event_name
(
const
EventPrepared
&
)
{
return
"EventPrepared"
;
}
template
<
>
std
::
string
event_name
(
const
EventQueryStatus
&
)
{
return
"EventQueryStatus"
;
}
template
<
>
std
::
string
event_name
(
const
EventSchedule
&
)
{
return
"EventSchedule"
;
}
// 用 std::visit 实现对 variant 的 event_name
std
::
string
event_name
(
const
Event
&
event
)
{
return
std
::
visit
([](
const
auto
&
e
)
{
return
event_name
(
e
);
},
event
);
}
static_assert
(
std
::
is_copy_constructible
<
Event
>::
value
);
static_assert
(
std
::
is_move_constructible
<
Event
>::
value
);
struct
QueryMaintainer
:
public
Scheduler
{
// only get access by event loop
Settings
settings
;
QueryID
query_id_counter
=
NoQueryID
+
1
;
std
::
map
<
QueryID
,
Q
>
query_map
;
std
::
shared_ptr
<
KVC2_Maintainer
>
kvc2_maintainer
;
std
::
shared_ptr
<
Metrics
>
met
;
// multi-thread visit
std
::
atomic_bool
stop_flag
=
false
;
// TODO consider correctness of event loop
MPSCQueueConsumerLock
<
Event
>
event_loop_queue
;
// std::binary_semaphore batch_ready{0};
AtomicPtrWithFlag
<
BatchQueryTodo
>
next_batch
;
QueryMaintainer
()
=
default
;
void
gen_batch_query_todo
(
BatchQueryTodo
*
re
,
const
std
::
set
<
Q
>
&
queries
)
{
std
::
vector
<
std
::
vector
<
QueryID
>>
d_batch
(
2
);
size_t
last_decode_batch
=
0
;
size_t
prefill_num
=
0
;
size_t
decode_num
=
0
;
size_t
preill_length
=
0
;
for
(
auto
&
q
:
queries
)
{
if
(
q
->
plan_status
==
Query
::
Prefill
)
{
prefill_num
+=
1
;
}
if
(
q
->
plan_status
==
Query
::
Decode
)
{
decode_num
+=
1
;
}
}
if
(
prefill_num
>=
2
||
(
prefill_num
==
1
&&
settings
.
max_batch_size
-
2
<
decode_num
))
{
preill_length
=
settings
.
recommended_chunk_prefill_token_count
;
}
else
{
preill_length
=
settings
.
recommended_chunk_prefill_token_count
*
2
;
}
for
(
auto
&
q
:
queries
)
{
re
->
query_ids
.
push_back
(
q
->
id
);
re
->
query_tokens
.
push_back
(
q
->
query_token
);
re
->
query_lengths
.
push_back
(
q
->
prompt_length
);
if
(
q
->
plan_status
==
Query
::
Prefill
)
{
re
->
prefill_mini_batches
.
push_back
(
q
->
get_prefill_task
(
preill_length
));
assert
(
re
->
prefill_mini_batches
.
size
()
<=
2
);
}
if
(
q
->
plan_status
==
Query
::
Decode
)
{
d_batch
[
last_decode_batch
].
push_back
(
q
->
id
);
// last_decode_batch = 1 - last_decode_batch;
if
(
d_batch
[
last_decode_batch
].
size
()
==
settings
.
max_batch_size
-
1
)
{
last_decode_batch
+=
1
;
assert
(
last_decode_batch
<
2
);
}
}
re
->
block_indexes
.
push_back
(
q
->
block_index
);
re
->
sample_options
.
push_back
(
q
->
sample_options
);
re
->
stop_criteria
.
push_back
(
q
->
stop_criteria
);
}
re
->
attn_masks
=
std
::
nullopt
;
re
->
rope_ranges
=
std
::
nullopt
;
for
(
auto
&
b
:
d_batch
)
{
if
(
b
.
empty
())
continue
;
re
->
decode_mini_batches
.
push_back
(
b
);
}
met
->
batch_count
(
"Generated"
)
->
Increment
(
1
);
}
// Interface
void
init
(
Settings
settings
)
override
{
SPDLOG_INFO
(
"
\n
Scheduler Settings:
\n
"
" model_name: {}
\n
"
" quant_type: {}
\n
"
" model_path: {}
\n
"
" params_count: {}
\n
"
" layer_count: {}
\n
"
" num_k_heads: {}
\n
"
" k_head_dim: {}
\n
"
" bytes_per_params: {}
\n
"
" bytes_per_kv_cache_element: {}
\n
"
" page_size: {}
\n
"
" gpu_device_id: {}
\n
"
" gpu_memory_size: {}
\n
"
" memory_utilization_percentage: {}
\n
"
" max_batch_size: {}
\n
"
" recommended_chunk_prefill_token_count: {}
\n
"
" sched_metrics_port: {}
\n
"
" kvc2_config_path: {}
\n
"
" kvc2_root_path: {}
\n
"
" memory_pool_size_GB: {}
\n
"
" evict_count: {}
\n
"
" kvc2_metrics_port: {}
\n
"
" load_from_disk: {}
\n
"
" save_to_disk: {}
\n
"
" strategy_name: {}
\n
"
" gpu_device_count: {}
\n
"
,
settings
.
model_name
,
settings
.
quant_type
,
settings
.
model_settings
.
model_path
,
settings
.
model_settings
.
params_count
,
settings
.
model_settings
.
layer_count
,
settings
.
model_settings
.
num_k_heads
,
settings
.
model_settings
.
k_head_dim
,
settings
.
model_settings
.
bytes_per_params
,
settings
.
model_settings
.
bytes_per_kv_cache_element
,
settings
.
page_size
,
format_vector
(
settings
.
gpu_device_id
),
readable_number
(
settings
.
gpu_memory_size
),
settings
.
memory_utilization_percentage
,
settings
.
max_batch_size
,
settings
.
recommended_chunk_prefill_token_count
,
settings
.
sched_metrics_port
,
settings
.
kvc2_config_path
,
settings
.
kvc2_root_path
,
settings
.
memory_pool_size_GB
,
settings
.
evict_count
,
settings
.
kvc2_metrics_port
,
settings
.
load_from_disk
,
settings
.
save_to_disk
,
settings
.
strategy_name
,
settings
.
gpu_device_count
);
this
->
settings
=
settings
;
kvc2_maintainer
=
std
::
shared_ptr
<
KVC2_Maintainer
>
(
new
KVC2_Maintainer
(
settings
));
MetricsConfig
met_conf
=
{
.
endpoint
=
"0.0.0.0:"
+
std
::
to_string
(
settings
.
sched_metrics_port
),
.
model_name
=
settings
.
model_name
,
.
gpu_count
=
settings
.
gpu_device_count
,
};
SPDLOG_INFO
(
"Creating scheduler metrics exporter on {}"
,
met_conf
.
endpoint
);
met
=
std
::
make_shared
<
Metrics
>
(
met_conf
);
met
->
fn_every_sec
=
[](
Metrics
*
met
)
{
auto
generated_tokens
=
met
->
generated_tokens
->
Collect
().
counter
.
value
;
SPDLOG_INFO
(
"Last Sec Generated Tokens {}"
,
generated_tokens
);
};
}
Query
::
QueryContext
get_query_context
()
{
return
Query
::
QueryContext
{
.
model_name
=
settings
.
model_name
,
.
quant_type
=
settings
.
quant_type
,
.
kvc2_interface
=
kvc2_maintainer
->
kvc2_interface
.
get
(),
.
query_maintainer
=
this
,
.
met
=
met
.
get
(),
};
}
QueryID
add_query
(
QueryAdd
query_add
)
override
{
std
::
promise
<
QueryID
>
p
;
event_loop_queue
.
enqueue
(
EventAddQuery
(
query_add
,
&
p
));
return
p
.
get_future
().
get
();
}
void
cancel_query
(
QueryID
id
)
override
{
SPDLOG_INFO
(
"Cancel Query"
);
SPDLOG_INFO
(
"sched:{} Cancel Query"
,
fmt
::
ptr
(
this
));
auto
it
=
query_map
.
find
(
id
);
if
(
it
==
query_map
.
end
())
{
SPDLOG_ERROR
(
"Query {} is not found"
,
id
);
return
;
}
query_map
.
erase
(
it
);
}
// Here this function update last batch results and get the next batch
// in most cases, the batch is ready,
// if not, busy wait to get it
std
::
shared_ptr
<
BatchQueryTodo
>
update_last_batch
(
BatchQueryUpdate
updates
)
override
{
event_loop_queue
.
enqueue
(
updates
);
// Busy Wait
while
(
true
)
{
auto
[
ptr
,
is_new
]
=
next_batch
.
touch_load
();
// SPDLOG_INFO("ptr {} is_new {}", fmt::ptr(ptr), is_new);
if
(
is_new
)
{
// SPDLOG_DEBUG("New Batch {}", fmt::ptr(ptr));
auto
re
=
std
::
shared_ptr
<
BatchQueryTodo
>
(
ptr
);
event_loop_queue
.
enqueue
(
re
);
return
re
;
}
else
{
// // here to busy wait
// SPDLOG_INFO("Not New");
// using namespace std::chrono_literals;
// std::this_thread::sleep_for(1s);
}
}
}
InferenceContext
get_inference_context
()
override
{
InferenceContext
re
;
re
.
k_cache
=
kvc2_maintainer
->
k_cache
;
re
.
v_cache
=
kvc2_maintainer
->
v_cache
;
// kvc2_maintainer->k_cache[0][0][0][0][0][0] = 42; // test whether we pass
// this to inference loop
return
re
;
}
virtual
void
strategy_add_query
(
Q
new_query
)
=
0
;
virtual
void
strategy_update_query
(
const
EventUpdateQuery
&
update
)
=
0
;
virtual
void
strategy_taken_batch
(
const
EventTakenBatch
&
batch
)
=
0
;
virtual
void
strategy_prepare
(
const
EventPrepare
&
prepare
)
=
0
;
virtual
void
strategy_prepared
(
const
EventPrepared
&
prepared
)
=
0
;
virtual
void
strategy_query_status
(
const
EventQueryStatus
&
query_status
)
=
0
;
virtual
void
strategy_schedule
(
const
EventSchedule
&
event
,
BatchQueryTodo
*
new_batch
)
=
0
;
void
tackle_event
(
EventAddQuery
&
event
)
{
auto
&
query_add
=
event
.
first
;
QueryID
id
=
query_id_counter
;
event
.
second
->
set_value
(
id
);
query_id_counter
+=
1
;
Q
new_query
(
new
Query
(
id
,
query_add
,
get_query_context
()));
query_map
[
id
]
=
new_query
;
SPDLOG_INFO
(
"New Query {} is added"
,
id
);
strategy_add_query
(
new_query
);
}
void
tackle_event
(
const
EventUpdateQuery
&
update
)
{
// SPDLOG_INFO("Tackle Update Query");
for
(
auto
&
u
:
update
)
{
if
(
u
.
ok
==
false
)
{
SPDLOG_ERROR
(
"Query {} is not exectued OK"
,
u
.
id
);
exit
(
1
);
}
auto
q
=
query_map
[
u
.
id
];
if
(
q
->
plan_status
==
Query
::
Status
::
Prefill
||
q
->
plan_status
==
Query
::
Status
::
Decode
)
{
q
->
absorb_update
(
u
);
}
else
{
SPDLOG_DEBUG
(
"Query {} is not in Prefill or Decode status, do not update it"
,
u
.
id
);
}
}
strategy_update_query
(
update
);
}
void
tackle_event
(
const
EventTakenBatch
&
batch
)
{
met
->
batch_count
(
"Taken"
)
->
Increment
(
1
);
for
(
auto
&
task
:
batch
->
prefill_mini_batches
)
{
auto
[
id
,
s
,
l
]
=
task
;
if
(
l
==
0
)
continue
;
query_map
.
at
(
id
)
->
absorb_prefill_task
(
task
);
}
for
(
auto
&
mini_batch
:
batch
->
decode_mini_batches
)
{
for
(
auto
&
id
:
mini_batch
)
{
query_map
.
at
(
id
)
->
absorb_decode_task
(
id
);
}
}
strategy_taken_batch
(
batch
);
}
void
tackle_event
(
const
EventPrepare
&
event
)
{
strategy_prepare
(
event
);
}
void
tackle_event
(
const
EventPrepared
&
event
)
{
strategy_prepared
(
event
);
}
void
tackle_event
(
const
EventQueryStatus
&
event
)
{
strategy_query_status
(
event
);
}
void
tackle_event
(
const
EventSchedule
&
event
)
{
// SPDLOG_INFO("Tackle Schedule Event");
HistogramTimerWrapper
t
(
met
->
schedule_time
);
BatchQueryTodo
*
new_batch
=
new
BatchQueryTodo
;
strategy_schedule
(
event
,
new_batch
);
// if (new_batch->query_ids.empty()) {
// SPDLOG_INFO("Nothing todo");
// delete new_batch;
// return;
// }
auto
[
old_batch
,
flag
]
=
next_batch
.
exchange
(
new_batch
,
true
);
if
(
new_batch
->
empty
()
==
false
)
{
SPDLOG_DEBUG
(
"set new batch {}"
,
fmt
::
ptr
(
new_batch
));
}
if
(
flag
)
{
SPDLOG_INFO
(
"Batch {} is not consumed"
,
fmt
::
ptr
(
old_batch
));
delete
old_batch
;
}
}
void
run
()
override
{
std
::
thread
([
this
]()
{
SPDLOG_WARN
(
"Starting Scheduler Event Loop"
);
while
(
stop_flag
.
load
()
==
false
)
{
auto
event
=
event_loop_queue
.
dequeue
();
met
->
event_count
(
event_name
(
event
))
->
Increment
(
1
);
std
::
visit
(
[
this
](
auto
event
)
{
using
T
=
std
::
decay_t
<
decltype
(
event
)
>
;
// SPDLOG_INFO("Event Loop: {}", typeid(T).name());
if
constexpr
(
std
::
is_same_v
<
T
,
EventAddQuery
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventUpdateQuery
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventTakenBatch
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventPrepare
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventPrepared
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventQueryStatus
>
)
{
tackle_event
(
event
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
EventSchedule
>
)
{
tackle_event
(
event
);
}
else
{
SPDLOG_ERROR
(
"Should not be here"
);
assert
(
false
);
}
},
event
);
if
(
event_loop_queue
.
size
()
==
0
&&
std
::
holds_alternative
<
EventSchedule
>
(
event
)
==
false
)
{
// if this is not a schedule event, we need to schedule one
event_loop_queue
.
enqueue
(
EventSchedule
());
}
}
}).
detach
();
}
void
stop
()
override
{
stop_flag
.
store
(
true
);
}
~
QueryMaintainer
()
{
kvc2_maintainer
->
kvc2_interface
->
save
();
stop
();
}
};
void
Query
::
to_status
(
Status
to
)
{
SPDLOG_DEBUG
(
"Calling to status query {}, to {}"
,
id
,
status_to_string
(
to
));
switch
(
to
)
{
case
Received
:
assert
(
false
);
break
;
case
Preparing
:
SPDLOG_INFO
(
"Preparing Query {} {}"
,
id
,
prepare_try_count
>
0
?
(
std
::
to_string
(
prepare_try_count
)
+
" Try"
)
:
""
);
prepare_try_count
+=
1
;
ctx
.
kvc2_interface
->
lookup_to_gpu_async
(
ctx
.
model_name
,
ctx
.
quant_type
,
static_cast
<
kvc2
::
Token
*>
(
query_token
.
data_ptr
()),
prompt_length
,
estimated_length
,
[
this
](
std
::
shared_ptr
<
kvc2
::
DoubleCacheHandleInterface
>
handle
)
{
if
(
handle
==
nullptr
)
{
SPDLOG_INFO
(
"Get handle from kvc2 Failed."
);
this
->
after_load
(
false
);
}
else
{
SPDLOG_INFO
(
"Get handle from kvc2 Success."
);
this
->
kvc2_handle
=
handle
;
this
->
to_status
(
Ready
);
this
->
after_load
(
true
);
}
});
break
;
case
Ready
:
SPDLOG_INFO
(
"Ready Query {}"
,
id
);
break
;
case
Prefill
:
SPDLOG_INFO
(
"Prefilling Query {}"
,
id
);
// assert(plan_status == Received);
plan_position
=
kvc2_handle
->
matched_length
();
if
(
prompt_length
-
plan_position
==
0
)
{
assert
(
prompt_length
>
0
);
plan_position
-=
1
;
}
break
;
case
Decode
:
SPDLOG_INFO
(
"Decoding Query {}"
,
id
);
// assert(plan_status == Prefill);
break
;
case
Done
:
SPDLOG_INFO
(
"Finish Query {}"
,
id
);
kvc2_handle
=
nullptr
;
ctx
.
query_maintainer
->
event_loop_queue
.
enqueue
(
EventQueryStatus
{
.
query_id
=
id
,
.
now_status
=
to
,
});
// assert(plan_status == Decode);
break
;
}
plan_status
=
to
;
export_metrics
();
}
void
Query
::
after_load
(
bool
ok
)
{
if
(
ok
)
{
size_t
page_count
=
div_up
(
estimated_length
,
ctx
.
query_maintainer
->
settings
.
page_size
);
std
::
vector
<
int64_t
>
shape
;
shape
.
push_back
(
page_count
);
block_index
=
torch
::
zeros
(
shape
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
))
.
contiguous
();
auto
ptr
=
reinterpret_cast
<
int32_t
*>
(
block_index
.
data_ptr
());
auto
vec_idx
=
kvc2_handle
->
get_gpu_block_idx
();
for
(
size_t
i
=
0
;
i
<
vec_idx
.
size
();
i
++
)
{
ptr
[
i
]
=
vec_idx
[
i
];
}
no_kvcache_from
=
kvc2_handle
->
matched_length
();
}
if
(
ok
)
{
ctx
.
query_maintainer
->
event_loop_queue
.
enqueue
(
EventPrepared
{
.
query_id
=
id
,
.
ok
=
ok
,
});
}
else
{
ctx
.
query_maintainer
->
event_loop_queue
.
enqueue
(
EventPrepare
{
.
query_id
=
id
,
.
first_try
=
false
,
});
}
}
struct
FCFS_single_prefill
:
public
QueryMaintainer
{
std
::
queue
<
Q
>
queue
;
std
::
queue
<
Q
>
ready_queue
;
bool
has_query_preparing
=
false
;
std
::
optional
<
EventPrepare
>
wait_done_prepare
=
std
::
nullopt
;
std
::
set
<
Q
>
active_query
;
// on going queries for LLMs
// interface all these are executed in a single thread
void
strategy_add_query
(
Q
new_query
)
override
{
queue
.
push
(
new_query
);
if
(
has_query_preparing
==
false
)
{
has_query_preparing
=
true
;
auto
next_q
=
queue
.
front
();
queue
.
pop
();
event_loop_queue
.
enqueue
(
EventPrepare
{
next_q
->
id
,
true
});
}
}
void
strategy_update_query
(
const
EventUpdateQuery
&
update
)
override
{
for
(
auto
u
:
update
)
{
auto
&
q
=
query_map
[
u
.
id
];
if
(
q
->
plan_status
==
Query
::
Done
)
{
active_query
.
erase
(
q
);
}
}
}
void
strategy_taken_batch
(
const
EventTakenBatch
&
batch
)
override
{
for
(
auto
&
q
:
batch
->
query_ids
)
{
if
(
query_map
[
q
]
->
plan_status
!=
Query
::
Done
)
{
active_query
.
insert
(
query_map
[
q
]);
}
}
}
void
strategy_prepare
(
const
EventPrepare
&
prepare
)
override
{
if
(
prepare
.
first_try
)
{
auto
&
q
=
query_map
[
prepare
.
query_id
];
q
->
to_status
(
Query
::
Preparing
);
}
else
{
assert
(
wait_done_prepare
.
has_value
()
==
false
);
wait_done_prepare
=
prepare
;
wait_done_prepare
->
first_try
=
true
;
}
}
void
strategy_prepared
(
const
EventPrepared
&
prepared
)
override
{
assert
(
prepared
.
ok
);
ready_queue
.
push
(
query_map
[
prepared
.
query_id
]);
if
(
queue
.
empty
()
==
false
)
{
auto
next_q_prepare
=
queue
.
front
();
queue
.
pop
();
event_loop_queue
.
enqueue
(
EventPrepare
{
next_q_prepare
->
id
,
true
});
}
else
{
has_query_preparing
=
false
;
}
}
void
strategy_query_status
(
const
EventQueryStatus
&
query_status
)
override
{
if
(
query_status
.
now_status
==
Query
::
Done
)
{
if
(
wait_done_prepare
.
has_value
())
{
event_loop_queue
.
enqueue
(
wait_done_prepare
.
value
());
wait_done_prepare
=
std
::
nullopt
;
}
}
}
void
strategy_schedule
([[
maybe_unused
]]
const
EventSchedule
&
event
,
BatchQueryTodo
*
new_batch
)
override
{
bool
have_prefill
=
false
;
for
(
auto
&
q
:
active_query
)
{
if
(
q
->
plan_status
==
Query
::
Prefill
)
{
have_prefill
=
true
;
}
}
if
(
have_prefill
==
false
&&
ready_queue
.
empty
()
==
false
&&
active_query
.
size
()
<
settings
.
max_batch_size
)
{
auto
&
next_q
=
ready_queue
.
front
();
ready_queue
.
pop
();
SPDLOG_INFO
(
"Active query {}"
,
next_q
->
id
);
active_query
.
insert
(
next_q
);
next_q
->
to_status
(
Query
::
Prefill
);
}
if
(
active_query
.
empty
()
==
false
)
SPDLOG_INFO
(
"Active Query Size {}"
,
active_query
.
size
());
for
(
auto
&
q
:
active_query
)
{
q
->
debug
();
}
gen_batch_query_todo
(
new_batch
,
active_query
);
}
};
struct
FCFS
:
public
FCFS_single_prefill
{
void
strategy_schedule
([[
maybe_unused
]]
const
EventSchedule
&
event
,
BatchQueryTodo
*
new_batch
)
override
{
int
prefill_count
=
0
;
const
int
max_prefill_count
=
2
;
for
(
auto
&
q
:
active_query
)
{
if
(
q
->
plan_status
==
Query
::
Prefill
)
{
prefill_count
+=
1
;
}
}
while
(
prefill_count
<
max_prefill_count
&&
ready_queue
.
empty
()
==
false
&&
active_query
.
size
()
<
settings
.
max_batch_size
)
{
auto
next_q
=
ready_queue
.
front
();
ready_queue
.
pop
();
SPDLOG_INFO
(
"Active query {}"
,
next_q
->
id
);
active_query
.
insert
(
next_q
);
next_q
->
to_status
(
Query
::
Prefill
);
prefill_count
+=
1
;
}
if
(
active_query
.
empty
()
==
false
)
{
SPDLOG_DEBUG
(
"Active Query Size {}"
,
active_query
.
size
());
}
for
(
auto
&
q
:
active_query
)
{
q
->
debug
();
}
gen_batch_query_todo
(
new_batch
,
active_query
);
}
};
std
::
shared_ptr
<
Scheduler
>
create_scheduler
(
Settings
settings
)
{
spdlog
::
set_level
(
spdlog
::
level
::
debug
);
std
::
shared_ptr
<
Scheduler
>
re
;
SPDLOG_INFO
(
"Using Strategy {}"
,
settings
.
strategy_name
);
if
(
settings
.
strategy_name
==
"FCFS-single-prefill"
)
{
re
=
std
::
shared_ptr
<
Scheduler
>
(
new
FCFS_single_prefill
());
}
else
if
(
settings
.
strategy_name
==
"FCFS"
)
{
re
=
std
::
shared_ptr
<
Scheduler
>
(
new
FCFS
());
}
else
{
SPDLOG_ERROR
(
"Unknown strategy {}"
,
settings
.
strategy_name
);
}
re
->
init
(
settings
);
return
re
;
}
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE
(
SampleOptions
,
temperature
,
top_p
);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE
(
QueryAdd
,
query_token
,
query_length
,
estimated_length
,
sample_options
,
user_id
,
SLO_TTFT_ms
,
SLO_TBT_ms
);
std
::
string
QueryAdd
::
serialize
()
{
json
j
=
*
this
;
return
j
.
dump
();
}
QueryAdd
QueryAdd
::
deserialize
(
const
std
::
string
&
input
)
{
json
j
=
json
::
parse
(
input
);
return
j
.
get
<
QueryAdd
>
();
}
};
// namespace scheduler
csrc/balance_serve/sched/scheduler.h
0 → 100644
View file @
877aec85
#pragma once
#include "model_config.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <torch/torch.h>
#include <vector>
namespace
scheduler
{
using
Token
=
uint32_t
;
using
QueryID
=
uint64_t
;
constexpr
QueryID
NoQueryID
=
0
;
using
TokenLength
=
size_t
;
using
BatchID
=
uint64_t
;
using
PageCount
=
size_t
;
struct
ModelSettings
{
std
::
string
model_path
;
size_t
params_count
;
size_t
layer_count
;
size_t
num_k_heads
;
size_t
k_head_dim
;
double
bytes_per_params
;
double
bytes_per_kv_cache_element
;
inline
size_t
params_nbytes
()
{
return
params_count
*
bytes_per_params
;
}
inline
size_t
bytes_per_token_kv_cache
()
{
return
bytes_per_kv_cache_element
*
num_k_heads
*
k_head_dim
;
}
};
struct
SampleOptions
{
double
temperature
=
1.0
;
double
top_p
=
1.0
;
};
struct
Settings
{
// something is aukward here, kvc2 only use model_name and quant_type to get
// model infos.
ModelName
model_name
;
QuantType
quant_type
;
// model_setting is ignore by kvc2
ModelSettings
model_settings
;
size_t
page_size
=
256
;
// how many token in a page
std
::
vector
<
size_t
>
gpu_device_id
;
//
size_t
gpu_memory_size
;
// memory size in bytes of each GPU, each
double
memory_utilization_percentage
;
size_t
max_batch_size
=
256
;
size_t
recommended_chunk_prefill_token_count
;
SampleOptions
sample_options
;
size_t
sched_metrics_port
;
// for kvc2
bool
gpu_only
;
bool
use_self_defined_head_dim
=
false
;
size_t
self_defined_head_dim
;
bool
full_kv_cache_on_each_gpu
=
false
;
bool
k_cache_on
=
true
;
bool
v_cache_on
=
true
;
std
::
string
kvc2_config_path
;
std
::
string
kvc2_root_path
;
double
memory_pool_size_GB
=
100
;
size_t
evict_count
=
20
;
size_t
kvc2_metrics_port
;
bool
load_from_disk
=
false
;
bool
save_to_disk
=
false
;
// for strategy
std
::
string
strategy_name
;
// derived
size_t
gpu_device_count
;
std
::
optional
<
size_t
>
total_kvcache_pages
;
std
::
vector
<
torch
::
Device
>
devices
;
void
auto_derive
();
};
using
PrefillTask
=
std
::
tuple
<
QueryID
,
TokenLength
,
TokenLength
>
;
// id, start, length
struct
BatchQueryTodo
{
// query
std
::
vector
<
QueryID
>
query_ids
;
std
::
vector
<
torch
::
Tensor
>
query_tokens
;
std
::
vector
<
TokenLength
>
query_lengths
;
std
::
vector
<
torch
::
Tensor
>
block_indexes
;
// (max_num_blocks_per_seq), dtype torch.int32.
std
::
optional
<
torch
::
Tensor
>
attn_masks
;
std
::
optional
<
torch
::
Tensor
>
rope_ranges
;
std
::
vector
<
SampleOptions
>
sample_options
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
stop_criteria
;
// mini batches, adjacent two mini batches are executed together
// tasks count must be <=2, because of flash infer attention
std
::
vector
<
PrefillTask
>
prefill_mini_batches
;
// prefill minibatch only has 1 prefill
std
::
vector
<
std
::
vector
<
QueryID
>>
decode_mini_batches
;
// decode minibatch has multiple decode
std
::
string
debug
();
bool
empty
();
};
struct
QueryUpdate
{
QueryID
id
;
bool
ok
;
bool
is_prefill
;
bool
decode_done
;
// no use for now
TokenLength
active_position
;
// the position where no kvcache now,
// kvcache[active_position] == None
Token
generated_token
;
std
::
string
debug
()
const
;
};
using
BatchQueryUpdate
=
std
::
vector
<
QueryUpdate
>
;
struct
InferenceContext
{
std
::
vector
<
torch
::
Tensor
>
k_cache
;
// [gpu num] (layer_count, num blocks,
// page size, kheadnum, head_dim)
std
::
vector
<
torch
::
Tensor
>
v_cache
;
};
using
UserID
=
int64_t
;
constexpr
UserID
NoUser
=
-
1
;
const
int
MAX_SLO_TIME
=
1e9
;
struct
QueryAdd
{
std
::
vector
<
Token
>
query_token
;
// int here
// torch::Tensor attn_mask;
TokenLength
query_length
;
TokenLength
estimated_length
;
std
::
vector
<
std
::
vector
<
int
>>
stop_criteria
;
SampleOptions
sample_options
;
UserID
user_id
;
int
SLO_TTFT_ms
=
MAX_SLO_TIME
;
int
SLO_TBT_ms
=
MAX_SLO_TIME
;
std
::
string
serialize
();
static
QueryAdd
deserialize
(
const
std
::
string
&
input
);
};
class
Scheduler
{
public:
virtual
void
init
(
Settings
settings
)
=
0
;
virtual
void
run
()
=
0
;
virtual
void
stop
()
=
0
;
// webserver call this
virtual
QueryID
add_query
(
QueryAdd
query
)
=
0
;
virtual
void
cancel_query
(
QueryID
id
)
=
0
;
// inference loop call this
virtual
std
::
shared_ptr
<
BatchQueryTodo
>
update_last_batch
(
BatchQueryUpdate
updates
)
=
0
;
virtual
InferenceContext
get_inference_context
()
=
0
;
virtual
~
Scheduler
()
=
default
;
};
std
::
shared_ptr
<
Scheduler
>
create_scheduler
(
Settings
settings
);
};
// namespace scheduler
\ No newline at end of file
csrc/balance_serve/sched/utils/all.hpp
0 → 100644
View file @
877aec85
#pragma once
#include "readable_number.hpp"
#include "timer.hpp"
\ No newline at end of file
csrc/balance_serve/sched/utils/arithmetic.hpp
0 → 100644
View file @
877aec85
#include <type_traits>
template
<
typename
T
,
typename
U
>
T
div_up
(
T
x
,
U
by
)
{
static_assert
(
std
::
is_integral_v
<
T
>
);
static_assert
(
std
::
is_integral_v
<
U
>
);
return
(
x
+
by
-
1
)
/
by
;
}
\ No newline at end of file
csrc/balance_serve/sched/utils/atomic_ptr_with_flags.hpp
0 → 100644
View file @
877aec85
#include <atomic>
template
<
typename
T
>
struct
AtomicPtrWithFlag
{
constexpr
static
uint64_t
mask
=
1ull
<<
63
;
std
::
atomic_uint64_t
ptr
=
0
;
std
::
pair
<
T
*
,
bool
>
load
(
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
uint64_t
val
=
ptr
.
load
(
order
);
return
{
reinterpret_cast
<
T
*>
(
val
&
(
~
mask
)),
val
&
mask
};
}
void
store
(
T
*
p
,
bool
flag
,
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
ptr
.
store
(
reinterpret_cast
<
uint64_t
>
(
p
)
|
(
flag
?
mask
:
0
),
order
);
}
std
::
pair
<
T
*
,
bool
>
exchange
(
T
*
p
,
bool
flag
,
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
uint64_t
val
=
ptr
.
exchange
(
reinterpret_cast
<
uint64_t
>
(
p
)
|
(
flag
?
mask
:
0
),
order
);
return
{
reinterpret_cast
<
T
*>
(
val
&
(
~
mask
)),
val
&
mask
};
}
std
::
pair
<
T
*
,
bool
>
touch_load
(
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
uint64_t
val
=
ptr
.
fetch_and
(
~
mask
,
order
);
return
{
reinterpret_cast
<
T
*>
(
val
&
(
~
mask
)),
val
&
mask
};
}
bool
check_flag
(
std
::
memory_order
order
=
std
::
memory_order_seq_cst
)
{
return
ptr
.
load
(
order
)
&
mask
;
}
};
csrc/balance_serve/sched/utils/csv.hpp
0 → 100644
View file @
877aec85
#ifndef CSV_READER_HPP
#define CSV_READER_HPP
#include <fstream>
#include <iostream>
#include <mutex>
#include <sstream>
#include <stdexcept>
#include <string>
#include <thread>
#include <vector>
namespace
csv
{
/**
* @brief Parses a CSV line into individual fields, handling quoted fields with
* commas and newlines.
*
* @param line The CSV line to parse.
* @return A vector of strings, each representing a field in the CSV line.
*/
inline
std
::
vector
<
std
::
string
>
parse_csv_line
(
const
std
::
string
&
line
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
string
field
;
bool
in_quotes
=
false
;
for
(
size_t
i
=
0
;
i
<
line
.
length
();
++
i
)
{
char
c
=
line
[
i
];
if
(
c
==
'"'
)
{
// Handle double quotes inside quoted fields
if
(
in_quotes
&&
i
+
1
<
line
.
length
()
&&
line
[
i
+
1
]
==
'"'
)
{
field
+=
'"'
;
++
i
;
}
else
{
in_quotes
=
!
in_quotes
;
}
}
else
if
(
c
==
','
&&
!
in_quotes
)
{
result
.
push_back
(
field
);
field
.
clear
();
}
else
{
field
+=
c
;
}
}
result
.
push_back
(
field
);
return
result
;
}
/**
* @brief Reads a CSV file and returns a vector of pairs containing column names
* and their corresponding data vectors.
*
* This function reads the header to obtain column names and uses multithreading
* to read and parse the CSV file in chunks.
*
* @param filename The path to the CSV file.
* @return A vector of pairs, each containing a column name and a vector of data
* for that column.
*/
inline
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
read_csv
(
const
std
::
string
&
filename
)
{
std
::
cout
<<
"Reading CSV file: "
<<
filename
<<
std
::
endl
;
// Open the file
std
::
ifstream
file
(
filename
);
if
(
!
file
)
{
throw
std
::
runtime_error
(
"Cannot open file"
);
}
// Read the header line and parse column names
std
::
string
header_line
;
std
::
getline
(
file
,
header_line
);
std
::
vector
<
std
::
string
>
column_names
=
parse_csv_line
(
header_line
);
// Prepare the result vector with column names
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
result
;
for
(
const
auto
&
name
:
column_names
)
{
result
.
emplace_back
(
name
,
std
::
vector
<
std
::
string
>
());
}
// Read the rest of the file into a string buffer
std
::
stringstream
buffer
;
buffer
<<
file
.
rdbuf
();
std
::
string
content
=
buffer
.
str
();
// Determine the number of threads to use
unsigned
int
num_threads
=
std
::
thread
::
hardware_concurrency
();
if
(
num_threads
==
0
)
num_threads
=
4
;
// Default to 4 threads if hardware_concurrency returns 0
// Calculate chunk start positions based on content size
std
::
vector
<
size_t
>
chunk_starts
;
size_t
content_size
=
content
.
size
();
size_t
chunk_size
=
content_size
/
num_threads
;
chunk_starts
.
push_back
(
0
);
for
(
unsigned
int
i
=
1
;
i
<
num_threads
;
++
i
)
{
size_t
pos
=
i
*
chunk_size
;
// Adjust position to the next newline character to ensure we start at the
// beginning of a line
while
(
pos
<
content_size
&&
content
[
pos
]
!=
'\n'
)
{
++
pos
;
}
if
(
pos
<
content_size
)
{
++
pos
;
// Skip the newline character
}
chunk_starts
.
push_back
(
pos
);
}
chunk_starts
.
push_back
(
content_size
);
// Create threads to parse each chunk
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
string
>>>
thread_results
(
num_threads
);
std
::
vector
<
std
::
thread
>
threads
;
for
(
unsigned
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
size_t
start
=
chunk_starts
[
i
];
size_t
end
=
chunk_starts
[
i
+
1
];
threads
.
emplace_back
([
&
content
,
start
,
end
,
&
thread_results
,
i
]()
{
std
::
vector
<
std
::
vector
<
std
::
string
>>
local_result
;
size_t
pos
=
start
;
while
(
pos
<
end
)
{
size_t
next_pos
=
content
.
find
(
'\n'
,
pos
);
if
(
next_pos
==
std
::
string
::
npos
||
next_pos
>
end
)
{
next_pos
=
end
;
}
std
::
string
line
=
content
.
substr
(
pos
,
next_pos
-
pos
);
if
(
!
line
.
empty
())
{
local_result
.
push_back
(
parse_csv_line
(
line
));
}
pos
=
next_pos
+
1
;
}
thread_results
[
i
]
=
std
::
move
(
local_result
);
});
}
// Wait for all threads to finish
for
(
auto
&
t
:
threads
)
{
t
.
join
();
}
// Combine the results from all threads into the final result
for
(
const
auto
&
local_result
:
thread_results
)
{
for
(
const
auto
&
row
:
local_result
)
{
for
(
size_t
i
=
0
;
i
<
row
.
size
();
++
i
)
{
if
(
i
<
result
.
size
())
{
result
[
i
].
second
.
push_back
(
row
[
i
]);
}
}
}
}
return
result
;
}
/**
* @brief Writes the CSV data into a file.
*
* @param filename The path to the output CSV file.
* @param data A vector of pairs, each containing a column name and a vector of
* data for that column.
*/
inline
void
write_csv
(
const
std
::
string
&
filename
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>>
&
data
)
{
std
::
cout
<<
"Writing CSV file: "
<<
filename
<<
std
::
endl
;
// Open the file for writing
std
::
ofstream
file
(
filename
);
if
(
!
file
)
{
throw
std
::
runtime_error
(
"Cannot open file for writing"
);
}
// Check that all columns have the same number of rows
if
(
data
.
empty
())
{
return
;
// Nothing to write
}
size_t
num_rows
=
data
[
0
].
second
.
size
();
for
(
const
auto
&
column
:
data
)
{
if
(
column
.
second
.
size
()
!=
num_rows
)
{
throw
std
::
runtime_error
(
"All columns must have the same number of rows"
);
}
}
// Write the header
for
(
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
{
file
<<
data
[
i
].
first
;
if
(
i
!=
data
.
size
()
-
1
)
{
file
<<
','
;
}
}
file
<<
'\n'
;
// Write the data rows
for
(
size_t
row
=
0
;
row
<
num_rows
;
++
row
)
{
for
(
size_t
col
=
0
;
col
<
data
.
size
();
++
col
)
{
const
std
::
string
&
field
=
data
[
col
].
second
[
row
];
// Handle CSV escaping
std
::
string
escaped_field
=
field
;
bool
needs_quotes
=
false
;
if
(
escaped_field
.
find
(
'"'
)
!=
std
::
string
::
npos
)
{
needs_quotes
=
true
;
// Escape double quotes
size_t
pos
=
0
;
while
((
pos
=
escaped_field
.
find
(
'"'
,
pos
))
!=
std
::
string
::
npos
)
{
escaped_field
.
insert
(
pos
,
"
\"
"
);
pos
+=
2
;
}
}
if
(
escaped_field
.
find
(
','
)
!=
std
::
string
::
npos
||
escaped_field
.
find
(
'\n'
)
!=
std
::
string
::
npos
)
{
needs_quotes
=
true
;
}
if
(
needs_quotes
)
{
file
<<
'"'
<<
escaped_field
<<
'"'
;
}
else
{
file
<<
escaped_field
;
}
if
(
col
!=
data
.
size
()
-
1
)
{
file
<<
','
;
}
}
file
<<
'\n'
;
}
}
}
// namespace csv
#endif // CSV_READER_HPP
csrc/balance_serve/sched/utils/easy_format.hpp
0 → 100644
View file @
877aec85
#include <sstream>
#include <string>
#include <vector>
template
<
typename
T
>
std
::
string
format_vector
(
const
std
::
vector
<
T
>
&
v
)
{
std
::
ostringstream
oss
;
if
(
v
.
empty
())
return
"[]"
;
for
(
size_t
i
=
0
;
i
<
v
.
size
();
++
i
)
{
oss
<<
v
[
i
];
if
(
i
<
v
.
size
()
-
1
)
oss
<<
", "
;
// 逗号分隔
}
return
oss
.
str
();
}
csrc/balance_serve/sched/utils/mpsc.hpp
0 → 100644
View file @
877aec85
#include <atomic>
#include <cassert>
#include <iostream>
#include <optional>
#include <semaphore>
template
<
typename
T
>
class
MPSCQueue
{
struct
Node
{
T
data
;
std
::
atomic
<
Node
*>
next
;
Node
()
:
next
(
nullptr
)
{}
Node
(
T
data_
)
:
data
(
std
::
move
(
data_
)),
next
(
nullptr
)
{}
};
std
::
atomic
<
Node
*>
head
;
Node
*
tail
;
public:
std
::
atomic_size_t
enqueue_count
=
0
;
size_t
dequeue_count
=
0
;
MPSCQueue
()
{
Node
*
dummy
=
new
Node
();
head
.
store
(
dummy
,
std
::
memory_order_seq_cst
);
tail
=
dummy
;
}
~
MPSCQueue
()
{
Node
*
node
=
tail
;
while
(
node
)
{
Node
*
next
=
node
->
next
.
load
(
std
::
memory_order_seq_cst
);
delete
node
;
node
=
next
;
}
}
// 生产者调用
void
enqueue
(
T
data
)
{
enqueue_count
.
fetch_add
(
1
);
Node
*
node
=
new
Node
(
std
::
move
(
data
));
Node
*
prev_head
=
head
.
exchange
(
node
,
std
::
memory_order_seq_cst
);
prev_head
->
next
.
store
(
node
,
std
::
memory_order_seq_cst
);
}
// 消费者调用
std
::
optional
<
T
>
dequeue
()
{
Node
*
next
=
tail
->
next
.
load
(
std
::
memory_order_seq_cst
);
if
(
next
)
{
T
res
=
std
::
move
(
next
->
data
);
delete
tail
;
tail
=
next
;
dequeue_count
+=
1
;
return
res
;
}
return
std
::
nullopt
;
}
size_t
size
()
{
return
enqueue_count
.
load
()
-
dequeue_count
;
}
};
template
<
typename
T
>
class
MPSCQueueConsumerLock
{
MPSCQueue
<
T
>
queue
;
std
::
counting_semaphore
<>
sema
{
0
};
public:
void
enqueue
(
T
data
)
{
queue
.
enqueue
(
std
::
move
(
data
));
// std::atomic_thread_fence(std::memory_order_seq_cst);// Inserting this
// because the memory order might be wrong, I am also not that sure about
// this.
sema
.
release
();
}
T
dequeue
()
{
auto
re
=
queue
.
dequeue
();
if
(
re
.
has_value
())
{
while
(
sema
.
try_acquire
()
==
false
)
{
std
::
cerr
<<
__FILE__
<<
":"
<<
__FUNCTION__
<<
" sema try acquire should be success, retrying, please check"
<<
std
::
endl
;
// assert(false);
}
return
re
.
value
();
}
sema
.
acquire
();
return
queue
.
dequeue
().
value
();
}
template
<
typename
Rep
,
typename
Period
>
std
::
optional
<
T
>
try_dequeue_for
(
std
::
chrono
::
duration
<
Rep
,
Period
>
dur
)
{
auto
re
=
queue
.
dequeue
();
if
(
re
.
has_value
())
{
while
(
sema
.
try_acquire
()
==
false
)
{
std
::
cerr
<<
__FILE__
<<
":"
<<
__FUNCTION__
<<
" sema try acquire should be success, retrying, please check"
<<
std
::
endl
;
// assert(false);
}
return
re
.
value
();
}
if
(
sema
.
try_acquire_for
(
dur
))
{
return
queue
.
dequeue
().
value
();
}
else
{
return
std
::
nullopt
;
}
}
size_t
size
()
{
return
queue
.
size
();
}
};
csrc/balance_serve/sched/utils/readable_number.hpp
0 → 100644
View file @
877aec85
#pragma once
#include <array>
#include <iomanip>
#include <sstream>
#include <string>
inline
std
::
array
<
std
::
string
,
7
>
units
=
{
""
,
"K"
,
"M"
,
"G"
,
"T"
,
"P"
,
"E"
};
inline
std
::
string
readable_number
(
size_t
size
)
{
size_t
unit_index
=
0
;
double
readable_size
=
size
;
while
(
readable_size
>=
1000
&&
unit_index
<
units
.
size
()
-
1
)
{
readable_size
/=
1000
;
unit_index
++
;
}
std
::
ostringstream
ss
;
ss
<<
std
::
fixed
<<
std
::
setprecision
(
2
)
<<
readable_size
;
std
::
string
str
=
ss
.
str
();
return
str
+
""
+
units
[
unit_index
];
}
\ No newline at end of file
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment