Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
64de7843
Commit
64de7843
authored
Apr 08, 2025
by
qiyuxinlin
Browse files
format kvc2, delete quant_configs, move model_configs to ~/.ktransformers
parent
9dd24ecd
Changes
31
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
744 additions
and
625 deletions
+744
-625
csrc/balance_serve/kvc2/src/async_store.cpp
csrc/balance_serve/kvc2/src/async_store.cpp
+9
-9
csrc/balance_serve/kvc2/src/gpu_cache.cpp
csrc/balance_serve/kvc2/src/gpu_cache.cpp
+5
-5
csrc/balance_serve/kvc2/src/metrics.h
csrc/balance_serve/kvc2/src/metrics.h
+5
-5
csrc/balance_serve/kvc2/src/model_config.h
csrc/balance_serve/kvc2/src/model_config.h
+38
-22
csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp
csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp
+11
-9
csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h
csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h
+8
-7
csrc/balance_serve/kvc2/src/prefix.cpp
csrc/balance_serve/kvc2/src/prefix.cpp
+11
-13
csrc/balance_serve/kvc2/test/kvc2test/common.hpp
csrc/balance_serve/kvc2/test/kvc2test/common.hpp
+4
-4
csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp
csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp
+2
-2
csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp
...serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp
+1
-1
csrc/balance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp
...alance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp
+1
-2
csrc/balance_serve/kvc2/test/page_pool_test.cpp
csrc/balance_serve/kvc2/test/page_pool_test.cpp
+7
-9
csrc/balance_serve/kvc2/test/test_periodic_task.cpp
csrc/balance_serve/kvc2/test/test_periodic_task.cpp
+97
-105
csrc/balance_serve/sched/bind.cpp
csrc/balance_serve/sched/bind.cpp
+105
-64
csrc/balance_serve/sched/metrics.cpp
csrc/balance_serve/sched/metrics.cpp
+68
-56
csrc/balance_serve/sched/metrics.h
csrc/balance_serve/sched/metrics.h
+31
-28
csrc/balance_serve/sched/model_config.h
csrc/balance_serve/sched/model_config.h
+28
-22
csrc/balance_serve/sched/scheduler.cpp
csrc/balance_serve/sched/scheduler.cpp
+283
-239
csrc/balance_serve/sched/scheduler.h
csrc/balance_serve/sched/scheduler.h
+29
-21
csrc/balance_serve/sched/utils/arithmetic.hpp
csrc/balance_serve/sched/utils/arithmetic.hpp
+1
-2
No files found.
csrc/balance_serve/kvc2/src/async_store.cpp
View file @
64de7843
...
@@ -35,23 +35,23 @@ struct ArrayStore {
...
@@ -35,23 +35,23 @@ struct ArrayStore {
if
(
to
<=
size
)
{
if
(
to
<=
size
)
{
return
;
return
;
}
}
//TODO: extend file
//
TODO: extend file
size
=
to
;
size
=
to
;
//LOG_INFO("Extend file to `, size `", to, size_in_bytes());
//
LOG_INFO("Extend file to `, size `", to, size_in_bytes());
}
}
ArrayStore
(
size_t
element_size
,
size_t
size
,
std
::
filesystem
::
path
data_path
)
ArrayStore
(
size_t
element_size
,
size_t
size
,
std
::
filesystem
::
path
data_path
)
:
element_size
(
element_size
),
:
element_size
(
element_size
),
element_size_aligned
((
element_size
+
DeviceBlockSize
-
1
)
/
DeviceBlockSize
),
element_size_aligned
((
element_size
+
DeviceBlockSize
-
1
)
/
DeviceBlockSize
),
data_path
(
data_path
)
{
data_path
(
data_path
)
{
//TODO: prefix cache
//
TODO: prefix cache
}
}
void
read
(
size_t
index
,
void
*
buffer
)
{
void
read
(
size_t
index
,
void
*
buffer
)
{
//TODO: read from file
//
TODO: read from file
}
}
void
write
(
size_t
index
,
void
*
buffer
)
{
void
write
(
size_t
index
,
void
*
buffer
)
{
//TODO: write to file
//
TODO: write to file
}
}
};
};
...
@@ -98,15 +98,15 @@ struct IODealerImpl {
...
@@ -98,15 +98,15 @@ struct IODealerImpl {
IODealerImpl
(
bool
use_io_uring
,
int
IO_DEPTH
)
:
use_io_uring
(
use_io_uring
),
IO_DEPTH
(
IO_DEPTH
)
{}
IODealerImpl
(
bool
use_io_uring
,
int
IO_DEPTH
)
:
use_io_uring
(
use_io_uring
),
IO_DEPTH
(
IO_DEPTH
)
{}
void
queue_consumer
()
{
void
queue_consumer
()
{
//TODO:
//
TODO:
}
}
void
io_perf
()
{
void
io_perf
()
{
//TODO:
//
TODO:
}
}
void
io_dealer
()
{
void
io_dealer
()
{
//TODO:
//
TODO:
}
}
};
};
...
@@ -130,7 +130,7 @@ void IODealer::stop() {
...
@@ -130,7 +130,7 @@ void IODealer::stop() {
if
(
io_impl
->
stop
)
{
if
(
io_impl
->
stop
)
{
return
;
return
;
}
}
//LOG_INFO("Stopping IO Dealer");
//
LOG_INFO("Stopping IO Dealer");
io_impl
->
stop
=
true
;
io_impl
->
stop
=
true
;
}
}
...
...
csrc/balance_serve/kvc2/src/gpu_cache.cpp
View file @
64de7843
...
@@ -77,7 +77,6 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
...
@@ -77,7 +77,6 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
gpu_only_occupations
.
resize
(
config
.
total_kvcache_pages
,
false
);
gpu_only_occupations
.
resize
(
config
.
total_kvcache_pages
,
false
);
}
}
num_free_pages
=
config
.
total_kvcache_pages
;
num_free_pages
=
config
.
total_kvcache_pages
;
for
(
size_t
i
=
0
;
i
<
config
.
layer_count
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
config
.
layer_count
;
i
++
)
{
if
(
config
.
k_cache_on
)
if
(
config
.
k_cache_on
)
...
@@ -248,18 +247,19 @@ void GPUPageCache::append_col_to_request(std::vector<std::shared_ptr<CudaStreamM
...
@@ -248,18 +247,19 @@ void GPUPageCache::append_col_to_request(std::vector<std::shared_ptr<CudaStreamM
auto
gpu_block_idx
=
k_handles
[
0
][
at
]
->
gpu_block_idx
.
value
();
auto
gpu_block_idx
=
k_handles
[
0
][
at
]
->
gpu_block_idx
.
value
();
for
(
size_t
layer
=
0
;
layer
<
config
.
layer_count
;
layer
++
)
{
for
(
size_t
layer
=
0
;
layer
<
config
.
layer_count
;
layer
++
)
{
for
(
size_t
which_gpu
=
0
;
which_gpu
<
config
.
gpu_devices_id
.
size
();
which_gpu
++
)
{
for
(
size_t
which_gpu
=
0
;
which_gpu
<
config
.
gpu_devices_id
.
size
();
which_gpu
++
)
{
if
(
config
.
k_cache_on
)
{
if
(
config
.
k_cache_on
)
{
assert
(
k_handles
[
layer
][
at
]
->
data
!=
nullptr
);
assert
(
k_handles
[
layer
][
at
]
->
data
!=
nullptr
);
reqs
[
which_gpu
]
->
sizes
.
push_back
(
tp_size
[
which_gpu
]);
reqs
[
which_gpu
]
->
sizes
.
push_back
(
tp_size
[
which_gpu
]);
reqs
[
which_gpu
]
->
host_mem_addresses
.
push_back
(
offset_by_bytes
(
k_handles
[
layer
][
at
]
->
data
,
tp_offset
[
which_gpu
]));
reqs
[
which_gpu
]
->
host_mem_addresses
.
push_back
(
offset_by_bytes
(
k_handles
[
layer
][
at
]
->
data
,
tp_offset
[
which_gpu
]));
reqs
[
which_gpu
]
->
device_mem_addresses
.
push_back
(
k_cache
[
which_gpu
][
layer
][
gpu_block_idx
].
data_ptr
());
reqs
[
which_gpu
]
->
device_mem_addresses
.
push_back
(
k_cache
[
which_gpu
][
layer
][
gpu_block_idx
].
data_ptr
());
}
}
if
(
config
.
v_cache_on
)
{
if
(
config
.
v_cache_on
)
{
assert
(
v_handles
[
layer
][
at
]
->
data
!=
nullptr
);
assert
(
v_handles
[
layer
][
at
]
->
data
!=
nullptr
);
reqs
[
which_gpu
]
->
sizes
.
push_back
(
tp_size
[
which_gpu
]);
reqs
[
which_gpu
]
->
sizes
.
push_back
(
tp_size
[
which_gpu
]);
reqs
[
which_gpu
]
->
host_mem_addresses
.
push_back
(
offset_by_bytes
(
v_handles
[
layer
][
at
]
->
data
,
tp_offset
[
which_gpu
]));
reqs
[
which_gpu
]
->
host_mem_addresses
.
push_back
(
offset_by_bytes
(
v_handles
[
layer
][
at
]
->
data
,
tp_offset
[
which_gpu
]));
reqs
[
which_gpu
]
->
device_mem_addresses
.
push_back
(
v_cache
[
which_gpu
][
layer
][
gpu_block_idx
].
data_ptr
());
reqs
[
which_gpu
]
->
device_mem_addresses
.
push_back
(
v_cache
[
which_gpu
][
layer
][
gpu_block_idx
].
data_ptr
());
}
}
}
}
...
...
csrc/balance_serve/kvc2/src/metrics.h
View file @
64de7843
#pragma once
#pragma once
#include "prometheus/counter.h"
#include "prometheus/exposer.h"
#include "prometheus/gauge.h"
#include "prometheus/histogram.h"
#include "prometheus/registry.h"
#include <atomic>
#include <atomic>
#include <chrono>
#include <chrono>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <thread>
#include <thread>
#include <vector>
#include <vector>
#include "prometheus/counter.h"
#include "prometheus/exposer.h"
#include "prometheus/gauge.h"
#include "prometheus/histogram.h"
#include "prometheus/registry.h"
#include "utils/timer.hpp"
#include "utils/timer.hpp"
...
...
csrc/balance_serve/kvc2/src/model_config.h
View file @
64de7843
#ifndef __MODEL_CONFIG_HPP_
#ifndef __MODEL_CONFIG_HPP_
#define __MODEL_CONFIG_HPP_
#define __MODEL_CONFIG_HPP_
#include <iostream>
#include "nlohmann/json.hpp"
#include "nlohmann/json.hpp"
#include <iostream>
#include <filesystem>
#include <filesystem>
#include <fstream>
#include <fstream>
...
@@ -13,7 +13,7 @@ using ModelName = std::string;
...
@@ -13,7 +13,7 @@ using ModelName = std::string;
// We must assure this can be load by config.json
// We must assure this can be load by config.json
class
ModelConfig
{
class
ModelConfig
{
public:
public:
DimSize
hidden_size
;
DimSize
hidden_size
;
DimSize
intermediate_size
;
DimSize
intermediate_size
;
size_t
max_position_embeddings
;
size_t
max_position_embeddings
;
...
@@ -23,10 +23,13 @@ class ModelConfig {
...
@@ -23,10 +23,13 @@ class ModelConfig {
size_t
num_key_value_heads
;
size_t
num_key_value_heads
;
size_t
vocab_size
;
size_t
vocab_size
;
NLOHMANN_DEFINE_TYPE_INTRUSIVE
(
ModelConfig
,
hidden_size
,
intermediate_size
,
max_position_embeddings
,
model_type
,
NLOHMANN_DEFINE_TYPE_INTRUSIVE
(
ModelConfig
,
hidden_size
,
intermediate_size
,
num_attention_heads
,
num_hidden_layers
,
num_key_value_heads
,
vocab_size
);
max_position_embeddings
,
model_type
,
num_attention_heads
,
num_hidden_layers
,
num_key_value_heads
,
vocab_size
);
void
load_from
(
std
::
filesystem
::
path
path
)
{
void
load_from
(
std
::
filesystem
::
path
path
)
{
std
::
cout
<<
"Load from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
std
::
ifstream
i
(
path
);
nlohmann
::
json
j
;
nlohmann
::
json
j
;
i
>>
j
;
i
>>
j
;
...
@@ -38,12 +41,14 @@ using QuantType = std::string;
...
@@ -38,12 +41,14 @@ using QuantType = std::string;
static
const
QuantType
NoQuantType
=
""
;
static
const
QuantType
NoQuantType
=
""
;
class
QuantConfig
{
class
QuantConfig
{
public:
public:
QuantType
name
;
QuantType
name
;
// For GEMV
// For GEMV
QuantType
type_of_dot_vector
=
NoQuantType
;
QuantType
type_of_dot_vector
=
NoQuantType
;
inline
bool
can_be_used_as_matrix
()
{
return
type_of_dot_vector
!=
NoQuantType
;
}
inline
bool
can_be_used_as_matrix
()
{
return
type_of_dot_vector
!=
NoQuantType
;
}
bool
can_be_used_as_vector
;
bool
can_be_used_as_vector
;
...
@@ -56,8 +61,11 @@ class QuantConfig {
...
@@ -56,8 +61,11 @@ class QuantConfig {
URL
reference
=
""
;
URL
reference
=
""
;
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT
(
QuantConfig
,
name
,
type_of_dot_vector
,
can_be_used_as_vector
,
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT
(
QuantConfig
,
name
,
bytes_per_element
,
has_scale
,
has_min
,
block_element_count
,
type_of_dot_vector
,
can_be_used_as_vector
,
bytes_per_element
,
has_scale
,
has_min
,
block_element_count
,
block_element_size
,
reference
);
block_element_size
,
reference
);
};
};
...
@@ -65,14 +73,18 @@ inline std::map<QuantType, QuantConfig> quant_configs;
...
@@ -65,14 +73,18 @@ inline std::map<QuantType, QuantConfig> quant_configs;
inline
std
::
map
<
ModelName
,
ModelConfig
>
model_configs
;
inline
std
::
map
<
ModelName
,
ModelConfig
>
model_configs
;
inline
void
load_quant_configs
(
std
::
filesystem
::
path
path
)
{
inline
void
load_quant_configs
(
std
::
filesystem
::
path
path
)
{
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
nlohmann
::
json
j
;
nlohmann
::
json
j
;
i
>>
j
;
if
(
std
::
filesystem
::
exists
(
path
))
{
quant_configs
=
j
.
get
<
std
::
map
<
QuantType
,
QuantConfig
>>
();
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
cout
<<
"Loaded Quant Configs"
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
for
(
auto
&
[
k
,
v
]
:
quant_configs
)
{
i
>>
j
;
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
quant_configs
=
j
.
get
<
std
::
map
<
QuantType
,
QuantConfig
>>
();
std
::
cout
<<
"Loaded Quant Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
quant_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
else
{
std
::
cout
<<
__FUNCTION__
<<
" no file at "
<<
path
<<
std
::
endl
;
}
}
}
}
...
@@ -83,14 +95,18 @@ inline void dump_quant_configs(std::filesystem::path path) {
...
@@ -83,14 +95,18 @@ inline void dump_quant_configs(std::filesystem::path path) {
}
}
inline
void
load_model_configs
(
std
::
filesystem
::
path
path
)
{
inline
void
load_model_configs
(
std
::
filesystem
::
path
path
)
{
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
nlohmann
::
json
j
;
nlohmann
::
json
j
;
i
>>
j
;
if
(
std
::
filesystem
::
exists
(
path
))
{
model_configs
=
j
.
get
<
std
::
map
<
ModelName
,
ModelConfig
>>
();
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
cout
<<
"Loaded Model Configs"
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
for
(
auto
&
[
k
,
v
]
:
model_configs
)
{
i
>>
j
;
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
model_configs
=
j
.
get
<
std
::
map
<
ModelName
,
ModelConfig
>>
();
std
::
cout
<<
"Loaded Model Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
model_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
else
{
std
::
cout
<<
__FUNCTION__
<<
" no file at "
<<
path
<<
std
::
endl
;
}
}
}
}
...
...
csrc/balance_serve/kvc2/src/page_aligned_memory_pool.cpp
View file @
64de7843
...
@@ -17,13 +17,14 @@ PageAlignedMemoryPool::PageAlignedMemoryPool(size_t size_in_bytes) {
...
@@ -17,13 +17,14 @@ PageAlignedMemoryPool::PageAlignedMemoryPool(size_t size_in_bytes) {
assert
(
total_pages
>=
Blocks
);
assert
(
total_pages
>=
Blocks
);
page_per_block
=
total_pages
/
Blocks
;
page_per_block
=
total_pages
/
Blocks
;
for
(
size_t
block_index
=
0
;
block_index
<
Blocks
;
block_index
++
)
{
for
(
size_t
block_index
=
0
;
block_index
<
Blocks
;
block_index
++
)
{
first_page
[
block_index
]
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
intptr_t
>
(
data
)
+
static_cast
<
intptr_t
>
(
block_index
)
*
page_per_block
*
PageSize
);
first_page
[
block_index
]
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
intptr_t
>
(
data
)
+
static_cast
<
intptr_t
>
(
block_index
)
*
page_per_block
*
PageSize
);
count_page
[
block_index
]
=
count_page
[
block_index
]
=
block_index
==
Blocks
-
1
?
(
total_pages
-
page_per_block
*
(
Blocks
-
1
))
:
page_per_block
;
block_index
==
Blocks
-
1
?
(
total_pages
-
page_per_block
*
(
Blocks
-
1
))
:
page_per_block
;
SPDLOG_DEBUG
(
"first_page[{}] = {}, count_page[{}] = {}"
,
SPDLOG_DEBUG
(
"first_page[{}] = {}, count_page[{}] = {}"
,
block_index
,
block_index
,
reinterpret_cast
<
intptr_t
>
(
first_page
[
block_index
])
-
reinterpret_cast
<
intptr_t
>
(
data
),
reinterpret_cast
<
intptr_t
>
(
first_page
[
block_index
])
-
reinterpret_cast
<
intptr_t
>
(
data
),
block_index
,
block_index
,
count_page
[
block_index
]);
count_page
[
block_index
]);
bitmap
[
block_index
].
resize
(
count_page
[
block_index
],
0
);
bitmap
[
block_index
].
resize
(
count_page
[
block_index
],
0
);
}
}
SPDLOG_INFO
(
"PageAlignedMemoryPool with size {} Mbytes, {} pages"
,
total_size
/
(
1
<<
20
),
page_count
());
SPDLOG_INFO
(
"PageAlignedMemoryPool with size {} Mbytes, {} pages"
,
total_size
/
(
1
<<
20
),
page_count
());
...
@@ -53,7 +54,7 @@ void* PageAlignedMemoryPool::alloc_in_block(size_t block_index, size_t alloc_siz
...
@@ -53,7 +54,7 @@ void* PageAlignedMemoryPool::alloc_in_block(size_t block_index, size_t alloc_siz
size_t
free_pages
=
0
;
size_t
free_pages
=
0
;
for
(
size_t
i
=
0
;
i
<
count_page
[
block_index
];
i
++
)
{
for
(
size_t
i
=
0
;
i
<
count_page
[
block_index
];
i
++
)
{
if
(
bitmap
[
block_index
][
i
]
==
0
)
{
if
(
bitmap
[
block_index
][
i
]
==
0
)
{
free_pages
++
;
free_pages
++
;
if
(
free_pages
==
alloc_size
)
{
if
(
free_pages
==
alloc_size
)
{
size_t
page_index
=
i
+
1
-
free_pages
;
size_t
page_index
=
i
+
1
-
free_pages
;
for
(
size_t
page
=
page_index
;
page
<
page_index
+
alloc_size
;
page
++
)
{
for
(
size_t
page
=
page_index
;
page
<
page_index
+
alloc_size
;
page
++
)
{
...
@@ -73,7 +74,7 @@ void* PageAlignedMemoryPool::alloc_in_block(size_t block_index, size_t alloc_siz
...
@@ -73,7 +74,7 @@ void* PageAlignedMemoryPool::alloc_in_block(size_t block_index, size_t alloc_siz
void
*
PageAlignedMemoryPool
::
alloc
(
size_t
size
)
{
void
*
PageAlignedMemoryPool
::
alloc
(
size_t
size
)
{
size_t
alloc_size
=
div_up
(
size
,
PageSize
);
size_t
alloc_size
=
div_up
(
size
,
PageSize
);
auto
cnt
=
now_block
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
auto
cnt
=
now_block
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
for
(
size_t
i
=
0
;
i
<
Blocks
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
Blocks
;
i
++
)
{
auto
result
=
alloc_in_block
((
i
+
cnt
)
%
Blocks
,
alloc_size
);
auto
result
=
alloc_in_block
((
i
+
cnt
)
%
Blocks
,
alloc_size
);
if
(
result
!=
nullptr
)
{
if
(
result
!=
nullptr
)
{
allocated
.
fetch_add
(
alloc_size
*
PageSize
,
std
::
memory_order_relaxed
);
allocated
.
fetch_add
(
alloc_size
*
PageSize
,
std
::
memory_order_relaxed
);
...
@@ -119,5 +120,6 @@ void PageAlignedMemoryPool::defragment() {}
...
@@ -119,5 +120,6 @@ void PageAlignedMemoryPool::defragment() {}
/// 调试打印
/// 调试打印
std
::
string
PageAlignedMemoryPool
::
debug
()
{
std
::
string
PageAlignedMemoryPool
::
debug
()
{
return
fmt
::
format
(
"PageAlignedMemoryPool: total_size: {}MB, allocated: {}, alloc/free count: {}/{}
\n
"
,
return
fmt
::
format
(
"PageAlignedMemoryPool: total_size: {}MB, allocated: {}, alloc/free count: {}/{}
\n
"
,
readable_number
(
total_size
),
readable_number
(
size_t
(
allocated
)),
size_t
(
alloc_count
),
size_t
(
free_count
));
readable_number
(
total_size
),
readable_number
(
size_t
(
allocated
)),
size_t
(
alloc_count
),
size_t
(
free_count
));
}
}
csrc/balance_serve/kvc2/src/page_aligned_memory_pool.h
View file @
64de7843
#pragma once
#pragma once
#include <algorithm> // std::sort
#include <cstddef> // size_t
#include <mutex> // std::mutex
#include <vector>
#include <assert.h>
#include <assert.h>
#include <
bitset>
#include <
algorithm> // std::sort
#include <atomic>
#include <atomic>
#include <bitset>
#include <cstddef> // size_t
#include <mutex> // std::mutex
#include <vector>
constexpr
size_t
PageSize
=
4096
;
constexpr
size_t
PageSize
=
4096
;
...
@@ -18,7 +18,7 @@ struct PageAlignedMemoryPool {
...
@@ -18,7 +18,7 @@ struct PageAlignedMemoryPool {
void
*
data
=
nullptr
;
void
*
data
=
nullptr
;
size_t
total_size
=
0
,
total_pages
=
0
;
size_t
total_size
=
0
,
total_pages
=
0
;
std
::
atomic_size_t
now_block
=
0
;
std
::
atomic_size_t
now_block
=
0
;
std
::
atomic_size_t
allocated
=
0
;
// allocated_size
std
::
atomic_size_t
allocated
=
0
;
// allocated_size
std
::
atomic_size_t
alloc_count
=
0
;
std
::
atomic_size_t
alloc_count
=
0
;
...
@@ -26,10 +26,11 @@ struct PageAlignedMemoryPool {
...
@@ -26,10 +26,11 @@ struct PageAlignedMemoryPool {
std
::
mutex
lock
[
Blocks
];
std
::
mutex
lock
[
Blocks
];
size_t
page_per_block
=
0
;
size_t
page_per_block
=
0
;
void
*
first_page
[
Blocks
];
void
*
first_page
[
Blocks
];
size_t
count_page
[
Blocks
];
size_t
count_page
[
Blocks
];
std
::
vector
<
int8_t
>
bitmap
[
Blocks
];
std
::
vector
<
int8_t
>
bitmap
[
Blocks
];
void
*
alloc_in_block
(
size_t
block_index
,
size_t
alloc_size
);
void
*
alloc_in_block
(
size_t
block_index
,
size_t
alloc_size
);
public:
public:
/// 构造函数和析构函数
/// 构造函数和析构函数
explicit
PageAlignedMemoryPool
(
size_t
size_in_bytes
);
explicit
PageAlignedMemoryPool
(
size_t
size_in_bytes
);
...
...
csrc/balance_serve/kvc2/src/prefix.cpp
View file @
64de7843
...
@@ -339,7 +339,7 @@ struct Prefix {
...
@@ -339,7 +339,7 @@ struct Prefix {
void
update_location
(
CacheInfo
info
,
Location
location
)
{
locations
.
location_map
[
info
]
=
location
;
}
void
update_location
(
CacheInfo
info
,
Location
location
)
{
locations
.
location_map
[
info
]
=
location
;
}
Prefix
*
to_first_prefix_without_disk_locations
(
CacheInfo
k_info
/*, CacheInfo v_info*/
)
{
// just k_info
Prefix
*
to_first_prefix_without_disk_locations
(
CacheInfo
k_info
/*, CacheInfo v_info*/
)
{
// just k_info
auto
now_prefix
=
this
;
auto
now_prefix
=
this
;
while
(
now_prefix
->
prev
!=
nullptr
)
{
while
(
now_prefix
->
prev
!=
nullptr
)
{
auto
&
prev
=
now_prefix
->
prev
;
auto
&
prev
=
now_prefix
->
prev
;
...
@@ -561,7 +561,7 @@ struct PrefixTree {
...
@@ -561,7 +561,7 @@ struct PrefixTree {
if
(
need_lock
)
{
if
(
need_lock
)
{
sl
=
std
::
shared_lock
<
std
::
shared_mutex
>
(
rw_lock
);
sl
=
std
::
shared_lock
<
std
::
shared_mutex
>
(
rw_lock
);
}
}
//TODO: prefix cache
//
TODO: prefix cache
}
}
PrefixMatch
look_up_or_insert
(
Token
*
data
,
TokenLength
length
)
{
PrefixMatch
look_up_or_insert
(
Token
*
data
,
TokenLength
length
)
{
...
@@ -579,7 +579,6 @@ struct PrefixTree {
...
@@ -579,7 +579,6 @@ struct PrefixTree {
return
re
;
return
re
;
}
}
std
::
shared_ptr
<
Prefix
>
new_prefix_node
(
Prefix
*
prev
,
TokenLength
prev_match_length
,
Token
*
data
,
TokenLength
length
,
std
::
shared_ptr
<
Prefix
>
new_prefix_node
(
Prefix
*
prev
,
TokenLength
prev_match_length
,
Token
*
data
,
TokenLength
length
,
bool
need_lock
=
true
)
{
bool
need_lock
=
true
)
{
std
::
unique_lock
<
std
::
shared_mutex
>
ul
;
std
::
unique_lock
<
std
::
shared_mutex
>
ul
;
...
@@ -700,9 +699,7 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
...
@@ -700,9 +699,7 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
}
}
}
}
}
}
std
::
vector
<
MatchStatus
>
matched_status
()
override
{
std
::
vector
<
MatchStatus
>
matched_status
()
override
{
assert
(
false
);
}
assert
(
false
);
}
bool
any_match
()
{
bool
any_match
()
{
if
(
enable_alt
)
{
if
(
enable_alt
)
{
...
@@ -1066,7 +1063,6 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
...
@@ -1066,7 +1063,6 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
};
};
struct
KVC2
:
KVC2Interface
{
struct
KVC2
:
KVC2Interface
{
KVC2Config
config
;
KVC2Config
config
;
std
::
shared_ptr
<
Metrics
>
met
;
std
::
shared_ptr
<
Metrics
>
met
;
...
@@ -1194,7 +1190,7 @@ struct KVC2 : KVC2Interface {
...
@@ -1194,7 +1190,7 @@ struct KVC2 : KVC2Interface {
auto
v_loc
=
disk_cache
->
allocate
(
h
->
v_info
(),
div_up
(
new_length
,
NumTokenPerBlock
));
auto
v_loc
=
disk_cache
->
allocate
(
h
->
v_info
(),
div_up
(
new_length
,
NumTokenPerBlock
));
h
->
k_seg_locs
.
add_location
(
now_prefix
->
start_length
/
NumTokenPerBlock
,
k_loc
);
h
->
k_seg_locs
.
add_location
(
now_prefix
->
start_length
/
NumTokenPerBlock
,
k_loc
);
h
->
v_seg_locs
.
add_location
(
now_prefix
->
start_length
/
NumTokenPerBlock
,
v_loc
);
h
->
v_seg_locs
.
add_location
(
now_prefix
->
start_length
/
NumTokenPerBlock
,
v_loc
);
// split it to prefix trees
// split it to prefix trees
for
(
auto
tail
=
h
->
match
.
prefix
;
tail
!=
now_prefix
->
prev
;
tail
=
tail
->
prev
)
{
for
(
auto
tail
=
h
->
match
.
prefix
;
tail
!=
now_prefix
->
prev
;
tail
=
tail
->
prev
)
{
TokenLength
local_ids_length
=
tail
->
local_length
();
TokenLength
local_ids_length
=
tail
->
local_length
();
...
@@ -1207,7 +1203,7 @@ struct KVC2 : KVC2Interface {
...
@@ -1207,7 +1203,7 @@ struct KVC2 : KVC2Interface {
// allocate a big space on disk
// allocate a big space on disk
auto
k_loc
=
disk_cache
->
allocate
(
h
->
k_info
(),
div_up
(
new_length
,
NumTokenPerBlock
));
auto
k_loc
=
disk_cache
->
allocate
(
h
->
k_info
(),
div_up
(
new_length
,
NumTokenPerBlock
));
h
->
k_seg_locs
.
add_location
(
now_prefix
->
start_length
/
NumTokenPerBlock
,
k_loc
);
h
->
k_seg_locs
.
add_location
(
now_prefix
->
start_length
/
NumTokenPerBlock
,
k_loc
);
// split it to prefix trees
// split it to prefix trees
for
(
auto
tail
=
h
->
match
.
prefix
;
tail
!=
now_prefix
->
prev
;
tail
=
tail
->
prev
)
{
for
(
auto
tail
=
h
->
match
.
prefix
;
tail
!=
now_prefix
->
prev
;
tail
=
tail
->
prev
)
{
TokenLength
local_ids_length
=
tail
->
local_length
();
TokenLength
local_ids_length
=
tail
->
local_length
();
...
@@ -1231,7 +1227,7 @@ struct KVC2 : KVC2Interface {
...
@@ -1231,7 +1227,7 @@ struct KVC2 : KVC2Interface {
h
->
kvc2_top
=
this
;
h
->
kvc2_top
=
this
;
h
->
set_cache_info
(
model_name
,
quant_type
,
config
.
k_cache_on
,
config
.
v_cache_on
);
h
->
set_cache_info
(
model_name
,
quant_type
,
config
.
k_cache_on
,
config
.
v_cache_on
);
h
->
ids
=
Tokens
(
id
,
id
+
length
);
h
->
ids
=
Tokens
(
id
,
id
+
length
);
if
(
config
.
k_cache_on
)
if
(
config
.
k_cache_on
)
h
->
set_raw_handles
(
true
,
k_cache
);
h
->
set_raw_handles
(
true
,
k_cache
);
if
(
config
.
v_cache_on
)
if
(
config
.
v_cache_on
)
...
@@ -1261,7 +1257,7 @@ struct KVC2 : KVC2Interface {
...
@@ -1261,7 +1257,7 @@ struct KVC2 : KVC2Interface {
re
->
kvc2_top
=
this
;
re
->
kvc2_top
=
this
;
SPDLOG_DEBUG
(
"Lookup TokenLength {}"
,
length
);
SPDLOG_DEBUG
(
"Lookup TokenLength {}"
,
length
);
if
(
config
.
gpu_only
==
false
)
{
if
(
config
.
gpu_only
==
false
)
{
//TODO:
//
TODO:
}
}
return
re
;
return
re
;
};
};
...
@@ -1694,9 +1690,11 @@ void GPUPageCache::gpu_background_flush() {
...
@@ -1694,9 +1690,11 @@ void GPUPageCache::gpu_background_flush() {
if
(
col_uls
.
empty
())
if
(
col_uls
.
empty
())
continue
;
continue
;
for
(
size_t
l
=
0
;
l
<
config
.
layer_count
;
l
++
)
{
for
(
size_t
l
=
0
;
l
<
config
.
layer_count
;
l
++
)
{
if
(
config
.
k_cache_on
&&
(
occupations
[
l
][
i
]
->
gpu_cc
.
dirty
.
load
()
==
false
||
occupations
[
l
][
i
]
->
cpu_cc
.
dirty
.
load
()))
if
(
config
.
k_cache_on
&&
(
occupations
[
l
][
i
]
->
gpu_cc
.
dirty
.
load
()
==
false
||
occupations
[
l
][
i
]
->
cpu_cc
.
dirty
.
load
()))
goto
next_gpu_page
;
goto
next_gpu_page
;
if
(
config
.
v_cache_on
&&
(
v_occupations
[
l
][
i
]
->
gpu_cc
.
dirty
.
load
()
==
false
||
v_occupations
[
l
][
i
]
->
cpu_cc
.
dirty
.
load
()))
if
(
config
.
v_cache_on
&&
(
v_occupations
[
l
][
i
]
->
gpu_cc
.
dirty
.
load
()
==
false
||
v_occupations
[
l
][
i
]
->
cpu_cc
.
dirty
.
load
()))
goto
next_gpu_page
;
goto
next_gpu_page
;
}
}
...
...
csrc/balance_serve/kvc2/test/kvc2test/common.hpp
View file @
64de7843
...
@@ -139,18 +139,18 @@ std::vector<Token> random_ids(size_t length, std::mt19937& gen) {
...
@@ -139,18 +139,18 @@ std::vector<Token> random_ids(size_t length, std::mt19937& gen) {
return
re
;
return
re
;
}
}
std
::
vector
<
layer_data
>
slice
(
std
::
vector
<
layer_data
>&
h1
,
size_t
start
,
size_t
end
){
std
::
vector
<
layer_data
>
slice
(
std
::
vector
<
layer_data
>&
h1
,
size_t
start
,
size_t
end
)
{
std
::
vector
<
layer_data
>
re
;
std
::
vector
<
layer_data
>
re
;
for
(
auto
&
l
:
h1
){
for
(
auto
&
l
:
h1
)
{
layer_data
new_layer
;
layer_data
new_layer
;
new_layer
.
insert
(
new_layer
.
end
(),
l
.
begin
()
+
start
,
l
.
begin
()
+
end
);
new_layer
.
insert
(
new_layer
.
end
(),
l
.
begin
()
+
start
,
l
.
begin
()
+
end
);
re
.
push_back
(
new_layer
);
re
.
push_back
(
new_layer
);
}
}
return
re
;
return
re
;
}
}
void
cmp_handle_data
(
std
::
vector
<
layer_data
>
h1
,
std
::
vector
<
layer_data
>
h2
,
void
cmp_handle_data
(
std
::
vector
<
layer_data
>
h1
,
std
::
vector
<
layer_data
>
h2
,
std
::
optional
<
size_t
>
blocks
=
std
::
nullopt
)
{
std
::
optional
<
size_t
>
blocks
=
std
::
nullopt
)
{
assert
(
h1
.
size
()
==
h2
.
size
());
assert
(
h1
.
size
()
==
h2
.
size
());
for
(
size_t
i
=
0
;
i
<
h1
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
h1
.
size
();
i
++
)
{
...
...
csrc/balance_serve/kvc2/test/kvc2test/flush-back.cpp
View file @
64de7843
...
@@ -7,9 +7,9 @@ int main(int argc, char* argv[]) {
...
@@ -7,9 +7,9 @@ int main(int argc, char* argv[]) {
config
.
gpu_cache_config
->
total_kvcache_pages
=
12
;
config
.
gpu_cache_config
->
total_kvcache_pages
=
12
;
auto
kvc2
=
kvc2
::
create_kvc2
(
config
);
auto
kvc2
=
kvc2
::
create_kvc2
(
config
);
// #pragma omp parallel for
// #pragma omp parallel for
for
(
size_t
ti
=
0
;
ti
<
2
;
ti
++
)
{
for
(
size_t
ti
=
0
;
ti
<
2
;
ti
++
)
{
SPDLOG_WARN
(
"Test {}"
,
ti
);
SPDLOG_WARN
(
"Test {}"
,
ti
);
auto
[
kcache
,
vcache
]
=
kvc2
->
get_kvcache
();
auto
[
kcache
,
vcache
]
=
kvc2
->
get_kvcache
();
std
::
mt19937
gen
(
ti
+
123
);
std
::
mt19937
gen
(
ti
+
123
);
size_t
total_page
=
10
;
size_t
total_page
=
10
;
...
...
csrc/balance_serve/kvc2/test/kvc2test/lookup-gpu-mt-without-vcache.cpp
View file @
64de7843
...
@@ -14,7 +14,7 @@ int main(int argc, char* argv[]) {
...
@@ -14,7 +14,7 @@ int main(int argc, char* argv[]) {
qw25_7B_gpu_config
.
v_cache_on
=
false
;
qw25_7B_gpu_config
.
v_cache_on
=
false
;
config
.
gpu_cache_config
=
qw25_7B_gpu_config
;
config
.
gpu_cache_config
=
qw25_7B_gpu_config
;
config
.
v_cache_on
=
false
;
config
.
v_cache_on
=
false
;
init
(
argc
,
argv
);
init
(
argc
,
argv
);
spdlog
::
set_level
(
spdlog
::
level
::
debug
);
spdlog
::
set_level
(
spdlog
::
level
::
debug
);
auto
kvc2
=
kvc2
::
create_kvc2
(
config
);
auto
kvc2
=
kvc2
::
create_kvc2
(
config
);
...
...
csrc/balance_serve/kvc2/test/kvc2test/lookup-without-vcache.cpp
View file @
64de7843
...
@@ -11,11 +11,10 @@
...
@@ -11,11 +11,10 @@
#include "common.hpp"
#include "common.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
int
main
(
int
argc
,
char
*
argv
[])
{
qw25_7B_gpu_config
.
v_cache_on
=
false
;
qw25_7B_gpu_config
.
v_cache_on
=
false
;
config
.
gpu_cache_config
=
qw25_7B_gpu_config
;
config
.
gpu_cache_config
=
qw25_7B_gpu_config
;
config
.
v_cache_on
=
false
;
config
.
v_cache_on
=
false
;
init
(
argc
,
argv
);
init
(
argc
,
argv
);
spdlog
::
set_level
(
spdlog
::
level
::
debug
);
spdlog
::
set_level
(
spdlog
::
level
::
debug
);
auto
kvc2
=
kvc2
::
create_kvc2
(
config
);
auto
kvc2
=
kvc2
::
create_kvc2
(
config
);
...
...
csrc/balance_serve/kvc2/test/page_pool_test.cpp
View file @
64de7843
#include <unistd.h>
#include <iostream>
#include <iostream>
#include <random>
#include <thread>
#include <thread>
#include <vector>
#include <vector>
#include <random>
#include <unistd.h>
#include "page_aligned_memory_pool.cpp"
#include "page_aligned_memory_pool.cpp"
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG
#define FMT_HEADER_ONLY
#define FMT_HEADER_ONLY
#include "spdlog/spdlog.h"
#include "spdlog/spdlog.h"
// 每个线程执行的任务
// 每个线程执行的任务
void
thread_task
(
PageAlignedMemoryPool
&
pool
)
{
void
thread_task
(
PageAlignedMemoryPool
&
pool
)
{
std
::
mt19937
gen
(
123
);
std
::
mt19937
gen
(
123
);
...
@@ -22,8 +21,8 @@ void thread_task(PageAlignedMemoryPool& pool) {
...
@@ -22,8 +21,8 @@ void thread_task(PageAlignedMemoryPool& pool) {
void
*
ptr
=
pool
.
alloc
(
size
);
void
*
ptr
=
pool
.
alloc
(
size
);
// SPDLOG_DEBUG(pool.debug());
// SPDLOG_DEBUG(pool.debug());
if
(
ptr
)
{
if
(
ptr
)
{
pool
.
free
(
ptr
,
size
);
pool
.
free
(
ptr
,
size
);
// allocated.push_back({ptr, size});
// allocated.push_back({ptr, size});
}
}
// sleep((int)(gen() % 1000) / 1000.0);
// sleep((int)(gen() % 1000) / 1000.0);
}
}
...
@@ -35,21 +34,20 @@ void thread_task(PageAlignedMemoryPool& pool) {
...
@@ -35,21 +34,20 @@ void thread_task(PageAlignedMemoryPool& pool) {
int
main
(
int
argc
,
char
*
argv
[])
{
int
main
(
int
argc
,
char
*
argv
[])
{
spdlog
::
set_level
(
spdlog
::
level
::
debug
);
spdlog
::
set_level
(
spdlog
::
level
::
debug
);
// 创建一个内存池
// 创建一个内存池
PageAlignedMemoryPool
pool
(
40ll
*
1024
*
1024
*
1024
);
// 40 G
PageAlignedMemoryPool
pool
(
40ll
*
1024
*
1024
*
1024
);
// 40 G
// 创建线程
// 创建线程
const
int
num_threads
=
32
;
const
int
num_threads
=
32
;
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
thread
>
threads
;
for
(
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
threads
.
emplace_back
(
thread_task
,
std
::
ref
(
pool
));
threads
.
emplace_back
(
thread_task
,
std
::
ref
(
pool
));
}
}
// 等待所有线程完成
// 等待所有线程完成
for
(
auto
&
t
:
threads
)
{
for
(
auto
&
t
:
threads
)
{
t
.
join
();
t
.
join
();
}
}
// 输出调试信息
// 输出调试信息
...
...
csrc/balance_serve/kvc2/test/test_periodic_task.cpp
View file @
64de7843
#include "utils/periodic_task.hpp"
#include <atomic>
#include <cassert>
#include <chrono>
#include <chrono>
#include <cstdio>
#include <cstdio>
#include <future>
#include <iostream>
#include <iostream>
#include <thread>
#include <thread>
#include <future>
#include "utils/periodic_task.hpp"
#include <atomic>
#include <cassert>
// 1. 任务是否按预期执行
// 1. 任务是否按预期执行
void
testPeriodicTaskExecution
()
{
void
testPeriodicTaskExecution
()
{
std
::
atomic
<
int
>
execution_count
{
0
};
std
::
atomic
<
int
>
execution_count
{
0
};
auto
task
=
[
&
execution_count
]()
{
auto
task
=
[
&
execution_count
]()
{
execution_count
++
;
};
execution_count
++
;
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
50
));
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
50
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
2
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
2
));
assert
(
execution_count
>=
20
);
// 确保任务执行了至少 20 次
assert
(
execution_count
>=
20
);
// 确保任务执行了至少 20 次
std
::
cout
<<
"Test 1 passed: Task executed periodically."
<<
std
::
endl
;
std
::
cout
<<
"Test 1 passed: Task executed periodically."
<<
std
::
endl
;
std
::
cout
<<
"Task executed "
<<
execution_count
.
load
()
<<
" times."
<<
std
::
endl
;
std
::
cout
<<
"Task executed "
<<
execution_count
.
load
()
<<
" times."
<<
std
::
endl
;
}
}
// 2. 提前唤醒任务的功能
// 2. 提前唤醒任务的功能
void
testWakeUpImmediately
()
{
void
testWakeUpImmediately
()
{
std
::
atomic
<
int
>
execution_count
{
0
};
std
::
atomic
<
int
>
execution_count
{
0
};
auto
task
=
[
&
execution_count
]()
{
auto
task
=
[
&
execution_count
]()
{
execution_count
++
;
};
execution_count
++
;
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
// 提前唤醒任务
// 提前唤醒任务
periodic_task
.
wakeUp
();
periodic_task
.
wakeUp
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
50
));
// 等待任务执行
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
50
));
// 等待任务执行
std
::
cout
<<
"Execution count after wakeUp: "
<<
execution_count
.
load
()
<<
std
::
endl
;
std
::
cout
<<
"Execution count after wakeUp: "
<<
execution_count
.
load
()
<<
std
::
endl
;
assert
(
execution_count
==
1
);
// 确保任务立即执行
assert
(
execution_count
==
1
);
// 确保任务立即执行
std
::
cout
<<
"Test 2 passed: Task woke up immediately."
<<
std
::
endl
;
std
::
cout
<<
"Test 2 passed: Task woke up immediately."
<<
std
::
endl
;
}
}
// 3. wakeUpWait() 的等待功能
// 3. wakeUpWait() 的等待功能
void
testWakeUpWait
()
{
void
testWakeUpWait
()
{
std
::
promise
<
void
>
promise
;
std
::
promise
<
void
>
promise
;
std
::
future
<
void
>
future
=
promise
.
get_future
();
std
::
future
<
void
>
future
=
promise
.
get_future
();
auto
task
=
[
&
promise
]()
{
auto
task
=
[
&
promise
]()
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// 模拟任务执行
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// 模拟任务执行
promise
.
set_value
();
// 任务完成时设置 promise
promise
.
set_value
();
// 任务完成时设置 promise
};
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
// 调用 wakeUpWait 并等待任务完成
// 调用 wakeUpWait 并等待任务完成
std
::
future
<
void
>
wakeup_future
=
periodic_task
.
wakeUpWait
();
std
::
future
<
void
>
wakeup_future
=
periodic_task
.
wakeUpWait
();
wakeup_future
.
wait
();
// 等待任务完成
wakeup_future
.
wait
();
// 等待任务完成
assert
(
wakeup_future
.
valid
());
// 确保 future 是有效的
assert
(
wakeup_future
.
valid
());
// 确保 future 是有效的
std
::
cout
<<
"Test 3 passed: wakeUpWait() works correctly."
<<
std
::
endl
;
std
::
cout
<<
"Test 3 passed: wakeUpWait() works correctly."
<<
std
::
endl
;
std
::
cout
<<
"wakeUpWait() future is valid."
<<
std
::
endl
;
std
::
cout
<<
"wakeUpWait() future is valid."
<<
std
::
endl
;
}
}
// 4. 任务抛出异常的处理
// 4. 任务抛出异常的处理
void
testTaskExceptionHandling
()
{
void
testTaskExceptionHandling
()
{
auto
task
=
[]()
{
auto
task
=
[]()
{
throw
std
::
runtime_error
(
"Test exception"
);
};
throw
std
::
runtime_error
(
"Test exception"
);
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
300
));
// 等待一段时间
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
300
));
// 等待一段时间
std
::
cout
<<
"Test 4 passed: Task exception is handled correctly."
<<
std
::
endl
;
std
::
cout
<<
"Test 4 passed: Task exception is handled correctly."
<<
std
::
endl
;
std
::
cout
<<
"Exception handled and task did not crash."
<<
std
::
endl
;
std
::
cout
<<
"Exception handled and task did not crash."
<<
std
::
endl
;
}
}
// 5. 线程是否能正确停止
// 5. 线程是否能正确停止
void
testTaskStop
()
{
void
testTaskStop
()
{
std
::
atomic
<
bool
>
stopped
{
false
};
std
::
atomic
<
bool
>
stopped
{
false
};
auto
task
=
[
&
stopped
]()
{
auto
task
=
[
&
stopped
]()
{
while
(
!
stopped
)
{
while
(
!
stopped
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
50
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
50
));
}
}
};
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
100
));
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
100
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
// 运行一段时间
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
// 运行一段时间
stopped
=
true
;
// 请求停止
stopped
=
true
;
// 请求停止
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
50
));
// 等待线程停止
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
50
));
// 等待线程停止
std
::
cout
<<
"Test 5 passed: Task thread stops correctly."
<<
std
::
endl
;
std
::
cout
<<
"Test 5 passed: Task thread stops correctly."
<<
std
::
endl
;
std
::
cout
<<
"Task has been stopped successfully."
<<
std
::
endl
;
std
::
cout
<<
"Task has been stopped successfully."
<<
std
::
endl
;
}
}
// 6. 高频唤醒的情况下任务执行是否正常
// 6. 高频唤醒的情况下任务执行是否正常
void
testHighFrequencyWakeUp
()
{
void
testHighFrequencyWakeUp
()
{
std
::
atomic
<
int
>
execution_count
{
0
};
std
::
atomic
<
int
>
execution_count
{
0
};
auto
task
=
[
&
execution_count
]()
{
auto
task
=
[
&
execution_count
]()
{
execution_count
++
;
};
execution_count
++
;
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
periodic_task
.
wakeUp
();
periodic_task
.
wakeUp
();
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
// 每 10 毫秒唤醒一次
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
// 每 10 毫秒唤醒一次
}
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
// 等待任务执行完成
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
// 等待任务执行完成
assert
(
execution_count
>
50
);
// 确保任务至少执行了 50 次
assert
(
execution_count
>
50
);
// 确保任务至少执行了 50 次
std
::
cout
<<
"Test 6 passed: Task handles frequent wake ups correctly."
<<
std
::
endl
;
std
::
cout
<<
"Test 6 passed: Task handles frequent wake ups correctly."
<<
std
::
endl
;
std
::
cout
<<
"Task executed "
<<
execution_count
.
load
()
<<
" times."
<<
std
::
endl
;
std
::
cout
<<
"Task executed "
<<
execution_count
.
load
()
<<
" times."
<<
std
::
endl
;
}
}
// 7. 多个 wakeUpWait() 调用的处理
// 7. 多个 wakeUpWait() 调用的处理
void
testMultipleWakeUpWait
()
{
void
testMultipleWakeUpWait
()
{
std
::
atomic
<
int
>
execution_count
{
0
};
std
::
atomic
<
int
>
execution_count
{
0
};
auto
task
=
[
&
execution_count
]()
{
auto
task
=
[
&
execution_count
]()
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// 模拟任务执行
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
// 模拟任务执行
execution_count
++
;
execution_count
++
;
};
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
200
));
// 同时调用两个 wakeUpWait
// 同时调用两个 wakeUpWait
std
::
future
<
void
>
future1
=
periodic_task
.
wakeUpWait
();
std
::
future
<
void
>
future1
=
periodic_task
.
wakeUpWait
();
std
::
future
<
void
>
future2
=
periodic_task
.
wakeUpWait
();
std
::
future
<
void
>
future2
=
periodic_task
.
wakeUpWait
();
future1
.
wait
();
future1
.
wait
();
future2
.
wait
();
future2
.
wait
();
assert
(
execution_count
==
1
);
// 确保任务只执行了一次
assert
(
execution_count
==
1
);
// 确保任务只执行了一次
std
::
cout
<<
"Test 7 passed: Multiple wakeUpWait() calls are handled correctly."
<<
std
::
endl
;
std
::
cout
<<
"Test 7 passed: Multiple wakeUpWait() calls are handled correctly."
<<
std
::
endl
;
std
::
cout
<<
"Task executed "
<<
execution_count
.
load
()
<<
" times."
<<
std
::
endl
;
std
::
cout
<<
"Task executed "
<<
execution_count
.
load
()
<<
" times."
<<
std
::
endl
;
}
}
// 8. 任务函数为空的边界情况
// 8. 任务函数为空的边界情况
void
testEmptyTaskFunction
()
{
void
testEmptyTaskFunction
()
{
auto
task
=
[]()
{
auto
task
=
[]()
{
// 空任务函数
// 空任务函数
};
};
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
100
));
periodic
::
PeriodicTask
periodic_task
(
task
,
std
::
chrono
::
milliseconds
(
100
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
// 等待一段时间
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
// 等待一段时间
std
::
cout
<<
"Test 8 passed: Empty task function works correctly."
<<
std
::
endl
;
std
::
cout
<<
"Test 8 passed: Empty task function works correctly."
<<
std
::
endl
;
std
::
cout
<<
"Empty task function executed without issues."
<<
std
::
endl
;
std
::
cout
<<
"Empty task function executed without issues."
<<
std
::
endl
;
}
}
int
main
()
{
int
main
()
{
std
::
cout
<<
"Starting tests..."
<<
std
::
endl
;
std
::
cout
<<
"Starting tests..."
<<
std
::
endl
;
// testWakeUpImmediately();
// testWakeUpImmediately();
testPeriodicTaskExecution
();
testPeriodicTaskExecution
();
testWakeUpImmediately
();
testWakeUpImmediately
();
testWakeUpWait
();
testWakeUpWait
();
testTaskExceptionHandling
();
testTaskExceptionHandling
();
testTaskStop
();
testTaskStop
();
testHighFrequencyWakeUp
();
testHighFrequencyWakeUp
();
testMultipleWakeUpWait
();
testMultipleWakeUpWait
();
testEmptyTaskFunction
();
testEmptyTaskFunction
();
std
::
cout
<<
"All tests passed!"
<<
std
::
endl
;
std
::
cout
<<
"All tests passed!"
<<
std
::
endl
;
return
0
;
return
0
;
}
}
csrc/balance_serve/sched/bind.cpp
View file @
64de7843
#include "scheduler.h"
#include <memory>
#include <pybind11/numpy.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl.h>
#include <memory>
#include "scheduler.h"
#include <torch/extension.h>
#include <torch/extension.h>
...
@@ -16,19 +16,25 @@ PYBIND11_MODULE(sched_ext, m) {
...
@@ -16,19 +16,25 @@ PYBIND11_MODULE(sched_ext, m) {
.
def_readwrite
(
"layer_count"
,
&
scheduler
::
ModelSettings
::
layer_count
)
.
def_readwrite
(
"layer_count"
,
&
scheduler
::
ModelSettings
::
layer_count
)
.
def_readwrite
(
"num_k_heads"
,
&
scheduler
::
ModelSettings
::
num_k_heads
)
.
def_readwrite
(
"num_k_heads"
,
&
scheduler
::
ModelSettings
::
num_k_heads
)
.
def_readwrite
(
"k_head_dim"
,
&
scheduler
::
ModelSettings
::
k_head_dim
)
.
def_readwrite
(
"k_head_dim"
,
&
scheduler
::
ModelSettings
::
k_head_dim
)
.
def_readwrite
(
"bytes_per_params"
,
&
scheduler
::
ModelSettings
::
bytes_per_params
)
.
def_readwrite
(
"bytes_per_params"
,
.
def_readwrite
(
"bytes_per_kv_cache_element"
,
&
scheduler
::
ModelSettings
::
bytes_per_kv_cache_element
)
&
scheduler
::
ModelSettings
::
bytes_per_params
)
.
def_readwrite
(
"bytes_per_kv_cache_element"
,
&
scheduler
::
ModelSettings
::
bytes_per_kv_cache_element
)
.
def
(
"params_size"
,
&
scheduler
::
ModelSettings
::
params_nbytes
)
.
def
(
"params_size"
,
&
scheduler
::
ModelSettings
::
params_nbytes
)
.
def
(
"bytes_per_token_kv_cache"
,
&
scheduler
::
ModelSettings
::
bytes_per_token_kv_cache
)
.
def
(
"bytes_per_token_kv_cache"
,
&
scheduler
::
ModelSettings
::
bytes_per_token_kv_cache
)
// 添加 pickle 支持
// 添加 pickle 支持
.
def
(
py
::
pickle
(
.
def
(
py
::
pickle
(
[](
const
scheduler
::
ModelSettings
&
self
)
{
// __getstate__
[](
const
scheduler
::
ModelSettings
&
self
)
{
// __getstate__
return
py
::
make_tuple
(
self
.
params_count
,
self
.
layer_count
,
self
.
num_k_heads
,
self
.
k_head_dim
,
return
py
::
make_tuple
(
self
.
params_count
,
self
.
layer_count
,
self
.
bytes_per_params
,
self
.
bytes_per_kv_cache_element
);
self
.
num_k_heads
,
self
.
k_head_dim
,
self
.
bytes_per_params
,
self
.
bytes_per_kv_cache_element
);
},
},
[](
py
::
tuple
t
)
{
// __setstate__
[](
py
::
tuple
t
)
{
// __setstate__
if
(
t
.
size
()
!=
6
)
if
(
t
.
size
()
!=
6
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
ModelSettings
ms
;
scheduler
::
ModelSettings
ms
;
ms
.
params_count
=
t
[
0
].
cast
<
size_t
>
();
ms
.
params_count
=
t
[
0
].
cast
<
size_t
>
();
ms
.
layer_count
=
t
[
1
].
cast
<
size_t
>
();
ms
.
layer_count
=
t
[
1
].
cast
<
size_t
>
();
...
@@ -40,22 +46,24 @@ PYBIND11_MODULE(sched_ext, m) {
...
@@ -40,22 +46,24 @@ PYBIND11_MODULE(sched_ext, m) {
}));
}));
py
::
class_
<
scheduler
::
SampleOptions
>
(
m
,
"SampleOptions"
)
py
::
class_
<
scheduler
::
SampleOptions
>
(
m
,
"SampleOptions"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"temperature"
,
&
scheduler
::
SampleOptions
::
temperature
)
.
def_readwrite
(
"temperature"
,
&
scheduler
::
SampleOptions
::
temperature
)
.
def_readwrite
(
"top_p"
,
&
scheduler
::
SampleOptions
::
top_p
)
// 确保 top_p 也能被访问
.
def_readwrite
(
"top_p"
,
.
def
(
py
::
pickle
(
&
scheduler
::
SampleOptions
::
top_p
)
// 确保 top_p 也能被访问
[](
const
scheduler
::
SampleOptions
&
self
)
{
.
def
(
py
::
pickle
(
return
py
::
make_tuple
(
self
.
temperature
,
self
.
top_p
);
// 序列化 temperature 和 top_p
[](
const
scheduler
::
SampleOptions
&
self
)
{
},
return
py
::
make_tuple
(
self
.
temperature
,
[](
py
::
tuple
t
)
{
self
.
top_p
);
// 序列化 temperature 和 top_p
if
(
t
.
size
()
!=
2
)
// 确保解包时参数数量匹配
},
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
2
)
// 确保解包时参数数量匹配
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
SampleOptions
so
;
scheduler
::
SampleOptions
so
;
so
.
temperature
=
t
[
0
].
cast
<
double
>
();
so
.
temperature
=
t
[
0
].
cast
<
double
>
();
so
.
top_p
=
t
[
1
].
cast
<
double
>
();
// 反序列化 top_p
so
.
top_p
=
t
[
1
].
cast
<
double
>
();
// 反序列化 top_p
return
so
;
return
so
;
}
}));
));
py
::
class_
<
scheduler
::
Settings
>
(
m
,
"Settings"
)
py
::
class_
<
scheduler
::
Settings
>
(
m
,
"Settings"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
...
@@ -65,33 +73,43 @@ PYBIND11_MODULE(sched_ext, m) {
...
@@ -65,33 +73,43 @@ PYBIND11_MODULE(sched_ext, m) {
.
def_readwrite
(
"page_size"
,
&
scheduler
::
Settings
::
page_size
)
.
def_readwrite
(
"page_size"
,
&
scheduler
::
Settings
::
page_size
)
.
def_readwrite
(
"gpu_device_id"
,
&
scheduler
::
Settings
::
gpu_device_id
)
.
def_readwrite
(
"gpu_device_id"
,
&
scheduler
::
Settings
::
gpu_device_id
)
.
def_readwrite
(
"gpu_memory_size"
,
&
scheduler
::
Settings
::
gpu_memory_size
)
.
def_readwrite
(
"gpu_memory_size"
,
&
scheduler
::
Settings
::
gpu_memory_size
)
.
def_readwrite
(
"memory_utilization_percentage"
,
&
scheduler
::
Settings
::
memory_utilization_percentage
)
.
def_readwrite
(
"memory_utilization_percentage"
,
&
scheduler
::
Settings
::
memory_utilization_percentage
)
.
def_readwrite
(
"max_batch_size"
,
&
scheduler
::
Settings
::
max_batch_size
)
.
def_readwrite
(
"max_batch_size"
,
&
scheduler
::
Settings
::
max_batch_size
)
.
def_readwrite
(
"recommended_chunk_prefill_token_count"
,
.
def_readwrite
(
&
scheduler
::
Settings
::
recommended_chunk_prefill_token_count
)
"recommended_chunk_prefill_token_count"
,
&
scheduler
::
Settings
::
recommended_chunk_prefill_token_count
)
.
def_readwrite
(
"sample_options"
,
&
scheduler
::
Settings
::
sample_options
)
.
def_readwrite
(
"sample_options"
,
&
scheduler
::
Settings
::
sample_options
)
.
def_readwrite
(
"sched_metrics_port"
,
&
scheduler
::
Settings
::
sched_metrics_port
)
.
def_readwrite
(
"sched_metrics_port"
,
&
scheduler
::
Settings
::
sched_metrics_port
)
.
def_readwrite
(
"gpu_only"
,
&
scheduler
::
Settings
::
gpu_only
)
.
def_readwrite
(
"gpu_only"
,
&
scheduler
::
Settings
::
gpu_only
)
.
def_readwrite
(
"use_self_defined_head_dim"
,
&
scheduler
::
Settings
::
use_self_defined_head_dim
)
.
def_readwrite
(
"use_self_defined_head_dim"
,
.
def_readwrite
(
"self_defined_head_dim"
,
&
scheduler
::
Settings
::
self_defined_head_dim
)
&
scheduler
::
Settings
::
use_self_defined_head_dim
)
.
def_readwrite
(
"full_kv_cache_on_each_gpu"
,
&
scheduler
::
Settings
::
full_kv_cache_on_each_gpu
)
.
def_readwrite
(
"self_defined_head_dim"
,
&
scheduler
::
Settings
::
self_defined_head_dim
)
.
def_readwrite
(
"full_kv_cache_on_each_gpu"
,
&
scheduler
::
Settings
::
full_kv_cache_on_each_gpu
)
.
def_readwrite
(
"k_cache_on"
,
&
scheduler
::
Settings
::
k_cache_on
)
.
def_readwrite
(
"k_cache_on"
,
&
scheduler
::
Settings
::
k_cache_on
)
.
def_readwrite
(
"v_cache_on"
,
&
scheduler
::
Settings
::
v_cache_on
)
.
def_readwrite
(
"v_cache_on"
,
&
scheduler
::
Settings
::
v_cache_on
)
.
def_readwrite
(
"kvc2_config_path"
,
&
scheduler
::
Settings
::
kvc2_config_path
)
.
def_readwrite
(
"kvc2_config_path"
,
&
scheduler
::
Settings
::
kvc2_config_path
)
.
def_readwrite
(
"kvc2_root_path"
,
&
scheduler
::
Settings
::
kvc2_root_path
)
.
def_readwrite
(
"kvc2_root_path"
,
&
scheduler
::
Settings
::
kvc2_root_path
)
.
def_readwrite
(
"memory_pool_size_GB"
,
&
scheduler
::
Settings
::
memory_pool_size_GB
)
.
def_readwrite
(
"memory_pool_size_GB"
,
&
scheduler
::
Settings
::
memory_pool_size_GB
)
.
def_readwrite
(
"evict_count"
,
&
scheduler
::
Settings
::
evict_count
)
.
def_readwrite
(
"evict_count"
,
&
scheduler
::
Settings
::
evict_count
)
.
def_readwrite
(
"strategy_name"
,
&
scheduler
::
Settings
::
strategy_name
)
.
def_readwrite
(
"strategy_name"
,
&
scheduler
::
Settings
::
strategy_name
)
.
def_readwrite
(
"kvc2_metrics_port"
,
&
scheduler
::
Settings
::
kvc2_metrics_port
)
.
def_readwrite
(
"kvc2_metrics_port"
,
&
scheduler
::
Settings
::
kvc2_metrics_port
)
.
def_readwrite
(
"load_from_disk"
,
&
scheduler
::
Settings
::
load_from_disk
)
.
def_readwrite
(
"load_from_disk"
,
&
scheduler
::
Settings
::
load_from_disk
)
.
def_readwrite
(
"save_to_disk"
,
&
scheduler
::
Settings
::
save_to_disk
)
.
def_readwrite
(
"save_to_disk"
,
&
scheduler
::
Settings
::
save_to_disk
)
// derived
// derived
.
def_readwrite
(
"gpu_device_count"
,
&
scheduler
::
Settings
::
gpu_device_count
)
.
def_readwrite
(
"gpu_device_count"
,
&
scheduler
::
Settings
::
gpu_device_count
)
.
def_readwrite
(
"total_kvcache_pages"
,
&
scheduler
::
Settings
::
total_kvcache_pages
)
.
def_readwrite
(
"total_kvcache_pages"
,
&
scheduler
::
Settings
::
total_kvcache_pages
)
.
def_readwrite
(
"devices"
,
&
scheduler
::
Settings
::
devices
)
.
def_readwrite
(
"devices"
,
&
scheduler
::
Settings
::
devices
)
.
def
(
"auto_derive"
,
&
scheduler
::
Settings
::
auto_derive
);
.
def
(
"auto_derive"
,
&
scheduler
::
Settings
::
auto_derive
);
py
::
class_
<
scheduler
::
BatchQueryTodo
,
std
::
shared_ptr
<
scheduler
::
BatchQueryTodo
>>
(
m
,
"BatchQueryTodo"
)
py
::
class_
<
scheduler
::
BatchQueryTodo
,
std
::
shared_ptr
<
scheduler
::
BatchQueryTodo
>>
(
m
,
"BatchQueryTodo"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"query_ids"
,
&
scheduler
::
BatchQueryTodo
::
query_ids
)
.
def_readwrite
(
"query_ids"
,
&
scheduler
::
BatchQueryTodo
::
query_ids
)
.
def_readwrite
(
"query_tokens"
,
&
scheduler
::
BatchQueryTodo
::
query_tokens
)
.
def_readwrite
(
"query_tokens"
,
&
scheduler
::
BatchQueryTodo
::
query_tokens
)
...
@@ -99,31 +117,42 @@ PYBIND11_MODULE(sched_ext, m) {
...
@@ -99,31 +117,42 @@ PYBIND11_MODULE(sched_ext, m) {
.
def_readwrite
(
"block_indexes"
,
&
scheduler
::
BatchQueryTodo
::
block_indexes
)
.
def_readwrite
(
"block_indexes"
,
&
scheduler
::
BatchQueryTodo
::
block_indexes
)
.
def_readwrite
(
"attn_masks"
,
&
scheduler
::
BatchQueryTodo
::
attn_masks
)
.
def_readwrite
(
"attn_masks"
,
&
scheduler
::
BatchQueryTodo
::
attn_masks
)
.
def_readwrite
(
"rope_ranges"
,
&
scheduler
::
BatchQueryTodo
::
rope_ranges
)
.
def_readwrite
(
"rope_ranges"
,
&
scheduler
::
BatchQueryTodo
::
rope_ranges
)
.
def_readwrite
(
"sample_options"
,
&
scheduler
::
BatchQueryTodo
::
sample_options
)
.
def_readwrite
(
"sample_options"
,
.
def_readwrite
(
"prefill_mini_batches"
,
&
scheduler
::
BatchQueryTodo
::
prefill_mini_batches
)
&
scheduler
::
BatchQueryTodo
::
sample_options
)
.
def_readwrite
(
"decode_mini_batches"
,
&
scheduler
::
BatchQueryTodo
::
decode_mini_batches
)
.
def_readwrite
(
"prefill_mini_batches"
,
&
scheduler
::
BatchQueryTodo
::
prefill_mini_batches
)
.
def_readwrite
(
"decode_mini_batches"
,
&
scheduler
::
BatchQueryTodo
::
decode_mini_batches
)
.
def_readwrite
(
"stop_criteria"
,
&
scheduler
::
BatchQueryTodo
::
stop_criteria
)
.
def_readwrite
(
"stop_criteria"
,
&
scheduler
::
BatchQueryTodo
::
stop_criteria
)
.
def
(
"debug"
,
&
scheduler
::
BatchQueryTodo
::
debug
)
.
def
(
"debug"
,
&
scheduler
::
BatchQueryTodo
::
debug
)
.
def
(
py
::
pickle
(
.
def
(
py
::
pickle
(
[](
const
scheduler
::
BatchQueryTodo
&
self
)
{
[](
const
scheduler
::
BatchQueryTodo
&
self
)
{
return
py
::
make_tuple
(
self
.
query_ids
,
self
.
query_tokens
,
self
.
query_lengths
,
self
.
block_indexes
,
return
py
::
make_tuple
(
self
.
attn_masks
,
self
.
rope_ranges
,
self
.
sample_options
,
self
.
prefill_mini_batches
,
self
.
query_ids
,
self
.
query_tokens
,
self
.
query_lengths
,
self
.
decode_mini_batches
,
self
.
stop_criteria
);
self
.
block_indexes
,
self
.
attn_masks
,
self
.
rope_ranges
,
self
.
sample_options
,
self
.
prefill_mini_batches
,
self
.
decode_mini_batches
,
self
.
stop_criteria
);
},
},
[](
py
::
tuple
t
)
{
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
10
)
if
(
t
.
size
()
!=
10
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
BatchQueryTodo
bqt
;
scheduler
::
BatchQueryTodo
bqt
;
bqt
.
query_ids
=
t
[
0
].
cast
<
std
::
vector
<
scheduler
::
QueryID
>>
();
bqt
.
query_ids
=
t
[
0
].
cast
<
std
::
vector
<
scheduler
::
QueryID
>>
();
bqt
.
query_tokens
=
t
[
1
].
cast
<
std
::
vector
<
torch
::
Tensor
>>
();
bqt
.
query_tokens
=
t
[
1
].
cast
<
std
::
vector
<
torch
::
Tensor
>>
();
bqt
.
query_lengths
=
t
[
2
].
cast
<
std
::
vector
<
scheduler
::
TokenLength
>>
();
bqt
.
query_lengths
=
t
[
2
].
cast
<
std
::
vector
<
scheduler
::
TokenLength
>>
();
bqt
.
block_indexes
=
t
[
3
].
cast
<
std
::
vector
<
torch
::
Tensor
>>
();
bqt
.
block_indexes
=
t
[
3
].
cast
<
std
::
vector
<
torch
::
Tensor
>>
();
bqt
.
attn_masks
=
t
[
4
].
cast
<
std
::
optional
<
torch
::
Tensor
>>
();
bqt
.
attn_masks
=
t
[
4
].
cast
<
std
::
optional
<
torch
::
Tensor
>>
();
bqt
.
rope_ranges
=
t
[
5
].
cast
<
std
::
optional
<
torch
::
Tensor
>>
();
bqt
.
rope_ranges
=
t
[
5
].
cast
<
std
::
optional
<
torch
::
Tensor
>>
();
bqt
.
sample_options
=
t
[
6
].
cast
<
std
::
vector
<
scheduler
::
SampleOptions
>>
();
bqt
.
sample_options
=
bqt
.
prefill_mini_batches
=
t
[
7
].
cast
<
std
::
vector
<
scheduler
::
PrefillTask
>>
();
t
[
6
].
cast
<
std
::
vector
<
scheduler
::
SampleOptions
>>
();
bqt
.
decode_mini_batches
=
t
[
8
].
cast
<
std
::
vector
<
std
::
vector
<
scheduler
::
QueryID
>>>
();
bqt
.
prefill_mini_batches
=
bqt
.
stop_criteria
=
t
[
9
].
cast
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>>
();
t
[
7
].
cast
<
std
::
vector
<
scheduler
::
PrefillTask
>>
();
bqt
.
decode_mini_batches
=
t
[
8
].
cast
<
std
::
vector
<
std
::
vector
<
scheduler
::
QueryID
>>>
();
bqt
.
stop_criteria
=
t
[
9
].
cast
<
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>>
();
return
bqt
;
return
bqt
;
}));
}));
...
@@ -133,16 +162,20 @@ PYBIND11_MODULE(sched_ext, m) {
...
@@ -133,16 +162,20 @@ PYBIND11_MODULE(sched_ext, m) {
.
def_readwrite
(
"ok"
,
&
scheduler
::
QueryUpdate
::
ok
)
.
def_readwrite
(
"ok"
,
&
scheduler
::
QueryUpdate
::
ok
)
.
def_readwrite
(
"is_prefill"
,
&
scheduler
::
QueryUpdate
::
is_prefill
)
.
def_readwrite
(
"is_prefill"
,
&
scheduler
::
QueryUpdate
::
is_prefill
)
.
def_readwrite
(
"decode_done"
,
&
scheduler
::
QueryUpdate
::
decode_done
)
.
def_readwrite
(
"decode_done"
,
&
scheduler
::
QueryUpdate
::
decode_done
)
.
def_readwrite
(
"active_position"
,
&
scheduler
::
QueryUpdate
::
active_position
)
.
def_readwrite
(
"active_position"
,
.
def_readwrite
(
"generated_token"
,
&
scheduler
::
QueryUpdate
::
generated_token
)
&
scheduler
::
QueryUpdate
::
active_position
)
.
def_readwrite
(
"generated_token"
,
&
scheduler
::
QueryUpdate
::
generated_token
)
.
def
(
py
::
pickle
(
.
def
(
py
::
pickle
(
[](
const
scheduler
::
QueryUpdate
&
self
)
{
[](
const
scheduler
::
QueryUpdate
&
self
)
{
return
py
::
make_tuple
(
self
.
id
,
self
.
ok
,
self
.
is_prefill
,
self
.
decode_done
,
self
.
active_position
,
return
py
::
make_tuple
(
self
.
id
,
self
.
ok
,
self
.
is_prefill
,
self
.
decode_done
,
self
.
active_position
,
self
.
generated_token
);
self
.
generated_token
);
},
},
[](
py
::
tuple
t
)
{
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
6
)
if
(
t
.
size
()
!=
6
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
QueryUpdate
qu
;
scheduler
::
QueryUpdate
qu
;
qu
.
id
=
t
[
0
].
cast
<
scheduler
::
QueryID
>
();
qu
.
id
=
t
[
0
].
cast
<
scheduler
::
QueryID
>
();
qu
.
ok
=
t
[
1
].
cast
<
bool
>
();
qu
.
ok
=
t
[
1
].
cast
<
bool
>
();
...
@@ -156,8 +189,7 @@ PYBIND11_MODULE(sched_ext, m) {
...
@@ -156,8 +189,7 @@ PYBIND11_MODULE(sched_ext, m) {
py
::
class_
<
scheduler
::
InferenceContext
>
(
m
,
"InferenceContext"
)
py
::
class_
<
scheduler
::
InferenceContext
>
(
m
,
"InferenceContext"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"k_cache"
,
&
scheduler
::
InferenceContext
::
k_cache
)
.
def_readwrite
(
"k_cache"
,
&
scheduler
::
InferenceContext
::
k_cache
)
.
def_readwrite
(
"v_cache"
,
&
scheduler
::
InferenceContext
::
v_cache
)
.
def_readwrite
(
"v_cache"
,
&
scheduler
::
InferenceContext
::
v_cache
);
;
py
::
class_
<
scheduler
::
QueryAdd
>
(
m
,
"QueryAdd"
)
py
::
class_
<
scheduler
::
QueryAdd
>
(
m
,
"QueryAdd"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
...
@@ -173,15 +205,18 @@ PYBIND11_MODULE(sched_ext, m) {
...
@@ -173,15 +205,18 @@ PYBIND11_MODULE(sched_ext, m) {
.
def
(
"serialize"
,
&
scheduler
::
QueryAdd
::
serialize
)
.
def
(
"serialize"
,
&
scheduler
::
QueryAdd
::
serialize
)
.
def_static
(
"deserialize"
,
&
scheduler
::
QueryAdd
::
deserialize
)
.
def_static
(
"deserialize"
,
&
scheduler
::
QueryAdd
::
deserialize
)
.
def
(
py
::
pickle
(
.
def
(
py
::
pickle
(
[](
const
scheduler
::
QueryAdd
&
self
)
{
[](
const
scheduler
::
QueryAdd
&
self
)
{
return
py
::
make_tuple
(
self
.
query_token
,
return
py
::
make_tuple
(
self
.
query_token
,
// self.attn_mask,
// self.attn_mask,
self
.
query_length
,
self
.
estimated_length
,
self
.
sample_options
,
self
.
user_id
,
self
.
query_length
,
self
.
estimated_length
,
self
.
SLO_TTFT_ms
,
self
.
SLO_TBT_ms
,
self
.
stop_criteria
);
self
.
sample_options
,
self
.
user_id
,
self
.
SLO_TTFT_ms
,
self
.
SLO_TBT_ms
,
self
.
stop_criteria
);
},
},
[](
py
::
tuple
t
)
{
[](
py
::
tuple
t
)
{
if
(
t
.
size
()
!=
8
)
if
(
t
.
size
()
!=
8
)
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
throw
std
::
runtime_error
(
"Invalid state! t.size() = "
+
std
::
to_string
(
t
.
size
()));
scheduler
::
QueryAdd
qa
;
scheduler
::
QueryAdd
qa
;
qa
.
query_token
=
t
[
0
].
cast
<
std
::
vector
<
scheduler
::
Token
>>
();
qa
.
query_token
=
t
[
0
].
cast
<
std
::
vector
<
scheduler
::
Token
>>
();
// qa.attn_mask = t[1].cast<torch::Tensor>();
// qa.attn_mask = t[1].cast<torch::Tensor>();
...
@@ -195,14 +230,20 @@ PYBIND11_MODULE(sched_ext, m) {
...
@@ -195,14 +230,20 @@ PYBIND11_MODULE(sched_ext, m) {
return
qa
;
return
qa
;
}));
}));
py
::
class_
<
scheduler
::
Scheduler
,
std
::
shared_ptr
<
scheduler
::
Scheduler
>>
(
m
,
"Scheduler"
)
py
::
class_
<
scheduler
::
Scheduler
,
std
::
shared_ptr
<
scheduler
::
Scheduler
>>
(
m
,
"Scheduler"
)
.
def
(
"init"
,
&
scheduler
::
Scheduler
::
init
)
.
def
(
"init"
,
&
scheduler
::
Scheduler
::
init
)
.
def
(
"run"
,
&
scheduler
::
Scheduler
::
run
)
.
def
(
"run"
,
&
scheduler
::
Scheduler
::
run
)
.
def
(
"stop"
,
&
scheduler
::
Scheduler
::
stop
)
.
def
(
"stop"
,
&
scheduler
::
Scheduler
::
stop
)
.
def
(
"add_query"
,
&
scheduler
::
Scheduler
::
add_query
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"add_query"
,
&
scheduler
::
Scheduler
::
add_query
,
.
def
(
"cancel_query"
,
&
scheduler
::
Scheduler
::
cancel_query
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"update_last_batch"
,
&
scheduler
::
Scheduler
::
update_last_batch
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"cancel_query"
,
&
scheduler
::
Scheduler
::
cancel_query
,
.
def
(
"get_inference_context"
,
&
scheduler
::
Scheduler
::
get_inference_context
);
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"update_last_batch"
,
&
scheduler
::
Scheduler
::
update_last_batch
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"get_inference_context"
,
&
scheduler
::
Scheduler
::
get_inference_context
);
m
.
def
(
"create_scheduler"
,
&
scheduler
::
create_scheduler
,
"Create a new Scheduler instance"
);
m
.
def
(
"create_scheduler"
,
&
scheduler
::
create_scheduler
,
"Create a new Scheduler instance"
);
}
}
csrc/balance_serve/sched/metrics.cpp
View file @
64de7843
...
@@ -2,89 +2,101 @@
...
@@ -2,89 +2,101 @@
#include <iostream>
#include <iostream>
// 构造函数
// 构造函数
Metrics
::
Metrics
(
const
MetricsConfig
&
config
)
Metrics
::
Metrics
(
const
MetricsConfig
&
config
)
:
registry_
(
std
::
make_shared
<
prometheus
::
Registry
>
()),
:
registry_
(
std
::
make_shared
<
prometheus
::
Registry
>
()),
exposer_
(
config
.
endpoint
),
exposer_
(
config
.
endpoint
),
stop_uptime_thread_
(
false
),
stop_uptime_thread_
(
false
),
start_time_
(
std
::
chrono
::
steady_clock
::
now
())
{
start_time_
(
std
::
chrono
::
steady_clock
::
now
())
{
// 定义统一的桶大小,最大为 10000 ms (10 s)
// 定义统一的桶大小,最大为 10000 ms (10 s)
std
::
vector
<
double
>
common_buckets
=
{
0.001
,
0.005
,
0.01
,
0.05
,
0.1
,
0.5
,
1.0
,
5.0
,
std
::
vector
<
double
>
common_buckets
=
{
10.0
,
50.0
,
100.0
,
500.0
,
1000.0
,
5000.0
,
10000.0
};
// 毫秒
0.001
,
0.005
,
0.01
,
0.05
,
0.1
,
0.5
,
1.0
,
5.0
,
10.0
,
50.0
,
100.0
,
500.0
,
1000.0
,
5000.0
,
10000.0
};
// 毫秒
// 注册 TTFT_ms Histogram
// 注册 TTFT_ms Histogram
auto
&
TTFT_family
=
prometheus
::
BuildHistogram
()
auto
&
TTFT_family
=
prometheus
::
BuildHistogram
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_TTFT_ms"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_TTFT_ms"
)
.
Help
(
"Time to first token in milliseconds"
)
.
Help
(
"Time to first token in milliseconds"
)
.
Register
(
*
registry_
);
.
Register
(
*
registry_
);
TTFT_ms
=
&
TTFT_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
TTFT_ms
=
&
TTFT_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
// 注册 TBT_ms Histogram
// 注册 TBT_ms Histogram
auto
&
TBT_family
=
prometheus
::
BuildHistogram
()
auto
&
TBT_family
=
prometheus
::
BuildHistogram
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_TBT_ms"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_TBT_ms"
)
.
Help
(
"Time between tokens in milliseconds"
)
.
Help
(
"Time between tokens in milliseconds"
)
.
Register
(
*
registry_
);
.
Register
(
*
registry_
);
TBT_ms
=
&
TBT_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
TBT_ms
=
&
TBT_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
// 注册 schedule_time Histogram
// 注册 schedule_time Histogram
auto
&
schedule_time_family
=
prometheus
::
BuildHistogram
()
auto
&
schedule_time_family
=
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_schedule_time_ms"
)
prometheus
::
BuildHistogram
()
.
Help
(
"Time to generate schedule in milliseconds"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_schedule_time_ms"
)
.
Register
(
*
registry_
);
.
Help
(
"Time to generate schedule in milliseconds"
)
schedule_time
=
&
schedule_time_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
.
Register
(
*
registry_
);
schedule_time
=
&
schedule_time_family
.
Add
({{
"model"
,
config
.
model_name
}},
common_buckets
);
// 注册 generated_tokens Counter
// 注册 generated_tokens Counter
auto
&
generated_tokens_family
=
prometheus
::
BuildCounter
()
auto
&
generated_tokens_family
=
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_generated_tokens_total"
)
prometheus
::
BuildCounter
()
.
Help
(
"Total generated tokens"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_generated_tokens_total"
)
.
Register
(
*
registry_
);
.
Help
(
"Total generated tokens"
)
generated_tokens
=
&
generated_tokens_family
.
Add
({{
"model"
,
config
.
model_name
}});
.
Register
(
*
registry_
);
generated_tokens
=
&
generated_tokens_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 throughput_query Gauge
// 注册 throughput_query Gauge
auto
&
throughput_query_family
=
prometheus
::
BuildGauge
()
auto
&
throughput_query_family
=
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_throughput_query"
)
prometheus
::
BuildGauge
()
.
Help
(
"Throughput per second based on queries"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_throughput_query"
)
.
Register
(
*
registry_
);
.
Help
(
"Throughput per second based on queries"
)
throughput_query
=
&
throughput_query_family
.
Add
({{
"model"
,
config
.
model_name
}});
.
Register
(
*
registry_
);
throughput_query
=
&
throughput_query_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 throughput_generated_tokens Gauge
// 注册 throughput_generated_tokens Gauge
auto
&
throughput_generated_tokens_family
=
prometheus
::
BuildGauge
()
auto
&
throughput_generated_tokens_family
=
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_throughput_generated_tokens"
)
prometheus
::
BuildGauge
()
.
Help
(
"Throughput per second based on generated tokens"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_throughput_generated_tokens"
)
.
Register
(
*
registry_
);
.
Help
(
"Throughput per second based on generated tokens"
)
throughput_generated_tokens
=
&
throughput_generated_tokens_family
.
Add
({{
"model"
,
config
.
model_name
}});
.
Register
(
*
registry_
);
throughput_generated_tokens
=
&
throughput_generated_tokens_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 event_count Counter family
// 注册 event_count Counter family
event_count_family_
=
&
prometheus
::
BuildCounter
()
event_count_family_
=
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_event_count_total"
)
&
prometheus
::
BuildCounter
()
.
Help
(
"Count of various events"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_event_count_total"
)
.
Register
(
*
registry_
);
.
Help
(
"Count of various events"
)
.
Register
(
*
registry_
);
batch_count_family_
=
&
prometheus
::
BuildCounter
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_batch_count_total"
)
batch_count_family_
=
.
Help
(
"Count of various batch by status"
)
&
prometheus
::
BuildCounter
()
.
Register
(
*
registry_
);
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_batch_count_total"
)
.
Help
(
"Count of various batch by status"
)
.
Register
(
*
registry_
);
// 注册 query_count Counter family
// 注册 query_count Counter family
query_count_family_
=
&
prometheus
::
BuildCounter
()
query_count_family_
=
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_query_count_total"
)
&
prometheus
::
BuildCounter
()
.
Help
(
"Count of queries by status"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_query_count_total"
)
.
Register
(
*
registry_
);
.
Help
(
"Count of queries by status"
)
.
Register
(
*
registry_
);
// 注册 uptime_ms Gauge
// 注册 uptime_ms Gauge
auto
&
uptime_family
=
prometheus
::
BuildGauge
()
auto
&
uptime_family
=
prometheus
::
BuildGauge
()
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_uptime_ms"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_uptime_ms"
)
.
Help
(
"Uptime of the scheduler in milliseconds"
)
.
Help
(
"Uptime of the scheduler in milliseconds"
)
.
Register
(
*
registry_
);
.
Register
(
*
registry_
);
uptime_ms
=
&
uptime_family
.
Add
({{
"model"
,
config
.
model_name
}});
uptime_ms
=
&
uptime_family
.
Add
({{
"model"
,
config
.
model_name
}});
// 注册 GPU 利用率 Gauges
// 注册 GPU 利用率 Gauges
auto
&
gpu_util_family
=
prometheus
::
BuildGauge
()
auto
&
gpu_util_family
=
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_gpu_utilization_ratio"
)
prometheus
::
BuildGauge
()
.
Help
(
"Current GPU utilization ratio (0 to 1)"
)
.
Name
(
std
::
string
(
METRIC_PREFIX
)
+
"_gpu_utilization_ratio"
)
.
Register
(
*
registry_
);
.
Help
(
"Current GPU utilization ratio (0 to 1)"
)
.
Register
(
*
registry_
);
for
(
size_t
i
=
0
;
i
<
config
.
gpu_count
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
config
.
gpu_count
;
++
i
)
{
gpu_utilization_gauges
.
push_back
(
gpu_utilization_gauges
.
push_back
(
&
gpu_util_family
.
Add
(
&
gpu_util_family
.
Add
(
{{
"gpu_id"
,
std
::
to_string
(
i
)},
{
"model"
,
config
.
model_name
}}));
{{
"gpu_id"
,
std
::
to_string
(
i
)},
{
"model"
,
config
.
model_name
}}));
}
}
// 将 Registry 注册到 Exposer 中
// 将 Registry 注册到 Exposer 中
...
@@ -95,16 +107,15 @@ Metrics::Metrics(const MetricsConfig& config)
...
@@ -95,16 +107,15 @@ Metrics::Metrics(const MetricsConfig& config)
}
}
// 析构函数
// 析构函数
Metrics
::~
Metrics
()
{
Metrics
::~
Metrics
()
{
StopUptimeUpdater
();
}
StopUptimeUpdater
();
}
// 启动 uptime 更新线程
// 启动 uptime 更新线程
void
Metrics
::
StartUptimeUpdater
()
{
void
Metrics
::
StartUptimeUpdater
()
{
uptime_thread_
=
std
::
thread
([
this
]()
{
uptime_thread_
=
std
::
thread
([
this
]()
{
while
(
!
stop_uptime_thread_
)
{
while
(
!
stop_uptime_thread_
)
{
auto
now
=
std
::
chrono
::
steady_clock
::
now
();
auto
now
=
std
::
chrono
::
steady_clock
::
now
();
std
::
chrono
::
duration
<
double
,
std
::
milli
>
uptime_duration
=
now
-
start_time_
;
std
::
chrono
::
duration
<
double
,
std
::
milli
>
uptime_duration
=
now
-
start_time_
;
uptime_ms
->
Set
(
uptime_duration
.
count
());
uptime_ms
->
Set
(
uptime_duration
.
count
());
// fn_every_sec(this);
// fn_every_sec(this);
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
seconds
(
1
));
...
@@ -121,15 +132,16 @@ void Metrics::StopUptimeUpdater() {
...
@@ -121,15 +132,16 @@ void Metrics::StopUptimeUpdater() {
}
}
// 获取 event_count 指标
// 获取 event_count 指标
prometheus
::
Counter
*
Metrics
::
event_count
(
const
std
::
string
&
type
)
{
prometheus
::
Counter
*
Metrics
::
event_count
(
const
std
::
string
&
type
)
{
return
&
event_count_family_
->
Add
({{
"type"
,
type
}});
// 可根据需要添加更多标签
return
&
event_count_family_
->
Add
({{
"type"
,
type
}});
// 可根据需要添加更多标签
}
}
// 获取 query_count 指标
// 获取 query_count 指标
prometheus
::
Counter
*
Metrics
::
query_count
(
const
std
::
string
&
status
)
{
prometheus
::
Counter
*
Metrics
::
query_count
(
const
std
::
string
&
status
)
{
return
&
query_count_family_
->
Add
({{
"status"
,
status
}});
// 可根据需要添加更多标签
return
&
query_count_family_
->
Add
(
{{
"status"
,
status
}});
// 可根据需要添加更多标签
}
}
prometheus
::
Counter
*
Metrics
::
batch_count
(
const
std
::
string
&
type
)
{
prometheus
::
Counter
*
Metrics
::
batch_count
(
const
std
::
string
&
type
)
{
return
&
batch_count_family_
->
Add
({{
"type"
,
type
}});
return
&
batch_count_family_
->
Add
({{
"type"
,
type
}});
}
}
csrc/balance_serve/sched/metrics.h
View file @
64de7843
#ifndef Metrics_H
#ifndef Metrics_H
#define Metrics_H
#define Metrics_H
#include <atomic>
#include <chrono>
#include <memory>
#include <prometheus/counter.h>
#include <prometheus/counter.h>
#include <prometheus/exposer.h>
#include <prometheus/exposer.h>
#include <prometheus/gauge.h>
#include <prometheus/gauge.h>
#include <prometheus/histogram.h>
#include <prometheus/histogram.h>
#include <prometheus/registry.h>
#include <prometheus/registry.h>
#include <atomic>
#include <chrono>
#include <memory>
#include <string>
#include <string>
#include <thread>
#include <thread>
#include <vector>
#include <vector>
...
@@ -21,46 +21,46 @@ class Metrics;
...
@@ -21,46 +21,46 @@ class Metrics;
// 配置结构体
// 配置结构体
struct
MetricsConfig
{
struct
MetricsConfig
{
std
::
string
endpoint
;
std
::
string
endpoint
;
std
::
string
model_name
;
// 模型名称,如 "gpt-4"
std
::
string
model_name
;
// 模型名称,如 "gpt-4"
size_t
gpu_count
;
// GPU数量
size_t
gpu_count
;
// GPU数量
};
};
// Metrics 类,根据配置初始化 Prometheus 指标
// Metrics 类,根据配置初始化 Prometheus 指标
class
Metrics
{
class
Metrics
{
public:
public:
// 构造函数传入 MetricsConfig
// 构造函数传入 MetricsConfig
Metrics
(
const
MetricsConfig
&
config
);
Metrics
(
const
MetricsConfig
&
config
);
~
Metrics
();
~
Metrics
();
// 禁止拷贝和赋值
// 禁止拷贝和赋值
Metrics
(
const
Metrics
&
)
=
delete
;
Metrics
(
const
Metrics
&
)
=
delete
;
Metrics
&
operator
=
(
const
Metrics
&
)
=
delete
;
Metrics
&
operator
=
(
const
Metrics
&
)
=
delete
;
std
::
function
<
void
(
Metrics
*
)
>
fn_every_sec
;
std
::
function
<
void
(
Metrics
*
)
>
fn_every_sec
;
// 指标指针
// 指标指针
prometheus
::
Gauge
*
uptime_ms
;
prometheus
::
Gauge
*
uptime_ms
;
prometheus
::
Histogram
*
TTFT_ms
;
prometheus
::
Histogram
*
TTFT_ms
;
prometheus
::
Histogram
*
TBT_ms
;
prometheus
::
Histogram
*
TBT_ms
;
prometheus
::
Histogram
*
schedule_time
;
prometheus
::
Histogram
*
schedule_time
;
prometheus
::
Gauge
*
throughput_query
;
prometheus
::
Gauge
*
throughput_query
;
prometheus
::
Gauge
*
throughput_generated_tokens
;
prometheus
::
Gauge
*
throughput_generated_tokens
;
prometheus
::
Counter
*
generated_tokens
;
prometheus
::
Counter
*
generated_tokens
;
std
::
vector
<
prometheus
::
Gauge
*>
gpu_utilization_gauges
;
std
::
vector
<
prometheus
::
Gauge
*>
gpu_utilization_gauges
;
// 计数器家族
// 计数器家族
prometheus
::
Counter
*
event_count
(
const
std
::
string
&
type
);
prometheus
::
Counter
*
event_count
(
const
std
::
string
&
type
);
prometheus
::
Counter
*
query_count
(
const
std
::
string
&
status
);
prometheus
::
Counter
*
query_count
(
const
std
::
string
&
status
);
prometheus
::
Counter
*
batch_count
(
const
std
::
string
&
type
);
prometheus
::
Counter
*
batch_count
(
const
std
::
string
&
type
);
private:
private:
std
::
shared_ptr
<
prometheus
::
Registry
>
registry_
;
std
::
shared_ptr
<
prometheus
::
Registry
>
registry_
;
prometheus
::
Exposer
exposer_
;
prometheus
::
Exposer
exposer_
;
// 计数器家族
// 计数器家族
prometheus
::
Family
<
prometheus
::
Counter
>
*
event_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>
*
event_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>
*
batch_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>
*
batch_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>
*
query_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>
*
query_count_family_
;
// 线程和控制变量用于更新 uptime_ms
// 线程和控制变量用于更新 uptime_ms
std
::
thread
uptime_thread_
;
std
::
thread
uptime_thread_
;
...
@@ -76,10 +76,13 @@ class Metrics {
...
@@ -76,10 +76,13 @@ class Metrics {
};
};
struct
HistogramTimerWrapper
{
struct
HistogramTimerWrapper
{
prometheus
::
Histogram
*
histogram
;
prometheus
::
Histogram
*
histogram
;
Timer
timer
;
Timer
timer
;
inline
HistogramTimerWrapper
(
prometheus
::
Histogram
*
histogram
)
:
histogram
(
histogram
),
timer
()
{
timer
.
start
();
}
inline
HistogramTimerWrapper
(
prometheus
::
Histogram
*
histogram
)
:
histogram
(
histogram
),
timer
()
{
timer
.
start
();
}
inline
~
HistogramTimerWrapper
()
{
histogram
->
Observe
(
timer
.
elapsedMs
());
}
inline
~
HistogramTimerWrapper
()
{
histogram
->
Observe
(
timer
.
elapsedMs
());
}
};
};
#endif
// Metrics_H
#endif // Metrics_H
csrc/balance_serve/sched/model_config.h
View file @
64de7843
#ifndef __MODEL_CONFIG_HPP_
#ifndef __MODEL_CONFIG_HPP_
#define __MODEL_CONFIG_HPP_
#define __MODEL_CONFIG_HPP_
#include <iostream>
#include "nlohmann/json.hpp"
#include "nlohmann/json.hpp"
#include <iostream>
#include <filesystem>
#include <filesystem>
#include <fstream>
#include <fstream>
...
@@ -13,7 +13,7 @@ using ModelName = std::string;
...
@@ -13,7 +13,7 @@ using ModelName = std::string;
// We must assure this can be load by config.json
// We must assure this can be load by config.json
class
ModelConfig
{
class
ModelConfig
{
public:
public:
DimSize
hidden_size
;
DimSize
hidden_size
;
DimSize
intermediate_size
;
DimSize
intermediate_size
;
size_t
max_position_embeddings
;
size_t
max_position_embeddings
;
...
@@ -23,10 +23,13 @@ class ModelConfig {
...
@@ -23,10 +23,13 @@ class ModelConfig {
size_t
num_key_value_heads
;
size_t
num_key_value_heads
;
size_t
vocab_size
;
size_t
vocab_size
;
NLOHMANN_DEFINE_TYPE_INTRUSIVE
(
ModelConfig
,
hidden_size
,
intermediate_size
,
max_position_embeddings
,
model_type
,
NLOHMANN_DEFINE_TYPE_INTRUSIVE
(
ModelConfig
,
hidden_size
,
intermediate_size
,
num_attention_heads
,
num_hidden_layers
,
num_key_value_heads
,
vocab_size
);
max_position_embeddings
,
model_type
,
num_attention_heads
,
num_hidden_layers
,
num_key_value_heads
,
vocab_size
);
void
load_from
(
std
::
filesystem
::
path
path
)
{
void
load_from
(
std
::
filesystem
::
path
path
)
{
std
::
cout
<<
"Load from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
std
::
ifstream
i
(
path
);
nlohmann
::
json
j
;
nlohmann
::
json
j
;
i
>>
j
;
i
>>
j
;
...
@@ -38,12 +41,14 @@ using QuantType = std::string;
...
@@ -38,12 +41,14 @@ using QuantType = std::string;
static
const
QuantType
NoQuantType
=
""
;
static
const
QuantType
NoQuantType
=
""
;
class
QuantConfig
{
class
QuantConfig
{
public:
public:
QuantType
name
;
QuantType
name
;
// For GEMV
// For GEMV
QuantType
type_of_dot_vector
=
NoQuantType
;
QuantType
type_of_dot_vector
=
NoQuantType
;
inline
bool
can_be_used_as_matrix
()
{
return
type_of_dot_vector
!=
NoQuantType
;
}
inline
bool
can_be_used_as_matrix
()
{
return
type_of_dot_vector
!=
NoQuantType
;
}
bool
can_be_used_as_vector
;
bool
can_be_used_as_vector
;
...
@@ -56,8 +61,11 @@ class QuantConfig {
...
@@ -56,8 +61,11 @@ class QuantConfig {
URL
reference
=
""
;
URL
reference
=
""
;
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT
(
QuantConfig
,
name
,
type_of_dot_vector
,
can_be_used_as_vector
,
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT
(
QuantConfig
,
name
,
bytes_per_element
,
has_scale
,
has_min
,
block_element_count
,
type_of_dot_vector
,
can_be_used_as_vector
,
bytes_per_element
,
has_scale
,
has_min
,
block_element_count
,
block_element_size
,
reference
);
block_element_size
,
reference
);
};
};
...
@@ -70,14 +78,13 @@ inline void load_quant_configs(std::filesystem::path path) {
...
@@ -70,14 +78,13 @@ inline void load_quant_configs(std::filesystem::path path) {
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
std
::
ifstream
i
(
path
);
i
>>
j
;
i
>>
j
;
quant_configs
=
j
.
get
<
std
::
map
<
QuantType
,
QuantConfig
>>
();
std
::
cout
<<
"Loaded Quant Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
quant_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
else
{
}
else
{
std
::
cout
<<
__FUNCTION__
<<
" create new at "
<<
path
<<
std
::
endl
;
std
::
cout
<<
__FUNCTION__
<<
" no file at "
<<
path
<<
std
::
endl
;
}
quant_configs
=
j
.
get
<
std
::
map
<
QuantType
,
QuantConfig
>>
();
std
::
cout
<<
"Loaded Quant Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
quant_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
}
}
...
@@ -93,14 +100,13 @@ inline void load_model_configs(std::filesystem::path path) {
...
@@ -93,14 +100,13 @@ inline void load_model_configs(std::filesystem::path path) {
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
cout
<<
__FUNCTION__
<<
" from "
<<
path
<<
std
::
endl
;
std
::
ifstream
i
(
path
);
std
::
ifstream
i
(
path
);
i
>>
j
;
i
>>
j
;
model_configs
=
j
.
get
<
std
::
map
<
ModelName
,
ModelConfig
>>
();
std
::
cout
<<
"Loaded Model Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
model_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
else
{
}
else
{
std
::
cout
<<
__FUNCTION__
<<
" create new at "
<<
path
<<
std
::
endl
;
std
::
cout
<<
__FUNCTION__
<<
" no file at "
<<
path
<<
std
::
endl
;
}
model_configs
=
j
.
get
<
std
::
map
<
ModelName
,
ModelConfig
>>
();
std
::
cout
<<
"Loaded Model Configs"
<<
std
::
endl
;
for
(
auto
&
[
k
,
v
]
:
model_configs
)
{
std
::
cout
<<
" - "
<<
k
<<
std
::
endl
;
}
}
}
}
...
...
csrc/balance_serve/sched/scheduler.cpp
View file @
64de7843
This diff is collapsed.
Click to expand it.
csrc/balance_serve/sched/scheduler.h
View file @
64de7843
#pragma once
#pragma once
#include
<torch/torch
.h
>
#include
"model_config
.h
"
#include <cstdint>
#include <cstdint>
#include <memory>
#include <memory>
#include <optional>
#include <optional>
#include <torch/torch.h>
#include <vector>
#include <vector>
#include "model_config.h"
namespace
scheduler
{
namespace
scheduler
{
...
@@ -28,7 +28,9 @@ struct ModelSettings {
...
@@ -28,7 +28,9 @@ struct ModelSettings {
double
bytes_per_kv_cache_element
;
double
bytes_per_kv_cache_element
;
inline
size_t
params_nbytes
()
{
return
params_count
*
bytes_per_params
;
}
inline
size_t
params_nbytes
()
{
return
params_count
*
bytes_per_params
;
}
inline
size_t
bytes_per_token_kv_cache
()
{
return
bytes_per_kv_cache_element
*
num_k_heads
*
k_head_dim
;
}
inline
size_t
bytes_per_token_kv_cache
()
{
return
bytes_per_kv_cache_element
*
num_k_heads
*
k_head_dim
;
}
};
};
struct
SampleOptions
{
struct
SampleOptions
{
...
@@ -37,15 +39,16 @@ struct SampleOptions {
...
@@ -37,15 +39,16 @@ struct SampleOptions {
};
};
struct
Settings
{
struct
Settings
{
// something is aukward here, kvc2 only use model_name and quant_type to get model infos.
// something is aukward here, kvc2 only use model_name and quant_type to get
// model infos.
ModelName
model_name
;
ModelName
model_name
;
QuantType
quant_type
;
QuantType
quant_type
;
// model_setting is ignore by kvc2
// model_setting is ignore by kvc2
ModelSettings
model_settings
;
ModelSettings
model_settings
;
size_t
page_size
=
256
;
// how many token in a page
size_t
page_size
=
256
;
// how many token in a page
std
::
vector
<
size_t
>
gpu_device_id
;
//
std
::
vector
<
size_t
>
gpu_device_id
;
//
size_t
gpu_memory_size
;
// memory size in bytes of each GPU, each
size_t
gpu_memory_size
;
// memory size in bytes of each GPU, each
double
memory_utilization_percentage
;
double
memory_utilization_percentage
;
size_t
max_batch_size
=
256
;
size_t
max_batch_size
=
256
;
...
@@ -79,14 +82,16 @@ struct Settings {
...
@@ -79,14 +82,16 @@ struct Settings {
void
auto_derive
();
void
auto_derive
();
};
};
using
PrefillTask
=
std
::
tuple
<
QueryID
,
TokenLength
,
TokenLength
>
;
// id, start, length
using
PrefillTask
=
std
::
tuple
<
QueryID
,
TokenLength
,
TokenLength
>
;
// id, start, length
struct
BatchQueryTodo
{
struct
BatchQueryTodo
{
// query
// query
std
::
vector
<
QueryID
>
query_ids
;
std
::
vector
<
QueryID
>
query_ids
;
std
::
vector
<
torch
::
Tensor
>
query_tokens
;
std
::
vector
<
torch
::
Tensor
>
query_tokens
;
std
::
vector
<
TokenLength
>
query_lengths
;
std
::
vector
<
TokenLength
>
query_lengths
;
std
::
vector
<
torch
::
Tensor
>
block_indexes
;
// (max_num_blocks_per_seq), dtype torch.int32.
std
::
vector
<
torch
::
Tensor
>
block_indexes
;
// (max_num_blocks_per_seq), dtype torch.int32.
std
::
optional
<
torch
::
Tensor
>
attn_masks
;
std
::
optional
<
torch
::
Tensor
>
attn_masks
;
std
::
optional
<
torch
::
Tensor
>
rope_ranges
;
std
::
optional
<
torch
::
Tensor
>
rope_ranges
;
std
::
vector
<
SampleOptions
>
sample_options
;
std
::
vector
<
SampleOptions
>
sample_options
;
...
@@ -94,8 +99,10 @@ struct BatchQueryTodo {
...
@@ -94,8 +99,10 @@ struct BatchQueryTodo {
// mini batches, adjacent two mini batches are executed together
// mini batches, adjacent two mini batches are executed together
// tasks count must be <=2, because of flash infer attention
// tasks count must be <=2, because of flash infer attention
std
::
vector
<
PrefillTask
>
prefill_mini_batches
;
// prefill minibatch only has 1 prefill
std
::
vector
<
PrefillTask
>
std
::
vector
<
std
::
vector
<
QueryID
>>
decode_mini_batches
;
// decode minibatch has multiple decode
prefill_mini_batches
;
// prefill minibatch only has 1 prefill
std
::
vector
<
std
::
vector
<
QueryID
>>
decode_mini_batches
;
// decode minibatch has multiple decode
std
::
string
debug
();
std
::
string
debug
();
bool
empty
();
bool
empty
();
...
@@ -105,9 +112,9 @@ struct QueryUpdate {
...
@@ -105,9 +112,9 @@ struct QueryUpdate {
QueryID
id
;
QueryID
id
;
bool
ok
;
bool
ok
;
bool
is_prefill
;
bool
is_prefill
;
bool
decode_done
;
// no use for now
bool
decode_done
;
// no use for now
TokenLength
active_position
;
// the position where no kvcache now,
TokenLength
active_position
;
// the position where no kvcache now,
// kvcache[active_position] == None
// kvcache[active_position] == None
Token
generated_token
;
Token
generated_token
;
...
@@ -117,8 +124,8 @@ struct QueryUpdate {
...
@@ -117,8 +124,8 @@ struct QueryUpdate {
using
BatchQueryUpdate
=
std
::
vector
<
QueryUpdate
>
;
using
BatchQueryUpdate
=
std
::
vector
<
QueryUpdate
>
;
struct
InferenceContext
{
struct
InferenceContext
{
std
::
vector
<
torch
::
Tensor
>
k_cache
;
// [gpu num] (layer_count, num blocks,
std
::
vector
<
torch
::
Tensor
>
k_cache
;
// [gpu num] (layer_count, num blocks,
// page size, kheadnum, head_dim)
// page size, kheadnum, head_dim)
std
::
vector
<
torch
::
Tensor
>
v_cache
;
std
::
vector
<
torch
::
Tensor
>
v_cache
;
};
};
...
@@ -127,7 +134,7 @@ constexpr UserID NoUser = -1;
...
@@ -127,7 +134,7 @@ constexpr UserID NoUser = -1;
const
int
MAX_SLO_TIME
=
1e9
;
const
int
MAX_SLO_TIME
=
1e9
;
struct
QueryAdd
{
struct
QueryAdd
{
std
::
vector
<
Token
>
query_token
;
// int here
std
::
vector
<
Token
>
query_token
;
// int here
// torch::Tensor attn_mask;
// torch::Tensor attn_mask;
TokenLength
query_length
;
TokenLength
query_length
;
TokenLength
estimated_length
;
TokenLength
estimated_length
;
...
@@ -141,11 +148,11 @@ struct QueryAdd {
...
@@ -141,11 +148,11 @@ struct QueryAdd {
int
SLO_TBT_ms
=
MAX_SLO_TIME
;
int
SLO_TBT_ms
=
MAX_SLO_TIME
;
std
::
string
serialize
();
std
::
string
serialize
();
static
QueryAdd
deserialize
(
const
std
::
string
&
input
);
static
QueryAdd
deserialize
(
const
std
::
string
&
input
);
};
};
class
Scheduler
{
class
Scheduler
{
public:
public:
virtual
void
init
(
Settings
settings
)
=
0
;
virtual
void
init
(
Settings
settings
)
=
0
;
virtual
void
run
()
=
0
;
virtual
void
run
()
=
0
;
...
@@ -156,7 +163,8 @@ class Scheduler {
...
@@ -156,7 +163,8 @@ class Scheduler {
virtual
void
cancel_query
(
QueryID
id
)
=
0
;
virtual
void
cancel_query
(
QueryID
id
)
=
0
;
// inference loop call this
// inference loop call this
virtual
std
::
shared_ptr
<
BatchQueryTodo
>
update_last_batch
(
BatchQueryUpdate
updates
)
=
0
;
virtual
std
::
shared_ptr
<
BatchQueryTodo
>
update_last_batch
(
BatchQueryUpdate
updates
)
=
0
;
virtual
InferenceContext
get_inference_context
()
=
0
;
virtual
InferenceContext
get_inference_context
()
=
0
;
virtual
~
Scheduler
()
=
default
;
virtual
~
Scheduler
()
=
default
;
...
@@ -164,4 +172,4 @@ class Scheduler {
...
@@ -164,4 +172,4 @@ class Scheduler {
std
::
shared_ptr
<
Scheduler
>
create_scheduler
(
Settings
settings
);
std
::
shared_ptr
<
Scheduler
>
create_scheduler
(
Settings
settings
);
};
// namespace scheduler
};
// namespace scheduler
\ No newline at end of file
\ No newline at end of file
csrc/balance_serve/sched/utils/arithmetic.hpp
View file @
64de7843
#include <type_traits>
#include <type_traits>
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
>
T
div_up
(
T
x
,
U
by
)
{
T
div_up
(
T
x
,
U
by
)
{
static_assert
(
std
::
is_integral_v
<
T
>
);
static_assert
(
std
::
is_integral_v
<
T
>
);
static_assert
(
std
::
is_integral_v
<
U
>
);
static_assert
(
std
::
is_integral_v
<
U
>
);
return
(
x
+
by
-
1
)
/
by
;
return
(
x
+
by
-
1
)
/
by
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment