Commit 989bb3d5 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: make block_size input for indexer, router, publisher (#66)

parent dd31a322
......@@ -247,7 +247,7 @@ index 1ca9e49d..b1591c0c 100644
# Reuse the cached content hash
diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py
index c5b3b04f..c72001f7 100644
index c5b3b04f..12cd4dc9 100644
--- a/vllm/core/block_manager.py
+++ b/vllm/core/block_manager.py
@@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block
......@@ -269,7 +269,7 @@ index c5b3b04f..c72001f7 100644
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
@@ -91,11 +95,28 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
@@ -91,11 +95,29 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.watermark_blocks = int(watermark * num_gpu_blocks)
......@@ -285,7 +285,8 @@ index c5b3b04f..c72001f7 100644
+ namespace=VLLM_KV_NAMESPACE,
+ component=VLLM_KV_COMPONENT,
+ worker_id=VLLM_WORKER_ID,
+ lib_path=VLLM_KV_CAPI_PATH)
+ lib_path=VLLM_KV_CAPI_PATH,
+ kv_block_size=block_size)
+ else:
+ self.event_manager = None
+
......@@ -300,10 +301,10 @@ index c5b3b04f..c72001f7 100644
self.block_tables: Dict[SeqId, BlockTable] = {}
diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py
new file mode 100644
index 00000000..d3706700
index 00000000..a27af580
--- /dev/null
+++ b/vllm/core/event_manager.py
@@ -0,0 +1,102 @@
@@ -0,0 +1,108 @@
+# SPDX-License-Identifier: Apache-2.0
+import ctypes
+import logging
......@@ -324,16 +325,22 @@ index 00000000..d3706700
+class KVCacheEventManager:
+
+ def __init__(self, namespace: str, component: str, worker_id: int,
+ lib_path: str):
+ lib_path: str, kv_block_size: int):
+ self.lib = None
+
+ try:
+ self.lib = ctypes.CDLL(lib_path)
+ self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
+ self.lib.dynamo_llm_init.argtypes = [
+ c_char_p,
+ c_char_p,
+ c_int64,
+ c_uint32,
+ ]
+ self.lib.dynamo_llm_init.restype = c_uint32
+
+ result = self.lib.dynamo_llm_init(namespace.encode(),
+ component.encode(), worker_id)
+ result = self.lib.dynamo_llm_init(
+ namespace.encode(), component.encode(), worker_id, kv_block_size
+ )
+ if result == DynamoResult.OK:
+ logger.info(
+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
......
......@@ -262,7 +262,8 @@ cd /workspace/examples/python_rs/llm/vllm
RUST_LOG=info python3 -m kv_router.router \
--routing-strategy prefix \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--min-workers 1
--min-workers 1 \
--block-size 64
```
You can choose only the prefix strategy for now:
......
......@@ -173,13 +173,21 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
endpoint = router_component.endpoint("generate")
if args.custom_router:
indexer = KvIndexer(kv_listener)
# @REVIEWER - I'm not currently checking if block size matches that of the engine
# If they don't match things will silently fail
# The preferred solution would be for the KV Indexer to read from the MDC in etcd and not bother the user at all
# The second solution would be to do KvIndexer(kv_listener, MDC.block_size)
# as this ensures block size matches that of the engine
# In this case we need to do some sort of handshake or check in case a user just puts in a random block size
indexer = KvIndexer(kv_listener, args.block_size)
metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint(
CustomRouter(indexer, metrics_aggregator).generate
)
else:
router = KvRouter(runtime, kv_listener)
# TODO Read block_size from MDC
router = KvRouter(runtime, kv_listener, args.block_size)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
......@@ -208,6 +216,11 @@ if __name__ == "__main__":
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
parser.add_argument(
"--block-size",
type=int,
help="KV block size",
)
parser.add_argument(
"--custom-router",
type=bool,
......
......@@ -63,7 +63,7 @@ PROCESSOR_CMD="RUST_LOG=info python3 -m kv_router.processor \
--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
--enable-prefix-caching \
--block-size 64 \
--block-size 32 \
--max-model-len 16384 "
tmux new-session -d -s "$SESSION_NAME-processor"
tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
......@@ -74,7 +74,8 @@ tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
ROUTER_CMD="RUST_LOG=info python3 -m kv_router.router \
--model $MODEL_NAME \
--routing-strategy $ROUTING_STRATEGY \
--min-workers $NUM_WORKERS "
--min-workers $NUM_WORKERS \
--block-size 32"
tmux new-session -d -s "$SESSION_NAME-router"
tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
......
......@@ -51,6 +51,7 @@ class Router:
self.routing_strategy = RoutingStrategy.PREFIX
self.runtime = dynamo_context["runtime"]
self.min_workers = 1
self.kv_block_size = 64
@async_onstart
async def init_engine(self):
......@@ -77,7 +78,7 @@ class Router:
kv_listener = self.runtime.namespace("dynamo").component(self.model_name)
await kv_listener.create_service()
self.router = KvRouter(self.runtime, kv_listener)
self.router = KvRouter(self.runtime, kv_listener, self.kv_block_size)
@dynamo_endpoint()
async def generate(self, request: Tokens):
......
......@@ -53,6 +53,7 @@ pub unsafe extern "C" fn dynamo_llm_init(
namespace_c_str: *const c_char,
component_c_str: *const c_char,
worker_id: i64,
kv_block_size: u32,
) -> DynamoLlmResult {
initialize_tracing();
let wk = match WK.get_or_try_init(Worker::from_settings) {
......@@ -94,9 +95,9 @@ pub unsafe extern "C" fn dynamo_llm_init(
};
match result {
Ok(_) => match KV_PUB
.get_or_try_init(move || dynamo_create_kv_publisher(namespace, component, worker_id))
{
Ok(_) => match KV_PUB.get_or_try_init(move || {
dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size as usize)
}) {
Ok(_) => DynamoLlmResult::OK,
Err(e) => {
eprintln!("Failed to initialize distributed runtime: {:?}", e);
......@@ -138,6 +139,7 @@ fn dynamo_create_kv_publisher(
namespace: String,
component: String,
worker_id: i64,
kv_block_size: usize,
) -> Result<KvEventPublisher, anyhow::Error> {
tracing::info!("Creating KV Publisher for model: {}", component);
match DRT
......@@ -146,7 +148,7 @@ fn dynamo_create_kv_publisher(
{
Ok(drt) => {
let backend = drt.namespace(namespace)?.component(component)?;
KvEventPublisher::new(drt.clone(), backend, worker_id)
KvEventPublisher::new(drt.clone(), backend, worker_id, kv_block_size)
}
Err(e) => Err(e),
}
......@@ -156,10 +158,13 @@ fn kv_event_create_stored_block_from_parts(
block_hash: u64,
token_ids: *const u32,
num_tokens: usize,
kv_block_size: usize,
_lora_id: u64,
) -> KvCacheStoredBlockData {
let tokens_hash =
compute_block_hash_for_seq(unsafe { std::slice::from_raw_parts(token_ids, num_tokens) })[0];
let tokens_hash = compute_block_hash_for_seq(
unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
kv_block_size,
)[0];
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash),
tokens_hash,
......@@ -168,23 +173,22 @@ fn kv_event_create_stored_block_from_parts(
static WARN_COUNT: AtomicU32 = AtomicU32::new(0);
fn kv_event_create_stored_from_parts(
event_id: u64,
token_ids: *const u32,
num_block_tokens: *const usize,
block_ids: *const u64,
num_blocks: usize,
parent_hash: Option<u64>,
lora_id: u64,
kv_params: DynamoKvStoredEventParams,
kv_block_size: usize,
) -> KvCacheEvent {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
let mut token_offset: usize = 0;
for block_idx in 0..num_blocks {
let block_hash = unsafe { *block_ids.offset(block_idx.try_into().unwrap()) };
let tokens = unsafe { token_ids.offset(token_offset.try_into().unwrap()) };
let num_toks = unsafe { *num_block_tokens.offset(block_idx.try_into().unwrap()) };
// compute hash only apply to full block (KV_BLOCK_SIZE token)
if num_toks != 64 {
for block_idx in 0..kv_params.num_blocks {
let block_hash = unsafe { *kv_params.block_ids.offset(block_idx.try_into().unwrap()) };
let tokens = unsafe { kv_params.token_ids.offset(token_offset.try_into().unwrap()) };
let num_toks = unsafe {
*kv_params
.num_block_tokens
.offset(block_idx.try_into().unwrap())
};
if num_toks != kv_block_size {
if WARN_COUNT
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
if c < 3 {
......@@ -196,7 +200,8 @@ fn kv_event_create_stored_from_parts(
.is_ok()
{
tracing::warn!(
"Block size must be 64 tokens to be published. Block size is: {}",
"Block not published. Block size must be {} tokens to be published. Block size is: {}",
kv_block_size,
num_toks
);
}
......@@ -204,16 +209,20 @@ fn kv_event_create_stored_from_parts(
}
token_offset += num_toks;
blocks.push(kv_event_create_stored_block_from_parts(
block_hash, tokens, num_toks, lora_id,
block_hash,
tokens,
num_toks,
kv_block_size,
kv_params.lora_id,
));
}
KvCacheEvent {
data: KvCacheEventData::Stored(KvCacheStoreData {
blocks,
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
parent_hash: kv_params.parent_hash.map(ExternalSequenceBlockHash),
}),
event_id,
event_id: kv_params.event_id,
}
}
......@@ -234,6 +243,16 @@ fn kv_event_create_removed_from_parts(
}
}
pub struct DynamoKvStoredEventParams {
pub event_id: u64,
pub token_ids: *const u32,
pub num_block_tokens: *const usize,
pub block_ids: *const u64,
pub num_blocks: usize,
pub parent_hash: Option<u64>,
pub lora_id: u64,
}
/// # Safety
/// parent_hash is passed as pointer to indicate whether the blocks
/// has a parent hash or not. nullptr is used to represent no parent hash
......@@ -247,7 +266,6 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
parent_hash: *const u64,
lora_id: u64,
) -> DynamoLlmResult {
let publisher = KV_PUB.get().unwrap();
let parent_hash = {
if parent_hash.is_null() {
None
......@@ -255,7 +273,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
Some(unsafe { *parent_hash })
}
};
let event = kv_event_create_stored_from_parts(
let kv_params = DynamoKvStoredEventParams {
event_id,
token_ids,
num_block_tokens,
......@@ -263,7 +281,9 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored(
num_blocks,
parent_hash,
lora_id,
);
};
let publisher = KV_PUB.get().unwrap();
let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size());
match publisher.publish(event) {
Ok(_) => DynamoLlmResult::OK,
Err(e) => {
......@@ -290,68 +310,33 @@ pub extern "C" fn dynamo_kv_event_publish_removed(
}
}
// #[no_mangle]
// pub extern "C" fn dynamo_kv_publish_store_event(
// event_id: u64,
// token_ids: *const u32,
// num_tokens: usize,
// lora_id: u64,
// ) -> DynamoLlmResult {
// // if event.is_null() || token_ids.is_null() {
// // return dynamoKvErrorType::INVALID_TOKEN_IDS;
// // }
// // let tokens = unsafe { std::slice::from_raw_parts(token_ids, num_tokens) }.to_vec();
// // let new_event = Box::new(KvCacheStoreData {
// // event_id,
// // lora_id,
// // token_ids: tokens,
// // block_hashes: Vec::new(),
// // });
// Need to setup etcd and nats to run these tests
// #[cfg(test)]
// mod tests {
// use super::*;
// use std::ffi::CString;
// // unsafe { *event = Box::into_raw(new_event) };
// #[test]
// fn test_dynamo_llm_init() {
// // Create C-compatible strings
// let namespace = CString::new("test_namespace").unwrap();
// let component = CString::new("test_component").unwrap();
// DynamoLlmResult::OK
// }
// #[no_mangle]
// pub extern "C" fn dynamo_kv_event_create_removed(
// event_id: u64,
// block_hashes: *const u64,
// num_hashes: usize,
// ) -> DynamoLlmResult {
// // if event.is_null() || block_hashes.is_null() {
// // return -1;
// // }
// // Call the init function
// let result = unsafe {
// dynamo_llm_init(
// namespace.as_ptr(),
// component.as_ptr(),
// 1, // worker_id
// 32, // kv_block_size
// )
// };
// // let hashes = unsafe { std::slice::from_raw_parts(block_hashes, num_hashes) }.to_vec();
// // let new_event = Box::new(KvCacheRemoveData {
// // event_id,
// // lora_id: 0,
// // token_ids: Vec::new(),
// // block_hashes: hashes,
// // });
// assert_eq!(result as u32, DynamoLlmResult::OK as u32);
// // unsafe { *event = Box::into_raw(new_event) };
// // 0
// DynamoLlmResult::OK
// }
// /// create load publisher object and return a handle
// /// load publisher will instantiate the nats service and tie its stats handler to
// /// a watch channel receiver. the watch channel sender will be attach to the
// /// handle and calls to [`dynamo_load_stats_publish`] issue the stats to the watch t
// pub extern "C" fn dynamo_load_publisher_create() -> *mut LoadPublisher {
// // let publisher = Box::new(LoadPublisher::new());
// // Box::into_raw(publisher)
// }
// assert!(WK.get().is_some());
// pub extern "C" fn dynamo_load_stats_publish(
// publisher: *mut LoadPublisher,
// active_slots: u64,
// total_slots: u64,
// active_kv: u64,
// total_kv: u64,
// ) {
// // let publisher = unsafe { &mut *publisher };
// let shutdown_result = dynamo_llm_shutdown();
// assert_eq!(shutdown_result as u32, DynamoLlmResult::OK as u32);
// }
// }
......@@ -28,12 +28,13 @@ pub(crate) struct KvRouter {
impl KvRouter {
#[new]
// [FXIME] 'drt' can be obtained from 'component'
fn new(drt: DistributedRuntime, component: Component) -> PyResult<Self> {
fn new(drt: DistributedRuntime, component: Component, kv_block_size: usize) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner = llm_rs::kv_router::KvRouter::from_runtime(
drt.inner.clone(),
component.inner.clone(),
kv_block_size,
)
.await
.map_err(to_pyerr)?;
......@@ -138,7 +139,7 @@ pub(crate) struct KvIndexer {
#[pymethods]
impl KvIndexer {
#[new]
fn new(component: Component) -> PyResult<Self> {
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let kv_subject = component
......@@ -147,6 +148,7 @@ impl KvIndexer {
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(),
kv_block_size,
)
.into();
let mut kv_events_rx = component
......
......@@ -54,20 +54,27 @@ async def distributed_runtime():
return DistributedRuntime(loop)
# TODO Figure out how to test with different kv_block_size
# Right now I get an error in EventPublisher init when I run this test
# back to back. It occurs when calling dynamo_llm_init and I think is related to the
# OnceCell initializations not being reset.
# The test works individually if I run it with 32, then 11, then 64.
# @pytest.mark.parametrize("kv_block_size", [11, 32, 64])
async def test_event_handler(distributed_runtime):
kv_block_size = 32
namespace = "kv_test"
component = "event"
# publisher
worker_id = 233
event_publisher = EventPublisher(namespace, component, worker_id)
event_publisher = EventPublisher(namespace, component, worker_id, kv_block_size)
# indexer
kv_listener = distributed_runtime.namespace(namespace).component(component)
await kv_listener.create_service()
indexer = KvIndexer(kv_listener)
indexer = KvIndexer(kv_listener, kv_block_size)
test_token = [3] * 64
test_token = [3] * kv_block_size
lora_id = 0 # lora_id is not used in the indexer
scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores
......@@ -86,6 +93,8 @@ async def test_event_handler(distributed_runtime):
scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores
event_publisher.shutdown()
# KV events
class DynamoResult:
......@@ -94,18 +103,21 @@ class DynamoResult:
class EventPublisher:
def __init__(self, namespace: str, component: str, worker_id: int):
def __init__(
self, namespace: str, component: str, worker_id: int, kv_block_size: int
):
self.event_id_counter = 0
self.block_ids: List[int] = []
# load event publisher library
self.lib = ctypes.CDLL(os.environ["VLLM_KV_CAPI_PATH"])
self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.dynamo_llm_init.argtypes = [c_char_p, c_char_p, c_int64, c_uint32]
self.lib.dynamo_llm_init.restype = c_uint32
result = self.lib.dynamo_llm_init(
namespace.encode(), component.encode(), worker_id
namespace.encode(), component.encode(), worker_id, kv_block_size
)
assert result == DynamoResult.OK
self.lib.dynamo_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
......@@ -158,6 +170,10 @@ class EventPublisher:
assert result == DynamoResult.OK
def shutdown(self):
result = self.lib.dynamo_llm_shutdown()
assert result == DynamoResult.OK
async def test_metrics_aggregator(distributed_runtime):
namespace = "kv_test"
......
......@@ -1447,6 +1447,7 @@ dependencies = [
"regex",
"reqwest",
"rstest",
"rstest_reuse",
"semver",
"sentencepiece",
"serde",
......@@ -4647,6 +4648,17 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "rstest_reuse"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14"
dependencies = [
"quote",
"rand",
"syn 2.0.98",
]
[[package]]
name = "rustc-demangle"
version = "0.1.24"
......
......@@ -164,6 +164,7 @@ pythonize = { version = "0.23", optional = true }
proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
rstest = "0.18.2"
rstest_reuse = "0.7.0"
tempfile = "3.17.1"
hf-hub = "0.4.1"
insta = { version = "1.41", features = [
......
......@@ -54,6 +54,7 @@ impl KvRouter {
pub async fn from_runtime(
runtime: DistributedRuntime,
backend: Component,
kv_block_size: usize,
) -> Result<Arc<Self>> {
let nats_client = runtime.nats_client();
let service_name = backend.service_name();
......@@ -63,7 +64,14 @@ impl KvRouter {
tracing::info!("Component Namespace {}", backend.namespace());
tracing::info!("Component Service Name {}", service_name);
tracing::info!("KV Subject {}", kv_subject);
Self::new(nats_client, service_name, kv_subject, namespace).await
Self::new(
nats_client,
service_name,
kv_subject,
namespace,
kv_block_size,
)
.await
}
pub async fn new(
......@@ -71,6 +79,7 @@ impl KvRouter {
service_name: String,
kv_subject: String,
namespace: Namespace,
kv_block_size: usize,
) -> Result<Arc<Self>> {
let cancellation_token = CancellationToken::new();
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
......@@ -82,8 +91,8 @@ impl KvRouter {
cancellation_token.clone(),
));
let indexer = KvIndexer::new(cancellation_token.clone());
let scheduler = KvScheduler::start(ep_rx, namespace).await?;
let indexer = KvIndexer::new(cancellation_token.clone(), kv_block_size);
let scheduler = KvScheduler::start(ep_rx, namespace, kv_block_size).await?;
tracing::debug!("subscribing to kv events: {}", kv_subject);
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
......
......@@ -116,9 +116,9 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
/// ### Returns
///
/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
pub fn compute_block_hash_for_seq(tokens: &[u32]) -> Vec<LocalBlockHash> {
pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec<LocalBlockHash> {
tokens
.chunks_exact(KV_BLOCK_SIZE) // Split into chunks of KV_BLOCK_SIZE elements
.chunks_exact(kv_block_size) // Split into chunks of kv_block_size elements
.map(|chunk| {
let bytes: Vec<u8> = chunk
.iter()
......@@ -503,6 +503,8 @@ pub struct KvIndexer {
remove_worker_tx: mpsc::Sender<WorkerId>,
/// A handle to the background task managing the KV store.
task: OnceLock<std::thread::JoinHandle<()>>,
/// The size of the KV block this indexer can handle.
kv_block_size: usize,
}
impl KvIndexer {
......@@ -519,6 +521,7 @@ impl KvIndexer {
pub fn new_with_frequency(
token: CancellationToken,
expiration_duration: Option<Duration>,
kv_block_size: usize,
) -> Self {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
......@@ -581,11 +584,12 @@ impl KvIndexer {
match_tx,
remove_worker_tx,
task: once,
kv_block_size,
}
}
pub fn new(token: CancellationToken) -> Self {
Self::new_with_frequency(token, None)
pub fn new(token: CancellationToken, kv_block_size: usize) -> Self {
Self::new_with_frequency(token, None, kv_block_size)
}
/// Get a sender for `RouterEvent`s.
......@@ -633,7 +637,7 @@ impl KvIndexerInterface for KvIndexer {
tokens,
tokens.len()
);
let sequence = compute_block_hash_for_seq(tokens);
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
log::debug!("Computed sequence: {:?}", sequence);
self.find_matches(sequence).await
}
......@@ -665,6 +669,8 @@ pub struct ShardedMatchRequest {
pub struct KvIndexerSharded {
/// A `CancellationToken` for managing shutdown.
cancel: CancellationToken,
/// The size of the KV block this indexer can handle.
kv_block_size: usize,
worker_assignments: HashMap<WorkerId, usize>,
worker_counts: Vec<usize>,
......@@ -690,6 +696,7 @@ impl KvIndexerSharded {
token: CancellationToken,
num_shards: usize,
expiration_duration: Option<Duration>,
kv_block_size: usize,
) -> Self {
let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
let worker_counts: Vec<usize> = vec![0; num_shards];
......@@ -758,6 +765,7 @@ impl KvIndexerSharded {
Self {
cancel: token,
kv_block_size,
worker_assignments,
worker_counts,
event_tx,
......@@ -767,8 +775,8 @@ impl KvIndexerSharded {
}
}
pub fn new(token: CancellationToken, num_shards: usize) -> Self {
Self::new_with_frequency(token, num_shards, None)
pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self {
Self::new_with_frequency(token, num_shards, None, kv_block_size)
}
}
......@@ -827,7 +835,7 @@ impl KvIndexerInterface for KvIndexerSharded {
&self,
tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens);
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size);
self.find_matches(sequence).await
}
......@@ -875,6 +883,7 @@ mod tests {
use super::*;
use rstest::rstest;
use rstest_reuse::{self, *};
use tokio::time;
use tokio_util::sync::CancellationToken;
......@@ -1180,64 +1189,67 @@ mod tests {
assert!(result.len() == 2 && result[&worker_0] == 2 && result[&worker_1] == 1);
}
#[test]
fn test_compute_block_hash_for_seq() {
#[rstest]
#[case(11)]
#[case(32)]
#[case(64)]
fn test_compute_block_hash_for_seq(#[case] kv_block_size: usize) {
// create a sequence of 64 elements
let sequence = (0..KV_BLOCK_SIZE).map(|i| i as u32).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence);
let sequence = (0..kv_block_size).map(|i| i as u32).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 1);
// create a sequence of 65 elements
let sequence = (0..(KV_BLOCK_SIZE + 1))
let sequence = (0..(kv_block_size + 1))
.map(|i| i as u32)
.collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence);
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 1);
// create a sequence of 129 elements
let sequence = (0..(2 * KV_BLOCK_SIZE + 1))
let sequence = (0..(2 * kv_block_size + 1))
.map(|i| i as u32)
.collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence);
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 2);
}
fn make_indexer(token: &CancellationToken, num_shards: usize) -> Box<dyn KvIndexerInterface> {
fn make_indexer(
token: &CancellationToken,
num_shards: usize,
kv_block_size: usize,
) -> Box<dyn KvIndexerInterface> {
if num_shards == 1 {
Box::new(KvIndexer::new(token.clone()))
Box::new(KvIndexer::new(token.clone(), kv_block_size))
} else {
Box::new(KvIndexerSharded::new(token.clone(), num_shards))
Box::new(KvIndexerSharded::new(
token.clone(),
num_shards,
kv_block_size,
))
}
}
#[template]
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
fn indexer_template(
#[values(1, 3, 8)] num_shards: usize,
#[values(11, 32, 64)] kv_block_size: usize,
) {
}
#[tokio::test]
async fn test_kv_indexer_new(#[case] num_shards: usize) {
let token = CancellationToken::new();
let _ = make_indexer(&token, num_shards);
#[apply(indexer_template)]
async fn test_kv_indexer_new(num_shards: usize, kv_block_size: usize) {
let token: CancellationToken = CancellationToken::new();
let _ = make_indexer(&token, num_shards, kv_block_size);
}
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test]
async fn test_find_matches(#[case] num_shards: usize) {
#[apply(indexer_template)]
async fn test_find_matches(num_shards: usize, kv_block_size: usize) {
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards);
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let sequence = vec![compute_block_hash(b"test data")];
let scores = kv_indexer.find_matches(sequence).await;
......@@ -1245,19 +1257,11 @@ mod tests {
assert!(scores.unwrap().scores.is_empty());
}
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test]
async fn test_find_matches_for_request(#[case] num_shards: usize) {
#[apply(indexer_template)]
async fn test_find_matches_for_request(num_shards: usize, kv_block_size: usize) {
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards);
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let tokens = vec![1, 2, 3, 4];
let scores = kv_indexer.find_matches_for_request(&tokens).await;
......@@ -1265,21 +1269,13 @@ mod tests {
assert!(scores.unwrap().scores.is_empty());
}
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test]
async fn test_apply_event(#[case] num_shards: usize) {
#[apply(indexer_template)]
async fn test_apply_event(num_shards: usize, kv_block_size: usize) {
let worker_id = 0;
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards);
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
let event = create_store_event(worker_id, 1, vec![1, 2, 3], None);
kv_indexer.apply_event(event).await;
......@@ -1287,43 +1283,34 @@ mod tests {
// No assertion here, just ensuring it runs without panic
}
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test]
async fn test_shutdown(#[case] num_shards: usize) {
#[apply(indexer_template)]
async fn test_shutdown(num_shards: usize, kv_block_size: usize) {
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards);
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
kv_indexer.shutdown();
}
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
#[case(4)]
#[case(5)]
#[case(6)]
#[case(7)]
#[case(8)]
#[tokio::test]
async fn test_frequency(#[case] num_shards: usize) {
#[apply(indexer_template)]
async fn test_frequency(num_shards: usize, kv_block_size: usize) {
let mut kv_indexer: Box<dyn KvIndexerInterface>;
let token = CancellationToken::new();
let duration = Some(Duration::from_millis(50));
if num_shards == 1 {
kv_indexer = Box::new(KvIndexer::new_with_frequency(token, duration));
kv_indexer = Box::new(KvIndexer::new_with_frequency(
token,
duration,
kv_block_size,
));
} else {
kv_indexer = Box::new(KvIndexerSharded::new_with_frequency(
token, num_shards, duration,
token,
num_shards,
duration,
kv_block_size,
));
}
......
......@@ -15,13 +15,6 @@
use serde::{Deserialize, Serialize};
// Currently hard-coding the block size to be 64 tokens, and
// assuming the LLM framework aligns with this size.
// The KV publisher and subscriber conveys hash values of the tokens,
// for performance reason, therefore the block size needs to be consistent
// so that the computed hash value is the same on both sizes.
pub const KV_BLOCK_SIZE: usize = 64;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ForwardPassMetrics {
pub request_active_slots: u64,
......
......@@ -31,12 +31,18 @@ use tracing as log;
pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>,
kv_block_size: usize,
}
impl KvEventPublisher {
pub fn new(drt: DistributedRuntime, backend: Component, worker_id: i64) -> Result<Self> {
pub fn new(
drt: DistributedRuntime,
backend: Component,
worker_id: i64,
kv_block_size: usize,
) -> Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let p = KvEventPublisher { tx };
let p = KvEventPublisher { tx, kv_block_size };
start_publish_task(drt, backend, worker_id, rx);
Ok(p)
......@@ -46,6 +52,10 @@ impl KvEventPublisher {
log::debug!("Publish event: {:?}", event);
self.tx.send(event)
}
pub fn kv_block_size(&self) -> usize {
self.kv_block_size
}
}
fn start_publish_task(
......
......@@ -20,7 +20,7 @@ use std::borrow::BorrowMut;
use std::cmp::min;
use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
pub use crate::kv_router::protocols::ForwardPassMetrics;
use crate::kv_router::scoring::ProcessedEndpoints;
use crate::kv_router::KV_HIT_RATE_SUBJECT;
......@@ -112,6 +112,7 @@ impl KvScheduler {
pub async fn start(
endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
ns: Namespace,
kv_block_size: usize,
) -> Result<Self, KvSchedulerError> {
let mut endpoints_rx = endpoints_rx;
......@@ -178,7 +179,8 @@ impl KvScheduler {
};
tracing::debug!("selected");
loop {
match select_worker(endpoints.borrow_mut(), &request, &event_tx) {
match select_worker(endpoints.borrow_mut(), &request, &event_tx, kv_block_size)
{
Ok(worker_id) => {
request.respond(worker_id);
continue 'outer;
......@@ -237,6 +239,7 @@ pub fn select_worker(
workers: &mut ProcessedEndpoints,
request: &SchedulingRequest,
event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
kv_block_size: usize,
) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1;
......@@ -249,7 +252,7 @@ pub fn select_worker(
// Compute each worker's score
let mut best_index = None;
let mut best_cost = f64::INFINITY;
// [FIXME] REMOVE ONLY FOR TESTING
if workers.endpoints.is_empty() {
return Err(KvSchedulerError::NoEndpoints);
}
......@@ -268,7 +271,7 @@ pub fn select_worker(
// [FIXME] multiple endpoints of the same worker cause out of bound error
let worker_id = workers.worker_ids[i];
let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x);
let overlap_score = overlap_score as usize * KV_BLOCK_SIZE;
let overlap_score = overlap_score as usize * kv_block_size;
let new_tokens = request.isl_tokens.saturating_sub(overlap_score);
let normalized_new_tokens = new_tokens as f64 / request.isl_tokens as f64;
......@@ -296,14 +299,14 @@ pub fn select_worker(
}
if let Some(best_index) = best_index {
let total_blocks = min(request.isl_tokens / KV_BLOCK_SIZE, 1);
let total_blocks = min(request.isl_tokens / kv_block_size, 1);
workers.endpoints[best_index].data.request_active_slots += 1;
workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
// Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler
let best_worker_id = workers.endpoints[best_index].worker_id();
let isl_blocks = request.isl_tokens / KV_BLOCK_SIZE;
let isl_blocks = request.isl_tokens / kv_block_size;
let overlap_blocks = request
.overlap
.scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment