Unverified Commit aaf283bb authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Approximate KV Routing (#1636)

parent 9cd9993d
......@@ -26,7 +26,13 @@ from utils.check_worker import check_required_workers
from utils.protocol import LocalBlockHashes
from utils.vllm import RouterType
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.llm import (
AggregatedMetrics,
ApproxKvIndexer,
KvIndexer,
KvMetricsAggregator,
OverlapScores,
)
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
......@@ -153,6 +159,10 @@ class Router:
await kv_listener.create_service()
if self.router_type == RouterType.KV:
self.indexer = KvIndexer(kv_listener, self.args.block_size)
elif self.router_type == RouterType.APPROX_KV:
# For now, hardcode the TTL to 2 minutes.
self.indexer = ApproxKvIndexer(kv_listener, self.args.block_size, 120.0)
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
self.active_blocks_dict = {}
......@@ -352,7 +362,10 @@ class Router:
# Existing KV routing logic
try:
scores = await self.indexer.find_matches(request.hashes)
if self.router_type == RouterType.APPROX_KV:
scores = await self.indexer.find_matches_for_request(request.tokens)
else:
scores = await self.indexer.find_matches(request.hashes)
except Exception as e:
scores = {}
logger.exception(f"Error finding matches: {e}. {fallback_msg}")
......@@ -363,9 +376,30 @@ class Router:
scores, metrics, request.num_tokens
)
if self.router_type == RouterType.APPROX_KV:
# For the approx kv router, we need to know what worker we route to.
# We can't defer to the engine client to select a random worker.
# Because of this, we need to select a worker here.
if not worker_id:
all_workers = self.workers_client.instance_ids()
worker_id = random.choice(all_workers)
await self.log_router_decision(request.tokens, worker_id)
if worker_id:
logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
yield worker_id, prefix_hit_rate
async def log_router_decision(self, tokens: list[int], worker_id: str):
if self.router_type == RouterType.APPROX_KV:
try:
await self.indexer.process_routing_decision_for_request(
tokens, worker_id
)
except Exception as e:
logger.exception(
f"Error processing routing decision: {e}. {fallback_msg}"
)
......@@ -102,7 +102,11 @@ class Processor(ProcessMixIn):
.client()
)
self.use_router = self.engine_args.router in (RouterType.KV, RouterType.KV_LOAD)
self.use_router = self.engine_args.router in (
RouterType.KV,
RouterType.KV_LOAD,
RouterType.APPROX_KV,
)
if self.use_router:
router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = (
......@@ -238,7 +242,11 @@ class Processor(ProcessMixIn):
# TODO: queue request at processor when engines are full
router_mode = (await self.etcd_kv_cache.get("router")).decode()
self.use_router = router_mode in (RouterType.KV, RouterType.KV_LOAD)
self.use_router = router_mode in (
RouterType.KV,
RouterType.KV_LOAD,
RouterType.APPROX_KV,
)
prefix_hit_rate = 0.0 # Default value
if self.use_router:
......@@ -248,6 +256,7 @@ class Processor(ProcessMixIn):
hashes=compute_block_hash_for_seq_py(
token_ids, self.engine_args.block_size
),
tokens=token_ids,
num_tokens=len(token_ids),
).model_dump_json()
)
......
......@@ -75,7 +75,7 @@ class VllmWorker:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.router == RouterType.KV:
if self.engine_args.router in (RouterType.KV, RouterType.APPROX_KV):
if not self.engine_args.enable_prefix_caching:
logger.info(
"When using KV router, prefix caching must be enabled, setting to True"
......
......@@ -38,6 +38,7 @@ class Tokens(BaseModel):
class LocalBlockHashes(BaseModel):
hashes: list[int]
tokens: list[int]
num_tokens: int
......
......@@ -25,6 +25,7 @@ class RouterType:
ROUND_ROBIN = "round-robin"
KV = "kv"
KV_LOAD = "kv-load"
APPROX_KV = "approx-kv"
def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
......@@ -39,6 +40,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
RouterType.ROUND_ROBIN,
RouterType.KV,
RouterType.KV_LOAD,
RouterType.APPROX_KV,
],
default=RouterType.RANDOM,
help="Router type to use for scheduling requests to workers",
......
......@@ -59,6 +59,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::backend::Backend>()?;
m.add_class::<llm::kv::OverlapScores>()?;
m.add_class::<llm::kv::KvIndexer>()?;
m.add_class::<llm::kv::ApproxKvIndexer>()?;
m.add_class::<llm::kv::EndpointKvMetrics>()?;
m.add_class::<llm::kv::AggregatedMetrics>()?;
m.add_class::<llm::kv::KvMetricsAggregator>()?;
......
......@@ -521,6 +521,64 @@ impl KvIndexer {
}
}
/// Bindings for the approximate KV indexer. We need to exactly match the regular KV Indexer
/// interface, so that the router can switch between the two.
#[pyclass]
pub(crate) struct ApproxKvIndexer {
inner: Arc<llm_rs::kv_router::approx::ApproxKvIndexer>,
}
#[pymethods]
impl ApproxKvIndexer {
#[new]
fn new(component: Component, kv_block_size: usize, ttl_secs: f64) -> PyResult<Self> {
let ttl = tokio::time::Duration::from_secs_f64(ttl_secs);
let inner = Arc::new(llm_rs::kv_router::approx::ApproxKvIndexer::new(
component.inner.drt().runtime().child_token(),
kv_block_size,
ttl,
));
Ok(Self { inner })
}
fn block_size(&self) -> usize {
self.inner.block_size()
}
fn find_matches_for_request<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let rs_overlap_scores = indexer
.find_matches_for_request(token_ids.as_slice())
.await
.map_err(to_pyerr)?;
Ok(OverlapScores {
inner: rs_overlap_scores,
})
})
}
fn process_routing_decision_for_request<'p>(
&self,
py: Python<'p>,
tokens: Vec<u32>,
worker_id: i64,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
indexer
.process_routing_decision_for_request(tokens.as_slice(), worker_id)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
}
#[pyclass]
#[derive(Clone)]
pub(crate) struct EndpointKvMetrics {
......
......@@ -553,6 +553,35 @@ class KvIndexer:
"""
...
class ApproxKvIndexer:
"""
A KV Indexer that doesn't use KV cache events. It instead relies solely on the input tokens.
"""
def __init__(self, component: Component, kv_block_size: int, ttl_secs: float) -> None:
"""
Create a `ApproxKvIndexer` object
"""
...
def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
"""
...
def block_size(self) -> int:
"""
Return the block size of the ApproxKvIndexer.
"""
...
def process_routing_decision_for_request(self, tokens: List[int], lora_id: int, worker_id: int) -> None:
"""
Notify the indexer that a token sequence has been sent to a specific worker.
"""
...
class KvRecorder:
"""
A recorder for KV Router events.
......
......@@ -22,6 +22,7 @@ try:
except ImportError:
pass # BlockManager is not enabled by default
from dynamo._core import ApproxKvIndexer as ApproxKvIndexer
from dynamo._core import DisaggregatedRouter as DisaggregatedRouter
from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpError as HttpError
......
......@@ -25,6 +25,7 @@ from typing import List
import pytest
from dynamo.llm import (
ApproxKvIndexer,
KvEventPublisher,
KvIndexer,
KvMetricsAggregator,
......@@ -150,6 +151,30 @@ async def test_event_handler(distributed_runtime):
assert not scores.scores
async def test_approx_kv_indexer(distributed_runtime):
kv_block_size = 32
namespace = "kv_test"
component = "approx_kv"
kv_listener = distributed_runtime.namespace(namespace).component(component)
await kv_listener.create_service()
indexer = ApproxKvIndexer(kv_listener, kv_block_size, 30.0)
tokens = [0] * (kv_block_size * 2)
scores = await indexer.find_matches_for_request(tokens)
assert not scores.scores
worker_id = 0
await indexer.process_routing_decision_for_request(tokens, worker_id)
scores = await indexer.find_matches_for_request(tokens)
assert scores.scores
assert worker_id in scores.scores
assert scores.scores[worker_id] == 2
class EventPublisher:
def __init__(self, component: Component, worker_id: int, kv_block_size: int):
self.publisher = KvEventPublisher(component, worker_id, kv_block_size)
......
......@@ -15,6 +15,7 @@ use dynamo_runtime::{
};
use futures::stream::{self, StreamExt};
pub mod approx;
pub mod indexer;
pub mod metrics_aggregator;
pub mod protocols;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment