Commit 989bb3d5 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: make block_size input for indexer, router, publisher (#66)

parent dd31a322
...@@ -247,7 +247,7 @@ index 1ca9e49d..b1591c0c 100644 ...@@ -247,7 +247,7 @@ index 1ca9e49d..b1591c0c 100644
# Reuse the cached content hash # Reuse the cached content hash
diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py
index c5b3b04f..c72001f7 100644 index c5b3b04f..12cd4dc9 100644
--- a/vllm/core/block_manager.py --- a/vllm/core/block_manager.py
+++ b/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py
@@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block @@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block
...@@ -269,7 +269,7 @@ index c5b3b04f..c72001f7 100644 ...@@ -269,7 +269,7 @@ index c5b3b04f..c72001f7 100644
block_size: int, block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
@@ -91,11 +95,28 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): @@ -91,11 +95,29 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.watermark_blocks = int(watermark * num_gpu_blocks) self.watermark_blocks = int(watermark * num_gpu_blocks)
...@@ -285,7 +285,8 @@ index c5b3b04f..c72001f7 100644 ...@@ -285,7 +285,8 @@ index c5b3b04f..c72001f7 100644
+ namespace=VLLM_KV_NAMESPACE, + namespace=VLLM_KV_NAMESPACE,
+ component=VLLM_KV_COMPONENT, + component=VLLM_KV_COMPONENT,
+ worker_id=VLLM_WORKER_ID, + worker_id=VLLM_WORKER_ID,
+ lib_path=VLLM_KV_CAPI_PATH) + lib_path=VLLM_KV_CAPI_PATH,
+ kv_block_size=block_size)
+ else: + else:
+ self.event_manager = None + self.event_manager = None
+ +
...@@ -300,10 +301,10 @@ index c5b3b04f..c72001f7 100644 ...@@ -300,10 +301,10 @@ index c5b3b04f..c72001f7 100644
self.block_tables: Dict[SeqId, BlockTable] = {} self.block_tables: Dict[SeqId, BlockTable] = {}
diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py
new file mode 100644 new file mode 100644
index 00000000..d3706700 index 00000000..a27af580
--- /dev/null --- /dev/null
+++ b/vllm/core/event_manager.py +++ b/vllm/core/event_manager.py
@@ -0,0 +1,102 @@ @@ -0,0 +1,108 @@
+# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: Apache-2.0
+import ctypes +import ctypes
+import logging +import logging
...@@ -324,16 +325,22 @@ index 00000000..d3706700 ...@@ -324,16 +325,22 @@ index 00000000..d3706700
+class KVCacheEventManager: +class KVCacheEventManager:
+ +
+ def __init__(self, namespace: str, component: str, worker_id: int, + def __init__(self, namespace: str, component: str, worker_id: int,
+ lib_path: str): + lib_path: str, kv_block_size: int):
+ self.lib = None + self.lib = None
+ +
+ try: + try:
+ self.lib = ctypes.CDLL(lib_path) + self.lib = ctypes.CDLL(lib_path)
+ self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64] + self.lib.dynamo_llm_init.argtypes = [
+ c_char_p,
+ c_char_p,
+ c_int64,
+ c_uint32,
+ ]
+ self.lib.dynamo_llm_init.restype = c_uint32 + self.lib.dynamo_llm_init.restype = c_uint32
+ +
+ result = self.lib.dynamo_llm_init(namespace.encode(), + result = self.lib.dynamo_llm_init(
+ component.encode(), worker_id) + namespace.encode(), component.encode(), worker_id, kv_block_size
+ )
+ if result == DynamoResult.OK: + if result == DynamoResult.OK:
+ logger.info( + logger.info(
+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events" + "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
......
...@@ -262,7 +262,8 @@ cd /workspace/examples/python_rs/llm/vllm ...@@ -262,7 +262,8 @@ cd /workspace/examples/python_rs/llm/vllm
RUST_LOG=info python3 -m kv_router.router \ RUST_LOG=info python3 -m kv_router.router \
--routing-strategy prefix \ --routing-strategy prefix \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--min-workers 1 --min-workers 1 \
--block-size 64
``` ```
You can choose only the prefix strategy for now: You can choose only the prefix strategy for now:
......
...@@ -173,13 +173,21 @@ async def worker(runtime: DistributedRuntime, args: Namespace): ...@@ -173,13 +173,21 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
endpoint = router_component.endpoint("generate") endpoint = router_component.endpoint("generate")
if args.custom_router: if args.custom_router:
indexer = KvIndexer(kv_listener) # @REVIEWER - I'm not currently checking if block size matches that of the engine
# If they don't match things will silently fail
# The preferred solution would be for the KV Indexer to read from the MDC in etcd and not bother the user at all
# The second solution would be to do KvIndexer(kv_listener, MDC.block_size)
# as this ensures block size matches that of the engine
# In this case we need to do some sort of handshake or check in case a user just puts in a random block size
indexer = KvIndexer(kv_listener, args.block_size)
metrics_aggregator = KvMetricsAggregator(kv_listener) metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
CustomRouter(indexer, metrics_aggregator).generate CustomRouter(indexer, metrics_aggregator).generate
) )
else: else:
router = KvRouter(runtime, kv_listener) # TODO Read block_size from MDC
router = KvRouter(runtime, kv_listener, args.block_size)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate) await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
...@@ -208,6 +216,11 @@ if __name__ == "__main__": ...@@ -208,6 +216,11 @@ if __name__ == "__main__":
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served", help="Model that is being served",
) )
parser.add_argument(
"--block-size",
type=int,
help="KV block size",
)
parser.add_argument( parser.add_argument(
"--custom-router", "--custom-router",
type=bool, type=bool,
......
...@@ -63,7 +63,7 @@ PROCESSOR_CMD="RUST_LOG=info python3 -m kv_router.processor \ ...@@ -63,7 +63,7 @@ PROCESSOR_CMD="RUST_LOG=info python3 -m kv_router.processor \
--model $MODEL_NAME \ --model $MODEL_NAME \
--tokenizer $MODEL_NAME \ --tokenizer $MODEL_NAME \
--enable-prefix-caching \ --enable-prefix-caching \
--block-size 64 \ --block-size 32 \
--max-model-len 16384 " --max-model-len 16384 "
tmux new-session -d -s "$SESSION_NAME-processor" tmux new-session -d -s "$SESSION_NAME-processor"
tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
...@@ -74,7 +74,8 @@ tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m ...@@ -74,7 +74,8 @@ tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
ROUTER_CMD="RUST_LOG=info python3 -m kv_router.router \ ROUTER_CMD="RUST_LOG=info python3 -m kv_router.router \
--model $MODEL_NAME \ --model $MODEL_NAME \
--routing-strategy $ROUTING_STRATEGY \ --routing-strategy $ROUTING_STRATEGY \
--min-workers $NUM_WORKERS " --min-workers $NUM_WORKERS \
--block-size 32"
tmux new-session -d -s "$SESSION_NAME-router" tmux new-session -d -s "$SESSION_NAME-router"
tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
......
...@@ -51,6 +51,7 @@ class Router: ...@@ -51,6 +51,7 @@ class Router:
self.routing_strategy = RoutingStrategy.PREFIX self.routing_strategy = RoutingStrategy.PREFIX
self.runtime = dynamo_context["runtime"] self.runtime = dynamo_context["runtime"]
self.min_workers = 1 self.min_workers = 1
self.kv_block_size = 64
@async_onstart @async_onstart
async def init_engine(self): async def init_engine(self):
...@@ -77,7 +78,7 @@ class Router: ...@@ -77,7 +78,7 @@ class Router:
kv_listener = self.runtime.namespace("dynamo").component(self.model_name) kv_listener = self.runtime.namespace("dynamo").component(self.model_name)
await kv_listener.create_service() await kv_listener.create_service()
self.router = KvRouter(self.runtime, kv_listener) self.router = KvRouter(self.runtime, kv_listener, self.kv_block_size)
@dynamo_endpoint() @dynamo_endpoint()
async def generate(self, request: Tokens): async def generate(self, request: Tokens):
......
...@@ -53,6 +53,7 @@ pub unsafe extern "C" fn dynamo_llm_init( ...@@ -53,6 +53,7 @@ pub unsafe extern "C" fn dynamo_llm_init(
namespace_c_str: *const c_char, namespace_c_str: *const c_char,
component_c_str: *const c_char, component_c_str: *const c_char,
worker_id: i64, worker_id: i64,
kv_block_size: u32,
) -> DynamoLlmResult { ) -> DynamoLlmResult {
initialize_tracing(); initialize_tracing();
let wk = match WK.get_or_try_init(Worker::from_settings) { let wk = match WK.get_or_try_init(Worker::from_settings) {
...@@ -94,9 +95,9 @@ pub unsafe extern "C" fn dynamo_llm_init( ...@@ -94,9 +95,9 @@ pub unsafe extern "C" fn dynamo_llm_init(
}; };
match result { match result {
Ok(_) => match KV_PUB Ok(_) => match KV_PUB.get_or_try_init(move || {
.get_or_try_init(move || dynamo_create_kv_publisher(namespace, component, worker_id)) dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size as usize)
{ }) {
Ok(_) => DynamoLlmResult::OK, Ok(_) => DynamoLlmResult::OK,
Err(e) => { Err(e) => {
eprintln!("Failed to initialize distributed runtime: {:?}", e); eprintln!("Failed to initialize distributed runtime: {:?}", e);
...@@ -138,6 +139,7 @@ fn dynamo_create_kv_publisher( ...@@ -138,6 +139,7 @@ fn dynamo_create_kv_publisher(
namespace: String, namespace: String,
component: String, component: String,
worker_id: i64, worker_id: i64,
kv_block_size: usize,
) -> Result<KvEventPublisher, anyhow::Error> { ) -> Result<KvEventPublisher, anyhow::Error> {
tracing::info!("Creating KV Publisher for model: {}", component); tracing::info!("Creating KV Publisher for model: {}", component);
match DRT match DRT
...@@ -146,7 +148,7 @@ fn dynamo_create_kv_publisher( ...@@ -146,7 +148,7 @@ fn dynamo_create_kv_publisher(
{ {
Ok(drt) => { Ok(drt) => {
let backend = drt.namespace(namespace)?.component(component)?; let backend = drt.namespace(namespace)?.component(component)?;
KvEventPublisher::new(drt.clone(), backend, worker_id) KvEventPublisher::new(drt.clone(), backend, worker_id, kv_block_size)
} }
Err(e) => Err(e), Err(e) => Err(e),
} }
...@@ -156,10 +158,13 @@ fn kv_event_create_stored_block_from_parts( ...@@ -156,10 +158,13 @@ fn kv_event_create_stored_block_from_parts(
block_hash: u64, block_hash: u64,
token_ids: *const u32, token_ids: *const u32,
num_tokens: usize, num_tokens: usize,
kv_block_size: usize,
_lora_id: u64, _lora_id: u64,
) -> KvCacheStoredBlockData { ) -> KvCacheStoredBlockData {
let tokens_hash = let tokens_hash = compute_block_hash_for_seq(
compute_block_hash_for_seq(unsafe { std::slice::from_raw_parts(token_ids, num_tokens) })[0]; unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
kv_block_size,
)[0];
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash), block_hash: ExternalSequenceBlockHash(block_hash),
tokens_hash, tokens_hash,
...@@ -168,23 +173,22 @@ fn kv_event_create_stored_block_from_parts( ...@@ -168,23 +173,22 @@ fn kv_event_create_stored_block_from_parts(
static WARN_COUNT: AtomicU32 = AtomicU32::new(0); static WARN_COUNT: AtomicU32 = AtomicU32::new(0);
fn kv_event_create_stored_from_parts( fn kv_event_create_stored_from_parts(
event_id: u64, kv_params: DynamoKvStoredEventParams,
token_ids: *const u32, kv_block_size: usize,
num_block_tokens: *const usize,
block_ids: *const u64,
num_blocks: usize,
parent_hash: Option<u64>,
lora_id: u64,
) -> KvCacheEvent { ) -> KvCacheEvent {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new(); let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
let mut token_offset: usize = 0; let mut token_offset: usize = 0;
for block_idx in 0..num_blocks { for block_idx in 0..kv_params.num_blocks {
let block_hash = unsafe { *block_ids.offset(block_idx.try_into().unwrap()) }; let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
let tokens = unsafe { token_ids.offset(token_offset.try_into().unwrap()) }; let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
let num_toks = unsafe { *num_block_tokens.offset(block_idx.try_into().unwrap()) }; let num_toks = unsafe {
// compute hash only apply to full block (KV_BLOCK_SIZE token) *kv_params
if num_toks != 64 { .num_block_tokens
.offset(block_idx.try_into().unwrap())
};
if num_toks != kv_block_size {
if WARN_COUNT if WARN_COUNT
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| { .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
if c < 3 { if c < 3 {
...@@ -196,7 +200,8 @@ fn kv_event_create_stored_from_parts( ...@@ -196,7 +200,8 @@ fn kv_event_create_stored_from_parts(
.is_ok() .is_ok()
{ {
tracing::warn!( tracing::warn!(
"Block size must be 64 tokens to be published. Block size is: {}", "Block not published. Block size must be {} tokens to be published. Block size is: {}",
kv_block_size,
num_toks num_toks
); );
} }
...@@ -204,16 +209,20 @@ fn kv_event_create_stored_from_parts( ...@@ -204,16 +209,20 @@ fn kv_event_create_stored_from_parts(
} }
token_offset += num_toks; token_offset += num_toks;
blocks.push(kv_event_create_stored_block_from_parts( blocks.push(kv_event_create_stored_block_from_parts(
block_hash, tokens, num_toks, lora_id, block_hash,
tokens,
num_toks,
kv_block_size,
kv_params.lora_id,
)); ));
} }
KvCacheEvent { KvCacheEvent {
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
blocks, blocks,
parent_hash: parent_hash.map(ExternalSequenceBlockHash), parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
}), }),
event_id, event_id: kv_params.event_id,
} }
} }
...@@ -234,6 +243,16 @@ fn kv_event_create_removed_from_parts( ...@@ -234,6 +243,16 @@ fn kv_event_create_removed_from_parts(
} }
} }
pub struct DynamoKvStoredEventParams {
pub event_id: u64,
pub token_ids: *const u32,
pub num_block_tokens: *const usize,
pub block_ids: *const u64,
pub num_blocks: usize,
pub parent_hash: Option<u64>,
pub lora_id: u64,
}
/// # Safety /// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks /// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash /// has a parent hash or not. nullptr is used to represent no parent hash
...@@ -247,7 +266,6 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( ...@@ -247,7 +266,6 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
parent_hash: *const u64, parent_hash: *const u64,
lora_id: u64, lora_id: u64,
) -> DynamoLlmResult { ) -> DynamoLlmResult {
let publisher = KV_PUB.get().unwrap();
let parent_hash = { let parent_hash = {
if parent_hash.is_null() { if parent_hash.is_null() {
None None
...@@ -255,7 +273,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( ...@@ -255,7 +273,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
Some(unsafe { *parent_hash }) Some(unsafe { *parent_hash })
} }
}; };
let event = kv_event_create_stored_from_parts( let kv_params = DynamoKvStoredEventParams {
event_id, event_id,
token_ids, token_ids,
num_block_tokens, num_block_tokens,
...@@ -263,7 +281,9 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( ...@@ -263,7 +281,9 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
num_blocks, num_blocks,
parent_hash, parent_hash,
lora_id, lora_id,
); };
let publisher = KV_PUB.get().unwrap();
let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
match publisher.publish(event) { match publisher.publish(event) {
Ok(_) => DynamoLlmResult::OK, Ok(_) => DynamoLlmResult::OK,
Err(e) => { Err(e) => {
...@@ -290,68 +310,33 @@ pub extern "C" fn dynamo_kv_event_publish_removed( ...@@ -290,68 +310,33 @@ pub extern "C" fn dynamo_kv_event_publish_removed(
} }
} }
// #[no_mangle] // Need to setup etcd and nats to run these tests
// pub extern "C" fn dynamo_kv_publish_store_event( // #[cfg(test)]
// event_id: u64, // mod tests {
// token_ids: *const u32, // use super::*;
// num_tokens: usize, // use std::ffi::CString;
// lora_id: u64,
// ) -> DynamoLlmResult {
// // if event.is_null() || token_ids.is_null() {
// // return dynamoKvErrorType::INVALID_TOKEN_IDS;
// // }
// // let tokens = unsafe { std::slice::from_raw_parts(token_ids, num_tokens) }.to_vec(); // #[test]
// // let new_event = Box::new(KvCacheStoreData { // fn test_dynamo_llm_init() {
// // event_id, // // Create C-compatible strings
// // lora_id, // let namespace = CString::new("test_namespace").unwrap();
// // token_ids: tokens, // let component = CString::new("test_component").unwrap();
// // block_hashes: Vec::new(),
// // });
// // unsafe { *event = Box::into_raw(new_event) }; // // Call the init function
// let result = unsafe {
// dynamo_llm_init(
// namespace.as_ptr(),
// component.as_ptr(),
// 1, // worker_id
// 32, // kv_block_size
// )
// };
// DynamoLlmResult::OK // assert_eq!(result as u32, DynamoLlmResult::OK as u32);
// }
// #[no_mangle]
// pub extern "C" fn dynamo_kv_event_create_removed(
// event_id: u64,
// block_hashes: *const u64,
// num_hashes: usize,
// ) -> DynamoLlmResult {
// // if event.is_null() || block_hashes.is_null() {
// // return -1;
// // }
// // let hashes = unsafe { std::slice::from_raw_parts(block_hashes, num_hashes) }.to_vec(); // assert!(WK.get().is_some());
// // let new_event = Box::new(KvCacheRemoveData {
// // event_id,
// // lora_id: 0,
// // token_ids: Vec::new(),
// // block_hashes: hashes,
// // });
// // unsafe { *event = Box::into_raw(new_event) }; // let shutdown_result = dynamo_llm_shutdown();
// // 0 // assert_eq!(shutdown_result as u32, DynamoLlmResult::OK as u32);
// DynamoLlmResult::OK
// } // }
// /// create load publisher object and return a handle
// /// load publisher will instantiate the nats service and tie its stats handler to
// /// a watch channel receiver. the watch channel sender will be attach to the
// /// handle and calls to [`dynamo_load_stats_publish`] issue the stats to the watch t
// pub extern "C" fn dynamo_load_publisher_create() -> *mut LoadPublisher {
// // let publisher = Box::new(LoadPublisher::new());
// // Box::into_raw(publisher)
// }
// pub extern "C" fn dynamo_load_stats_publish(
// publisher: *mut LoadPublisher,
// active_slots: u64,
// total_slots: u64,
// active_kv: u64,
// total_kv: u64,
// ) {
// // let publisher = unsafe { &mut *publisher };
// } // }
...@@ -28,12 +28,13 @@ pub(crate) struct KvRouter { ...@@ -28,12 +28,13 @@ pub(crate) struct KvRouter {
impl KvRouter { impl KvRouter {
#[new] #[new]
// [FXIME] 'drt' can be obtained from 'component' // [FXIME] 'drt' can be obtained from 'component'
fn new(drt: DistributedRuntime, component: Component) -> PyResult<Self> { fn new(drt: DistributedRuntime, component: Component, kv_block_size: usize) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async { runtime.block_on(async {
let inner = llm_rs::kv_router::KvRouter::from_runtime( let inner = llm_rs::kv_router::KvRouter::from_runtime(
drt.inner.clone(), drt.inner.clone(),
component.inner.clone(), component.inner.clone(),
kv_block_size,
) )
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -138,7 +139,7 @@ pub(crate) struct KvIndexer { ...@@ -138,7 +139,7 @@ pub(crate) struct KvIndexer {
#[pymethods] #[pymethods]
impl KvIndexer { impl KvIndexer {
#[new] #[new]
fn new(component: Component) -> PyResult<Self> { fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async { runtime.block_on(async {
let kv_subject = component let kv_subject = component
...@@ -147,6 +148,7 @@ impl KvIndexer { ...@@ -147,6 +148,7 @@ impl KvIndexer {
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> = let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new( llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(), component.inner.drt().runtime().child_token(),
kv_block_size,
) )
.into(); .into();
let mut kv_events_rx = component let mut kv_events_rx = component
......
...@@ -54,20 +54,27 @@ async def distributed_runtime(): ...@@ -54,20 +54,27 @@ async def distributed_runtime():
return DistributedRuntime(loop) return DistributedRuntime(loop)
# TODO Figure out how to test with different kv_block_size
# Right now I get an error in EventPublisher init when I run this test
# back to back. It occurs when calling dynamo_llm_init and I think is related to the
# OnceCell initializations not being reset.
# The test works individually if I run it with 32, then 11, then 64.
# @pytest.mark.parametrize("kv_block_size", [11, 32, 64])
async def test_event_handler(distributed_runtime): async def test_event_handler(distributed_runtime):
kv_block_size = 32
namespace = "kv_test" namespace = "kv_test"
component = "event" component = "event"
# publisher # publisher
worker_id = 233 worker_id = 233
event_publisher = EventPublisher(namespace, component, worker_id) event_publisher = EventPublisher(namespace, component, worker_id, kv_block_size)
# indexer # indexer
kv_listener = distributed_runtime.namespace(namespace).component(component) kv_listener = distributed_runtime.namespace(namespace).component(component)
await kv_listener.create_service() await kv_listener.create_service()
indexer = KvIndexer(kv_listener) indexer = KvIndexer(kv_listener, kv_block_size)
test_token = [3] * 64 test_token = [3] * kv_block_size
lora_id = 0 # lora_id is not used in the indexer lora_id = 0 # lora_id is not used in the indexer
scores = await indexer.find_matches_for_request(test_token, lora_id) scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores assert not scores.scores
...@@ -86,6 +93,8 @@ async def test_event_handler(distributed_runtime): ...@@ -86,6 +93,8 @@ async def test_event_handler(distributed_runtime):
scores = await indexer.find_matches_for_request(test_token, lora_id) scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores assert not scores.scores
event_publisher.shutdown()
# KV events # KV events
class DynamoResult: class DynamoResult:
...@@ -94,18 +103,21 @@ class DynamoResult: ...@@ -94,18 +103,21 @@ class DynamoResult:
class EventPublisher: class EventPublisher:
def __init__(self, namespace: str, component: str, worker_id: int): def __init__(
self, namespace: str, component: str, worker_id: int, kv_block_size: int
):
self.event_id_counter = 0 self.event_id_counter = 0
self.block_ids: List[int] = [] self.block_ids: List[int] = []
# load event publisher library # load event publisher library
self.lib = ctypes.CDLL(os.environ["VLLM_KV_CAPI_PATH"]) self.lib = ctypes.CDLL(os.environ["VLLM_KV_CAPI_PATH"])
self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64] self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64, c_uint32]
self.lib.dynamo_llm_init.restype = c_uint32 self.lib.dynamo_llm_init.restype = c_uint32
result = self.lib.dynamo_llm_init( result = self.lib.dynamo_llm_init(
namespace.encode(), component.encode(), worker_id namespace.encode(), component.encode(), worker_id, kv_block_size
) )
assert result == DynamoResult.OK assert result == DynamoResult.OK
self.lib.dynamo_kv_event_publish_stored.argtypes = [ self.lib.dynamo_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids ctypes.POINTER(ctypes.c_uint32), # token_ids
...@@ -158,6 +170,10 @@ class EventPublisher: ...@@ -158,6 +170,10 @@ class EventPublisher:
assert result == DynamoResult.OK assert result == DynamoResult.OK
def shutdown(self):
result = self.lib.dynamo_llm_shutdown()
assert result == DynamoResult.OK
async def test_metrics_aggregator(distributed_runtime): async def test_metrics_aggregator(distributed_runtime):
namespace = "kv_test" namespace = "kv_test"
......
...@@ -1447,6 +1447,7 @@ dependencies = [ ...@@ -1447,6 +1447,7 @@ dependencies = [
"regex", "regex",
"reqwest", "reqwest",
"rstest", "rstest",
"rstest_reuse",
"semver", "semver",
"sentencepiece", "sentencepiece",
"serde", "serde",
...@@ -4647,6 +4648,17 @@ dependencies = [ ...@@ -4647,6 +4648,17 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "rstest_reuse"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14"
dependencies = [
"quote",
"rand",
"syn 2.0.98",
]
[[package]] [[package]]
name = "rustc-demangle" name = "rustc-demangle"
version = "0.1.24" version = "0.1.24"
......
...@@ -164,6 +164,7 @@ pythonize = { version = "0.23", optional = true } ...@@ -164,6 +164,7 @@ pythonize = { version = "0.23", optional = true }
proptest = "1.5.0" proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
rstest = "0.18.2" rstest = "0.18.2"
rstest_reuse = "0.7.0"
tempfile = "3.17.1" tempfile = "3.17.1"
hf-hub = "0.4.1" hf-hub = "0.4.1"
insta = { version = "1.41", features = [ insta = { version = "1.41", features = [
......
...@@ -54,6 +54,7 @@ impl KvRouter { ...@@ -54,6 +54,7 @@ impl KvRouter {
pub async fn from_runtime( pub async fn from_runtime(
runtime: DistributedRuntime, runtime: DistributedRuntime,
backend: Component, backend: Component,
kv_block_size: usize,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let nats_client = runtime.nats_client(); let nats_client = runtime.nats_client();
let service_name = backend.service_name(); let service_name = backend.service_name();
...@@ -63,7 +64,14 @@ impl KvRouter { ...@@ -63,7 +64,14 @@ impl KvRouter {
tracing::info!("Component Namespace {}", backend.namespace()); tracing::info!("Component Namespace {}", backend.namespace());
tracing::info!("Component Service Name {}", service_name); tracing::info!("Component Service Name {}", service_name);
tracing::info!("KV Subject {}", kv_subject); tracing::info!("KV Subject {}", kv_subject);
Self::new(nats_client, service_name, kv_subject, namespace).await Self::new(
nats_client,
service_name,
kv_subject,
namespace,
kv_block_size,
)
.await
} }
pub async fn new( pub async fn new(
...@@ -71,6 +79,7 @@ impl KvRouter { ...@@ -71,6 +79,7 @@ impl KvRouter {
service_name: String, service_name: String,
kv_subject: String, kv_subject: String,
namespace: Namespace, namespace: Namespace,
kv_block_size: usize,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128); let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
...@@ -82,8 +91,8 @@ impl KvRouter { ...@@ -82,8 +91,8 @@ impl KvRouter {
cancellation_token.clone(), cancellation_token.clone(),
)); ));
let indexer = KvIndexer::new(cancellation_token.clone()); let indexer = KvIndexer::new(cancellation_token.clone(), kv_block_size);
let scheduler = KvScheduler::start(ep_rx, namespace).await?; let scheduler = KvScheduler::start(ep_rx, namespace, kv_block_size).await?;
tracing::debug!("subscribing to kv events: {}", kv_subject); tracing::debug!("subscribing to kv events: {}", kv_subject);
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?; let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
......
...@@ -116,9 +116,9 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash { ...@@ -116,9 +116,9 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
/// ### Returns /// ### Returns
/// ///
/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens. /// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
pub fn compute_block_hash_for_seq(tokens: &[u32]) -> Vec<LocalBlockHash> { pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec<LocalBlockHash> {
tokens tokens
.chunks_exact(KV_BLOCK_SIZE) // Split into chunks of KV_BLOCK_SIZE elements .chunks_exact(kv_block_size) // Split into chunks of kv_block_size elements
.map(|chunk| { .map(|chunk| {
let bytes: Vec<u8> = chunk let bytes: Vec<u8> = chunk
.iter() .iter()
...@@ -503,6 +503,8 @@ pub struct KvIndexer { ...@@ -503,6 +503,8 @@ pub struct KvIndexer {
remove_worker_tx: mpsc::Sender<WorkerId>, remove_worker_tx: mpsc::Sender<WorkerId>,
/// A handle to the background task managing the KV store. /// A handle to the background task managing the KV store.
task: OnceLock<std::thread::JoinHandle<()>>, task: OnceLock<std::thread::JoinHandle<()>>,
/// The size of the KV block this indexer can handle.
kv_block_size: usize,
} }
impl KvIndexer { impl KvIndexer {
...@@ -519,6 +521,7 @@ impl KvIndexer { ...@@ -519,6 +521,7 @@ impl KvIndexer {
pub fn new_with_frequency( pub fn new_with_frequency(
token: CancellationToken, token: CancellationToken,
expiration_duration: Option<Duration>, expiration_duration: Option<Duration>,
kv_block_size: usize,
) -> Self { ) -> Self {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048); let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128); let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
...@@ -581,11 +584,12 @@ impl KvIndexer { ...@@ -581,11 +584,12 @@ impl KvIndexer {
match_tx, match_tx,
remove_worker_tx, remove_worker_tx,
task: once, task: once,
kv_block_size,
} }
} }
pub fn new(token: CancellationToken) -> Self { pub fn new(token: CancellationToken, kv_block_size: usize) -> Self {
Self::new_with_frequency(token, None) Self::new_with_frequency(token, None, kv_block_size)
} }
/// Get a sender for `RouterEvent`s. /// Get a sender for `RouterEvent`s.
...@@ -633,7 +637,7 @@ impl KvIndexerInterface for KvIndexer { ...@@ -633,7 +637,7 @@ impl KvIndexerInterface for KvIndexer {
tokens, tokens,
tokens.len() tokens.len()
); );
let sequence = compute_block_hash_for_seq(tokens); let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
log::debug!("Computed sequence: {:?}", sequence); log::debug!("Computed sequence: {:?}", sequence);
self.find_matches(sequence).await self.find_matches(sequence).await
} }
...@@ -665,6 +669,8 @@ pub struct ShardedMatchRequest { ...@@ -665,6 +669,8 @@ pub struct ShardedMatchRequest {
pub struct KvIndexerSharded { pub struct KvIndexerSharded {
/// A `CancellationToken` for managing shutdown. /// A `CancellationToken` for managing shutdown.
cancel: CancellationToken, cancel: CancellationToken,
/// The size of the KV block this indexer can handle.
kv_block_size: usize,
worker_assignments: HashMap<WorkerId, usize>, worker_assignments: HashMap<WorkerId, usize>,
worker_counts: Vec<usize>, worker_counts: Vec<usize>,
...@@ -690,6 +696,7 @@ impl KvIndexerSharded { ...@@ -690,6 +696,7 @@ impl KvIndexerSharded {
token: CancellationToken, token: CancellationToken,
num_shards: usize, num_shards: usize,
expiration_duration: Option<Duration>, expiration_duration: Option<Duration>,
kv_block_size: usize,
) -> Self { ) -> Self {
let worker_assignments: HashMap<WorkerId, usize> = HashMap::new(); let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
let worker_counts: Vec<usize> = vec![0; num_shards]; let worker_counts: Vec<usize> = vec![0; num_shards];
...@@ -758,6 +765,7 @@ impl KvIndexerSharded { ...@@ -758,6 +765,7 @@ impl KvIndexerSharded {
Self { Self {
cancel: token, cancel: token,
kv_block_size,
worker_assignments, worker_assignments,
worker_counts, worker_counts,
event_tx, event_tx,
...@@ -767,8 +775,8 @@ impl KvIndexerSharded { ...@@ -767,8 +775,8 @@ impl KvIndexerSharded {
} }
} }
pub fn new(token: CancellationToken, num_shards: usize) -> Self { pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self {
Self::new_with_frequency(token, num_shards, None) Self::new_with_frequency(token, num_shards, None, kv_block_size)
} }
} }
...@@ -827,7 +835,7 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -827,7 +835,7 @@ impl KvIndexerInterface for KvIndexerSharded {
&self, &self,
tokens: &[u32], tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens); let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
self.find_matches(sequence).await self.find_matches(sequence).await
} }
...@@ -875,6 +883,7 @@ mod tests { ...@@ -875,6 +883,7 @@ mod tests {
use super::*; use super::*;
use rstest::rstest; use rstest::rstest;
use rstest_reuse::{self, *};
use tokio::time; use tokio::time;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -1180,64 +1189,67 @@ mod tests { ...@@ -1180,64 +1189,67 @@ mod tests {
assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1); assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
} }
#[test] #[rstest]
fn test_compute_block_hash_for_seq() { #[case(11)]
#[case(32)]
#[case(64)]
fn test_compute_block_hash_for_seq(#[case] kv_block_size: usize) {
// create a sequence of 64 elements // create a sequence of 64 elements
let sequence = (0..KV_BLOCK_SIZE).map(|i| i as u32).collect::<Vec<u32>>(); let sequence = (0..kv_block_size).map(|i| i as u32).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
// create a sequence of 65 elements // create a sequence of 65 elements
let sequence = (0..(KV_BLOCK_SIZE + 1)) let sequence = (0..(kv_block_size + 1))
.map(|i| i as u32) .map(|i| i as u32)
.collect::<Vec<u32>>(); .collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
// create a sequence of 129 elements // create a sequence of 129 elements
let sequence = (0..(2 * KV_BLOCK_SIZE + 1)) let sequence = (0..(2 * kv_block_size + 1))
.map(|i| i as u32) .map(|i| i as u32)
.collect::<Vec<u32>>(); .collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 2); assert_eq!(hashes.len(), 2);
} }
fn make_indexer(token: &CancellationToken, num_shards: usize) -> Box<dyn KvIndexerInterface> { fn make_indexer(
token: &CancellationToken,
num_shards: usize,
kv_block_size: usize,
) -> Box<dyn KvIndexerInterface> {
if num_shards == 1 { if num_shards == 1 {
Box::new(KvIndexer::new(token.clone())) Box::new(KvIndexer::new(token.clone(), kv_block_size))
} else { } else {
Box::new(KvIndexerSharded::new(token.clone(), num_shards)) Box::new(KvIndexerSharded::new(
token.clone(),
num_shards,
kv_block_size,
))
} }
} }
#[template]
#[rstest] #[rstest]
#[case(1)] fn indexer_template(
#[case(2)] #[values(1, 3, 8)] num_shards: usize,
#[case(3)] #[values(11, 32, 64)] kv_block_size: usize,
#[case(4)] ) {
#[case(5)] }
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test] #[tokio::test]
async fn test_kv_indexer_new(#[case] num_shards: usize) { #[apply(indexer_template)]
let token = CancellationToken::new(); async fn test_kv_indexer_new(num_shards: usize, kv_block_size: usize) {
let _ = make_indexer(&token, num_shards); let token: CancellationToken = CancellationToken::new();
let _ = make_indexer(&token, num_shards, kv_block_size);
} }
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test] #[tokio::test]
async fn test_find_matches(#[case] num_shards: usize) { #[apply(indexer_template)]
async fn test_find_matches(num_shards: usize, kv_block_size: usize) {
let token = CancellationToken::new(); let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards); let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let sequence = vec![compute_block_hash(b"test data")]; let sequence = vec![compute_block_hash(b"test data")];
let scores = kv_indexer.find_matches(sequence).await; let scores = kv_indexer.find_matches(sequence).await;
...@@ -1245,19 +1257,11 @@ mod tests { ...@@ -1245,19 +1257,11 @@ mod tests {
assert!(scores.unwrap().scores.is_empty()); assert!(scores.unwrap().scores.is_empty());
} }
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test] #[tokio::test]
async fn test_find_matches_for_request(#[case] num_shards: usize) { #[apply(indexer_template)]
async fn test_find_matches_for_request(num_shards: usize, kv_block_size: usize) {
let token = CancellationToken::new(); let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards); let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let tokens = vec![1, 2, 3, 4]; let tokens = vec![1, 2, 3, 4];
let scores = kv_indexer.find_matches_for_request(&tokens).await; let scores = kv_indexer.find_matches_for_request(&tokens).await;
...@@ -1265,21 +1269,13 @@ mod tests { ...@@ -1265,21 +1269,13 @@ mod tests {
assert!(scores.unwrap().scores.is_empty()); assert!(scores.unwrap().scores.is_empty());
} }
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test] #[tokio::test]
async fn test_apply_event(#[case] num_shards: usize) { #[apply(indexer_template)]
async fn test_apply_event(num_shards: usize, kv_block_size: usize) {
let worker_id = 0; let worker_id = 0;
let token = CancellationToken::new(); let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards); let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let event = create_store_event(worker_id, 1, vec![1, 2, 3], None); let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
kv_indexer.apply_event(event).await; kv_indexer.apply_event(event).await;
...@@ -1287,43 +1283,34 @@ mod tests { ...@@ -1287,43 +1283,34 @@ mod tests {
// No assertion here, just ensuring it runs without panic // No assertion here, just ensuring it runs without panic
} }
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test] #[tokio::test]
async fn test_shutdown(#[case] num_shards: usize) { #[apply(indexer_template)]
async fn test_shutdown(num_shards: usize, kv_block_size: usize) {
let token = CancellationToken::new(); let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards); let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
kv_indexer.shutdown(); kv_indexer.shutdown();
} }
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test] #[tokio::test]
async fn test_frequency(#[case] num_shards: usize) { #[apply(indexer_template)]
async fn test_frequency(num_shards: usize, kv_block_size: usize) {
let mut kv_indexer: Box<dyn KvIndexerInterface>; let mut kv_indexer: Box<dyn KvIndexerInterface>;
let token = CancellationToken::new(); let token = CancellationToken::new();
let duration = Some(Duration::from_millis(50)); let duration = Some(Duration::from_millis(50));
if num_shards == 1 { if num_shards == 1 {
kv_indexer = Box::new(KvIndexer::new_with_frequency(token, duration)); kv_indexer = Box::new(KvIndexer::new_with_frequency(
token,
duration,
kv_block_size,
));
} else { } else {
kv_indexer = Box::new(KvIndexerSharded::new_with_frequency( kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
token, num_shards, duration, token,
num_shards,
duration,
kv_block_size,
)); ));
} }
......
...@@ -15,13 +15,6 @@ ...@@ -15,13 +15,6 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// Currently hard-coding the block size to be 64 tokens, and
// assuming the LLM framework aligns with this size.
// The KV publisher and subscriber conveys hash values of the tokens,
// for performance reason, therefore the block size needs to be consistent
// so that the computed hash value is the same on both sizes.
pub const KV_BLOCK_SIZE: usize = 64;
#[derive(Debug, Clone, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ForwardPassMetrics { pub struct ForwardPassMetrics {
pub request_active_slots: u64, pub request_active_slots: u64,
......
...@@ -31,12 +31,18 @@ use tracing as log; ...@@ -31,12 +31,18 @@ use tracing as log;
pub struct KvEventPublisher { pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
kv_block_size: usize,
} }
impl KvEventPublisher { impl KvEventPublisher {
pub fn new(drt: DistributedRuntime, backend: Component, worker_id: i64) -> Result<Self> { pub fn new(
drt: DistributedRuntime,
backend: Component,
worker_id: i64,
kv_block_size: usize,
) -> Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let p = KvEventPublisher { tx }; let p = KvEventPublisher { tx, kv_block_size };
start_publish_task(drt, backend, worker_id, rx); start_publish_task(drt, backend, worker_id, rx);
Ok(p) Ok(p)
...@@ -46,6 +52,10 @@ impl KvEventPublisher { ...@@ -46,6 +52,10 @@ impl KvEventPublisher {
log::debug!("Publish event: {:?}", event); log::debug!("Publish event: {:?}", event);
self.tx.send(event) self.tx.send(event)
} }
pub fn kv_block_size(&self) -> usize {
self.kv_block_size
}
} }
fn start_publish_task( fn start_publish_task(
......
...@@ -20,7 +20,7 @@ use std::borrow::BorrowMut; ...@@ -20,7 +20,7 @@ use std::borrow::BorrowMut;
use std::cmp::min; use std::cmp::min;
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE}; pub use crate::kv_router::protocols::ForwardPassMetrics;
use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::scoring::ProcessedEndpoints;
use crate::kv_router::KV_HIT_RATE_SUBJECT; use crate::kv_router::KV_HIT_RATE_SUBJECT;
...@@ -112,6 +112,7 @@ impl KvScheduler { ...@@ -112,6 +112,7 @@ impl KvScheduler {
pub async fn start( pub async fn start(
endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>, endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
ns: Namespace, ns: Namespace,
kv_block_size: usize,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let mut endpoints_rx = endpoints_rx; let mut endpoints_rx = endpoints_rx;
...@@ -178,7 +179,8 @@ impl KvScheduler { ...@@ -178,7 +179,8 @@ impl KvScheduler {
}; };
tracing::debug!("selected"); tracing::debug!("selected");
loop { loop {
match select_worker(endpoints.borrow_mut(), &request, &event_tx) { match select_worker(endpoints.borrow_mut(), &request, &event_tx, kv_block_size)
{
Ok(worker_id) => { Ok(worker_id) => {
request.respond(worker_id); request.respond(worker_id);
continue 'outer; continue 'outer;
...@@ -237,6 +239,7 @@ pub fn select_worker( ...@@ -237,6 +239,7 @@ pub fn select_worker(
workers: &mut ProcessedEndpoints, workers: &mut ProcessedEndpoints,
request: &SchedulingRequest, request: &SchedulingRequest,
event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>, event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
kv_block_size: usize,
) -> Result<i64, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers // balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1; let balance_threshold: f64 = 0.1;
...@@ -249,7 +252,7 @@ pub fn select_worker( ...@@ -249,7 +252,7 @@ pub fn select_worker(
// Compute each worker's score // Compute each worker's score
let mut best_index = None; let mut best_index = None;
let mut best_cost = f64::INFINITY; let mut best_cost = f64::INFINITY;
// [FIXME] REMOVE ONLY FOR TESTING
if workers.endpoints.is_empty() { if workers.endpoints.is_empty() {
return Err(KvSchedulerError::NoEndpoints); return Err(KvSchedulerError::NoEndpoints);
} }
...@@ -268,7 +271,7 @@ pub fn select_worker( ...@@ -268,7 +271,7 @@ pub fn select_worker(
// [FIXME] multiple endpoints of the same worker cause out of bound error // [FIXME] multiple endpoints of the same worker cause out of bound error
let worker_id = workers.worker_ids[i]; let worker_id = workers.worker_ids[i];
let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x); let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x);
let overlap_score = overlap_score as usize * KV_BLOCK_SIZE; let overlap_score = overlap_score as usize * kv_block_size;
let new_tokens = request.isl_tokens.saturating_sub(overlap_score); let new_tokens = request.isl_tokens.saturating_sub(overlap_score);
let normalized_new_tokens = new_tokens as f64 / request.isl_tokens as f64; let normalized_new_tokens = new_tokens as f64 / request.isl_tokens as f64;
...@@ -296,14 +299,14 @@ pub fn select_worker( ...@@ -296,14 +299,14 @@ pub fn select_worker(
} }
if let Some(best_index) = best_index { if let Some(best_index) = best_index {
let total_blocks = min(request.isl_tokens / KV_BLOCK_SIZE, 1); let total_blocks = min(request.isl_tokens / kv_block_size, 1);
workers.endpoints[best_index].data.request_active_slots += 1; workers.endpoints[best_index].data.request_active_slots += 1;
workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64; workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
// Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler // Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler
let best_worker_id = workers.endpoints[best_index].worker_id(); let best_worker_id = workers.endpoints[best_index].worker_id();
let isl_blocks = request.isl_tokens / KV_BLOCK_SIZE; let isl_blocks = request.isl_tokens / kv_block_size;
let overlap_blocks = request let overlap_blocks = request
.overlap .overlap
.scores .scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment