Unverified Commit aaf283bb authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Approximate KV Routing (#1636)

parent 9cd9993d
...@@ -26,7 +26,13 @@ from utils.check_worker import check_required_workers ...@@ -26,7 +26,13 @@ from utils.check_worker import check_required_workers
from utils.protocol import LocalBlockHashes from utils.protocol import LocalBlockHashes
from utils.vllm import RouterType from utils.vllm import RouterType
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores from dynamo.llm import (
AggregatedMetrics,
ApproxKvIndexer,
KvIndexer,
KvMetricsAggregator,
OverlapScores,
)
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
...@@ -153,6 +159,10 @@ class Router: ...@@ -153,6 +159,10 @@ class Router:
await kv_listener.create_service() await kv_listener.create_service()
if self.router_type == RouterType.KV: if self.router_type == RouterType.KV:
self.indexer = KvIndexer(kv_listener, self.args.block_size) self.indexer = KvIndexer(kv_listener, self.args.block_size)
elif self.router_type == RouterType.APPROX_KV:
# For now, hardcode the TTL to 2 minutes.
self.indexer = ApproxKvIndexer(kv_listener, self.args.block_size, 120.0)
self.metrics_aggregator = KvMetricsAggregator(kv_listener) self.metrics_aggregator = KvMetricsAggregator(kv_listener)
self.active_blocks_dict = {} self.active_blocks_dict = {}
...@@ -352,6 +362,9 @@ class Router: ...@@ -352,6 +362,9 @@ class Router:
# Existing KV routing logic # Existing KV routing logic
try: try:
if self.router_type == RouterType.APPROX_KV:
scores = await self.indexer.find_matches_for_request(request.tokens)
else:
scores = await self.indexer.find_matches(request.hashes) scores = await self.indexer.find_matches(request.hashes)
except Exception as e: except Exception as e:
scores = {} scores = {}
...@@ -363,9 +376,30 @@ class Router: ...@@ -363,9 +376,30 @@ class Router:
scores, metrics, request.num_tokens scores, metrics, request.num_tokens
) )
if self.router_type == RouterType.APPROX_KV:
# For the approx kv router, we need to know what worker we route to.
# We can't defer to the engine client to select a random worker.
# Because of this, we need to select a worker here.
if not worker_id:
all_workers = self.workers_client.instance_ids()
worker_id = random.choice(all_workers)
await self.log_router_decision(request.tokens, worker_id)
if worker_id: if worker_id:
logger.info( logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}" f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
) )
yield worker_id, prefix_hit_rate yield worker_id, prefix_hit_rate
async def log_router_decision(self, tokens: list[int], worker_id: str):
if self.router_type == RouterType.APPROX_KV:
try:
await self.indexer.process_routing_decision_for_request(
tokens, worker_id
)
except Exception as e:
logger.exception(
f"Error processing routing decision: {e}. {fallback_msg}"
)
...@@ -102,7 +102,11 @@ class Processor(ProcessMixIn): ...@@ -102,7 +102,11 @@ class Processor(ProcessMixIn):
.client() .client()
) )
self.use_router = self.engine_args.router in (RouterType.KV, RouterType.KV_LOAD) self.use_router = self.engine_args.router in (
RouterType.KV,
RouterType.KV_LOAD,
RouterType.APPROX_KV,
)
if self.use_router: if self.use_router:
router_ns, router_name = Router.dynamo_address() # type: ignore router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = ( self.router_client = (
...@@ -238,7 +242,11 @@ class Processor(ProcessMixIn): ...@@ -238,7 +242,11 @@ class Processor(ProcessMixIn):
# TODO: queue request at processor when engines are full # TODO: queue request at processor when engines are full
router_mode = (await self.etcd_kv_cache.get("router")).decode() router_mode = (await self.etcd_kv_cache.get("router")).decode()
self.use_router = router_mode in (RouterType.KV, RouterType.KV_LOAD) self.use_router = router_mode in (
RouterType.KV,
RouterType.KV_LOAD,
RouterType.APPROX_KV,
)
prefix_hit_rate = 0.0 # Default value prefix_hit_rate = 0.0 # Default value
if self.use_router: if self.use_router:
...@@ -248,6 +256,7 @@ class Processor(ProcessMixIn): ...@@ -248,6 +256,7 @@ class Processor(ProcessMixIn):
hashes=compute_block_hash_for_seq_py( hashes=compute_block_hash_for_seq_py(
token_ids, self.engine_args.block_size token_ids, self.engine_args.block_size
), ),
tokens=token_ids,
num_tokens=len(token_ids), num_tokens=len(token_ids),
).model_dump_json() ).model_dump_json()
) )
......
...@@ -75,7 +75,7 @@ class VllmWorker: ...@@ -75,7 +75,7 @@ class VllmWorker:
logger.info("Pipeline parallel size is not supported yet, setting to 1") logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1 self.engine_args.pipeline_parallel_size = 1
if self.engine_args.router == RouterType.KV: if self.engine_args.router in (RouterType.KV, RouterType.APPROX_KV):
if not self.engine_args.enable_prefix_caching: if not self.engine_args.enable_prefix_caching:
logger.info( logger.info(
"When using KV router, prefix caching must be enabled, setting to True" "When using KV router, prefix caching must be enabled, setting to True"
......
...@@ -38,6 +38,7 @@ class Tokens(BaseModel): ...@@ -38,6 +38,7 @@ class Tokens(BaseModel):
class LocalBlockHashes(BaseModel): class LocalBlockHashes(BaseModel):
hashes: list[int] hashes: list[int]
tokens: list[int]
num_tokens: int num_tokens: int
......
...@@ -25,6 +25,7 @@ class RouterType: ...@@ -25,6 +25,7 @@ class RouterType:
ROUND_ROBIN = "round-robin" ROUND_ROBIN = "round-robin"
KV = "kv" KV = "kv"
KV_LOAD = "kv-load" KV_LOAD = "kv-load"
APPROX_KV = "approx-kv"
def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
...@@ -39,6 +40,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: ...@@ -39,6 +40,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
RouterType.ROUND_ROBIN, RouterType.ROUND_ROBIN,
RouterType.KV, RouterType.KV,
RouterType.KV_LOAD, RouterType.KV_LOAD,
RouterType.APPROX_KV,
], ],
default=RouterType.RANDOM, default=RouterType.RANDOM,
help="Router type to use for scheduling requests to workers", help="Router type to use for scheduling requests to workers",
......
...@@ -59,6 +59,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -59,6 +59,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::backend::Backend>()?; m.add_class::<llm::backend::Backend>()?;
m.add_class::<llm::kv::OverlapScores>()?; m.add_class::<llm::kv::OverlapScores>()?;
m.add_class::<llm::kv::KvIndexer>()?; m.add_class::<llm::kv::KvIndexer>()?;
m.add_class::<llm::kv::ApproxKvIndexer>()?;
m.add_class::<llm::kv::EndpointKvMetrics>()?; m.add_class::<llm::kv::EndpointKvMetrics>()?;
m.add_class::<llm::kv::AggregatedMetrics>()?; m.add_class::<llm::kv::AggregatedMetrics>()?;
m.add_class::<llm::kv::KvMetricsAggregator>()?; m.add_class::<llm::kv::KvMetricsAggregator>()?;
......
...@@ -521,6 +521,64 @@ impl KvIndexer { ...@@ -521,6 +521,64 @@ impl KvIndexer {
} }
} }
/// Bindings for the approximate KV indexer. We need to exactly match the regular KV Indexer
/// interface, so that the router can switch between the two.
#[pyclass]
pub(crate) struct ApproxKvIndexer {
inner: Arc<llm_rs::kv_router::approx::ApproxKvIndexer>,
}
#[pymethods]
impl ApproxKvIndexer {
#[new]
fn new(component: Component, kv_block_size: usize, ttl_secs: f64) -> PyResult<Self> {
let ttl = tokio::time::Duration::from_secs_f64(ttl_secs);
let inner = Arc::new(llm_rs::kv_router::approx::ApproxKvIndexer::new(
component.inner.drt().runtime().child_token(),
kv_block_size,
ttl,
));
Ok(Self { inner })
}
fn block_size(&self) -> usize {
self.inner.block_size()
}
fn find_matches_for_request<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let rs_overlap_scores = indexer
.find_matches_for_request(token_ids.as_slice())
.await
.map_err(to_pyerr)?;
Ok(OverlapScores {
inner: rs_overlap_scores,
})
})
}
fn process_routing_decision_for_request<'p>(
&self,
py: Python<'p>,
tokens: Vec<u32>,
worker_id: i64,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
indexer
.process_routing_decision_for_request(tokens.as_slice(), worker_id)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
}
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct EndpointKvMetrics { pub(crate) struct EndpointKvMetrics {
......
...@@ -553,6 +553,35 @@ class KvIndexer: ...@@ -553,6 +553,35 @@ class KvIndexer:
""" """
... ...
class ApproxKvIndexer:
"""
A KV Indexer that doesn't use KV cache events. It instead relies solely on the input tokens.
"""
def __init__(self, component: Component, kv_block_size: int, ttl_secs: float) -> None:
"""
Create a `ApproxKvIndexer` object
"""
...
def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
"""
...
def block_size(self) -> int:
"""
Return the block size of the ApproxKvIndexer.
"""
...
def process_routing_decision_for_request(self, tokens: List[int], lora_id: int, worker_id: int) -> None:
"""
Notify the indexer that a token sequence has been sent to a specific worker.
"""
...
class KvRecorder: class KvRecorder:
""" """
A recorder for KV Router events. A recorder for KV Router events.
......
...@@ -22,6 +22,7 @@ try: ...@@ -22,6 +22,7 @@ try:
except ImportError: except ImportError:
pass # BlockManager is not enabled by default pass # BlockManager is not enabled by default
from dynamo._core import ApproxKvIndexer as ApproxKvIndexer
from dynamo._core import DisaggregatedRouter as DisaggregatedRouter from dynamo._core import DisaggregatedRouter as DisaggregatedRouter
from dynamo._core import HttpAsyncEngine as HttpAsyncEngine from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpError as HttpError from dynamo._core import HttpError as HttpError
......
...@@ -25,6 +25,7 @@ from typing import List ...@@ -25,6 +25,7 @@ from typing import List
import pytest import pytest
from dynamo.llm import ( from dynamo.llm import (
ApproxKvIndexer,
KvEventPublisher, KvEventPublisher,
KvIndexer, KvIndexer,
KvMetricsAggregator, KvMetricsAggregator,
...@@ -150,6 +151,30 @@ async def test_event_handler(distributed_runtime): ...@@ -150,6 +151,30 @@ async def test_event_handler(distributed_runtime):
assert not scores.scores assert not scores.scores
async def test_approx_kv_indexer(distributed_runtime):
kv_block_size = 32
namespace = "kv_test"
component = "approx_kv"
kv_listener = distributed_runtime.namespace(namespace).component(component)
await kv_listener.create_service()
indexer = ApproxKvIndexer(kv_listener, kv_block_size, 30.0)
tokens = [0] * (kv_block_size * 2)
scores = await indexer.find_matches_for_request(tokens)
assert not scores.scores
worker_id = 0
await indexer.process_routing_decision_for_request(tokens, worker_id)
scores = await indexer.find_matches_for_request(tokens)
assert scores.scores
assert worker_id in scores.scores
assert scores.scores[worker_id] == 2
class EventPublisher: class EventPublisher:
def __init__(self, component: Component, worker_id: int, kv_block_size: int): def __init__(self, component: Component, worker_id: int, kv_block_size: int):
self.publisher = KvEventPublisher(component, worker_id, kv_block_size) self.publisher = KvEventPublisher(component, worker_id, kv_block_size)
......
...@@ -15,6 +15,7 @@ use dynamo_runtime::{ ...@@ -15,6 +15,7 @@ use dynamo_runtime::{
}; };
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
pub mod approx;
pub mod indexer; pub mod indexer;
pub mod metrics_aggregator; pub mod metrics_aggregator;
pub mod protocols; pub mod protocols;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment