Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
aaf283bb
Unverified
Commit
aaf283bb
authored
Jun 30, 2025
by
jthomson04
Committed by
GitHub
Jun 30, 2025
Browse files
feat: Approximate KV Routing (#1636)
parent
9cd9993d
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
806 additions
and
5 deletions
+806
-5
examples/llm/components/kv_router.py
examples/llm/components/kv_router.py
+36
-2
examples/llm/components/processor.py
examples/llm/components/processor.py
+11
-2
examples/llm/components/worker.py
examples/llm/components/worker.py
+1
-1
examples/llm/utils/protocol.py
examples/llm/utils/protocol.py
+1
-0
examples/llm/utils/vllm.py
examples/llm/utils/vllm.py
+2
-0
lib/bindings/python/rust/lib.rs
lib/bindings/python/rust/lib.rs
+1
-0
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+58
-0
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+29
-0
lib/bindings/python/src/dynamo/llm/__init__.py
lib/bindings/python/src/dynamo/llm/__init__.py
+1
-0
lib/bindings/python/tests/test_kv_bindings.py
lib/bindings/python/tests/test_kv_bindings.py
+25
-0
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+1
-0
lib/llm/src/kv_router/approx.rs
lib/llm/src/kv_router/approx.rs
+640
-0
No files found.
examples/llm/components/kv_router.py
View file @
aaf283bb
...
@@ -26,7 +26,13 @@ from utils.check_worker import check_required_workers
...
@@ -26,7 +26,13 @@ from utils.check_worker import check_required_workers
from
utils.protocol
import
LocalBlockHashes
from
utils.protocol
import
LocalBlockHashes
from
utils.vllm
import
RouterType
from
utils.vllm
import
RouterType
from
dynamo.llm
import
AggregatedMetrics
,
KvIndexer
,
KvMetricsAggregator
,
OverlapScores
from
dynamo.llm
import
(
AggregatedMetrics
,
ApproxKvIndexer
,
KvIndexer
,
KvMetricsAggregator
,
OverlapScores
,
)
from
dynamo.sdk
import
async_on_start
,
depends
,
dynamo_context
,
endpoint
,
service
from
dynamo.sdk
import
async_on_start
,
depends
,
dynamo_context
,
endpoint
,
service
from
dynamo.sdk.lib.config
import
ServiceConfig
from
dynamo.sdk.lib.config
import
ServiceConfig
...
@@ -153,6 +159,10 @@ class Router:
...
@@ -153,6 +159,10 @@ class Router:
await
kv_listener
.
create_service
()
await
kv_listener
.
create_service
()
if
self
.
router_type
==
RouterType
.
KV
:
if
self
.
router_type
==
RouterType
.
KV
:
self
.
indexer
=
KvIndexer
(
kv_listener
,
self
.
args
.
block_size
)
self
.
indexer
=
KvIndexer
(
kv_listener
,
self
.
args
.
block_size
)
elif
self
.
router_type
==
RouterType
.
APPROX_KV
:
# For now, hardcode the TTL to 2 minutes.
self
.
indexer
=
ApproxKvIndexer
(
kv_listener
,
self
.
args
.
block_size
,
120.0
)
self
.
metrics_aggregator
=
KvMetricsAggregator
(
kv_listener
)
self
.
metrics_aggregator
=
KvMetricsAggregator
(
kv_listener
)
self
.
active_blocks_dict
=
{}
self
.
active_blocks_dict
=
{}
...
@@ -352,7 +362,10 @@ class Router:
...
@@ -352,7 +362,10 @@ class Router:
# Existing KV routing logic
# Existing KV routing logic
try
:
try
:
scores
=
await
self
.
indexer
.
find_matches
(
request
.
hashes
)
if
self
.
router_type
==
RouterType
.
APPROX_KV
:
scores
=
await
self
.
indexer
.
find_matches_for_request
(
request
.
tokens
)
else
:
scores
=
await
self
.
indexer
.
find_matches
(
request
.
hashes
)
except
Exception
as
e
:
except
Exception
as
e
:
scores
=
{}
scores
=
{}
logger
.
exception
(
f
"Error finding matches:
{
e
}
.
{
fallback_msg
}
"
)
logger
.
exception
(
f
"Error finding matches:
{
e
}
.
{
fallback_msg
}
"
)
...
@@ -363,9 +376,30 @@ class Router:
...
@@ -363,9 +376,30 @@ class Router:
scores
,
metrics
,
request
.
num_tokens
scores
,
metrics
,
request
.
num_tokens
)
)
if
self
.
router_type
==
RouterType
.
APPROX_KV
:
# For the approx kv router, we need to know what worker we route to.
# We can't defer to the engine client to select a random worker.
# Because of this, we need to select a worker here.
if
not
worker_id
:
all_workers
=
self
.
workers_client
.
instance_ids
()
worker_id
=
random
.
choice
(
all_workers
)
await
self
.
log_router_decision
(
request
.
tokens
,
worker_id
)
if
worker_id
:
if
worker_id
:
logger
.
info
(
logger
.
info
(
f
"Scheduling to worker_id:
{
worker_id
}
with estimated prefix hit rate:
{
prefix_hit_rate
}
"
f
"Scheduling to worker_id:
{
worker_id
}
with estimated prefix hit rate:
{
prefix_hit_rate
}
"
)
)
yield
worker_id
,
prefix_hit_rate
yield
worker_id
,
prefix_hit_rate
async
def
log_router_decision
(
self
,
tokens
:
list
[
int
],
worker_id
:
str
):
if
self
.
router_type
==
RouterType
.
APPROX_KV
:
try
:
await
self
.
indexer
.
process_routing_decision_for_request
(
tokens
,
worker_id
)
except
Exception
as
e
:
logger
.
exception
(
f
"Error processing routing decision:
{
e
}
.
{
fallback_msg
}
"
)
examples/llm/components/processor.py
View file @
aaf283bb
...
@@ -102,7 +102,11 @@ class Processor(ProcessMixIn):
...
@@ -102,7 +102,11 @@ class Processor(ProcessMixIn):
.
client
()
.
client
()
)
)
self
.
use_router
=
self
.
engine_args
.
router
in
(
RouterType
.
KV
,
RouterType
.
KV_LOAD
)
self
.
use_router
=
self
.
engine_args
.
router
in
(
RouterType
.
KV
,
RouterType
.
KV_LOAD
,
RouterType
.
APPROX_KV
,
)
if
self
.
use_router
:
if
self
.
use_router
:
router_ns
,
router_name
=
Router
.
dynamo_address
()
# type: ignore
router_ns
,
router_name
=
Router
.
dynamo_address
()
# type: ignore
self
.
router_client
=
(
self
.
router_client
=
(
...
@@ -238,7 +242,11 @@ class Processor(ProcessMixIn):
...
@@ -238,7 +242,11 @@ class Processor(ProcessMixIn):
# TODO: queue request at processor when engines are full
# TODO: queue request at processor when engines are full
router_mode
=
(
await
self
.
etcd_kv_cache
.
get
(
"router"
)).
decode
()
router_mode
=
(
await
self
.
etcd_kv_cache
.
get
(
"router"
)).
decode
()
self
.
use_router
=
router_mode
in
(
RouterType
.
KV
,
RouterType
.
KV_LOAD
)
self
.
use_router
=
router_mode
in
(
RouterType
.
KV
,
RouterType
.
KV_LOAD
,
RouterType
.
APPROX_KV
,
)
prefix_hit_rate
=
0.0
# Default value
prefix_hit_rate
=
0.0
# Default value
if
self
.
use_router
:
if
self
.
use_router
:
...
@@ -248,6 +256,7 @@ class Processor(ProcessMixIn):
...
@@ -248,6 +256,7 @@ class Processor(ProcessMixIn):
hashes
=
compute_block_hash_for_seq_py
(
hashes
=
compute_block_hash_for_seq_py
(
token_ids
,
self
.
engine_args
.
block_size
token_ids
,
self
.
engine_args
.
block_size
),
),
tokens
=
token_ids
,
num_tokens
=
len
(
token_ids
),
num_tokens
=
len
(
token_ids
),
).
model_dump_json
()
).
model_dump_json
()
)
)
...
...
examples/llm/components/worker.py
View file @
aaf283bb
...
@@ -75,7 +75,7 @@ class VllmWorker:
...
@@ -75,7 +75,7 @@ class VllmWorker:
logger
.
info
(
"Pipeline parallel size is not supported yet, setting to 1"
)
logger
.
info
(
"Pipeline parallel size is not supported yet, setting to 1"
)
self
.
engine_args
.
pipeline_parallel_size
=
1
self
.
engine_args
.
pipeline_parallel_size
=
1
if
self
.
engine_args
.
router
==
RouterType
.
KV
:
if
self
.
engine_args
.
router
in
(
RouterType
.
KV
,
RouterType
.
APPROX_KV
)
:
if
not
self
.
engine_args
.
enable_prefix_caching
:
if
not
self
.
engine_args
.
enable_prefix_caching
:
logger
.
info
(
logger
.
info
(
"When using KV router, prefix caching must be enabled, setting to True"
"When using KV router, prefix caching must be enabled, setting to True"
...
...
examples/llm/utils/protocol.py
View file @
aaf283bb
...
@@ -38,6 +38,7 @@ class Tokens(BaseModel):
...
@@ -38,6 +38,7 @@ class Tokens(BaseModel):
class
LocalBlockHashes
(
BaseModel
):
class
LocalBlockHashes
(
BaseModel
):
hashes
:
list
[
int
]
hashes
:
list
[
int
]
tokens
:
list
[
int
]
num_tokens
:
int
num_tokens
:
int
...
...
examples/llm/utils/vllm.py
View file @
aaf283bb
...
@@ -25,6 +25,7 @@ class RouterType:
...
@@ -25,6 +25,7 @@ class RouterType:
ROUND_ROBIN
=
"round-robin"
ROUND_ROBIN
=
"round-robin"
KV
=
"kv"
KV
=
"kv"
KV_LOAD
=
"kv-load"
KV_LOAD
=
"kv-load"
APPROX_KV
=
"approx-kv"
def
parse_vllm_args
(
service_name
,
prefix
)
->
AsyncEngineArgs
:
def
parse_vllm_args
(
service_name
,
prefix
)
->
AsyncEngineArgs
:
...
@@ -39,6 +40,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
...
@@ -39,6 +40,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
RouterType
.
ROUND_ROBIN
,
RouterType
.
ROUND_ROBIN
,
RouterType
.
KV
,
RouterType
.
KV
,
RouterType
.
KV_LOAD
,
RouterType
.
KV_LOAD
,
RouterType
.
APPROX_KV
,
],
],
default
=
RouterType
.
RANDOM
,
default
=
RouterType
.
RANDOM
,
help
=
"Router type to use for scheduling requests to workers"
,
help
=
"Router type to use for scheduling requests to workers"
,
...
...
lib/bindings/python/rust/lib.rs
View file @
aaf283bb
...
@@ -59,6 +59,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
...
@@ -59,6 +59,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_class
::
<
llm
::
backend
::
Backend
>
()
?
;
m
.add_class
::
<
llm
::
backend
::
Backend
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
OverlapScores
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
OverlapScores
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvIndexer
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvIndexer
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
ApproxKvIndexer
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
EndpointKvMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
EndpointKvMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
AggregatedMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
AggregatedMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvMetricsAggregator
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvMetricsAggregator
>
()
?
;
...
...
lib/bindings/python/rust/llm/kv.rs
View file @
aaf283bb
...
@@ -521,6 +521,64 @@ impl KvIndexer {
...
@@ -521,6 +521,64 @@ impl KvIndexer {
}
}
}
}
/// Bindings for the approximate KV indexer. We need to exactly match the regular KV Indexer
/// interface, so that the router can switch between the two.
#[pyclass]
pub
(
crate
)
struct
ApproxKvIndexer
{
inner
:
Arc
<
llm_rs
::
kv_router
::
approx
::
ApproxKvIndexer
>
,
}
#[pymethods]
impl
ApproxKvIndexer
{
#[new]
fn
new
(
component
:
Component
,
kv_block_size
:
usize
,
ttl_secs
:
f64
)
->
PyResult
<
Self
>
{
let
ttl
=
tokio
::
time
::
Duration
::
from_secs_f64
(
ttl_secs
);
let
inner
=
Arc
::
new
(
llm_rs
::
kv_router
::
approx
::
ApproxKvIndexer
::
new
(
component
.inner
.drt
()
.runtime
()
.child_token
(),
kv_block_size
,
ttl
,
));
Ok
(
Self
{
inner
})
}
fn
block_size
(
&
self
)
->
usize
{
self
.inner
.block_size
()
}
fn
find_matches_for_request
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
token_ids
:
Vec
<
u32
>
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
indexer
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
let
rs_overlap_scores
=
indexer
.find_matches_for_request
(
token_ids
.as_slice
())
.await
.map_err
(
to_pyerr
)
?
;
Ok
(
OverlapScores
{
inner
:
rs_overlap_scores
,
})
})
}
fn
process_routing_decision_for_request
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
tokens
:
Vec
<
u32
>
,
worker_id
:
i64
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
indexer
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
indexer
.process_routing_decision_for_request
(
tokens
.as_slice
(),
worker_id
)
.await
.map_err
(
to_pyerr
)
?
;
Ok
(())
})
}
}
#[pyclass]
#[pyclass]
#[derive(Clone)]
#[derive(Clone)]
pub
(
crate
)
struct
EndpointKvMetrics
{
pub
(
crate
)
struct
EndpointKvMetrics
{
...
...
lib/bindings/python/src/dynamo/_core.pyi
View file @
aaf283bb
...
@@ -553,6 +553,35 @@ class KvIndexer:
...
@@ -553,6 +553,35 @@ class KvIndexer:
"""
"""
...
...
class ApproxKvIndexer:
"""
A KV Indexer that doesn't use KV cache events. It instead relies solely on the input tokens.
"""
def __init__(self, component: Component, kv_block_size: int, ttl_secs: float) -> None:
"""
Create a `ApproxKvIndexer` object
"""
...
def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
"""
...
def block_size(self) -> int:
"""
Return the block size of the ApproxKvIndexer.
"""
...
def process_routing_decision_for_request(self, tokens: List[int], lora_id: int, worker_id: int) -> None:
"""
Notify the indexer that a token sequence has been sent to a specific worker.
"""
...
class KvRecorder:
class KvRecorder:
"""
"""
A recorder for KV Router events.
A recorder for KV Router events.
...
...
lib/bindings/python/src/dynamo/llm/__init__.py
View file @
aaf283bb
...
@@ -22,6 +22,7 @@ try:
...
@@ -22,6 +22,7 @@ try:
except
ImportError
:
except
ImportError
:
pass
# BlockManager is not enabled by default
pass
# BlockManager is not enabled by default
from
dynamo._core
import
ApproxKvIndexer
as
ApproxKvIndexer
from
dynamo._core
import
DisaggregatedRouter
as
DisaggregatedRouter
from
dynamo._core
import
DisaggregatedRouter
as
DisaggregatedRouter
from
dynamo._core
import
HttpAsyncEngine
as
HttpAsyncEngine
from
dynamo._core
import
HttpAsyncEngine
as
HttpAsyncEngine
from
dynamo._core
import
HttpError
as
HttpError
from
dynamo._core
import
HttpError
as
HttpError
...
...
lib/bindings/python/tests/test_kv_bindings.py
View file @
aaf283bb
...
@@ -25,6 +25,7 @@ from typing import List
...
@@ -25,6 +25,7 @@ from typing import List
import
pytest
import
pytest
from
dynamo.llm
import
(
from
dynamo.llm
import
(
ApproxKvIndexer
,
KvEventPublisher
,
KvEventPublisher
,
KvIndexer
,
KvIndexer
,
KvMetricsAggregator
,
KvMetricsAggregator
,
...
@@ -150,6 +151,30 @@ async def test_event_handler(distributed_runtime):
...
@@ -150,6 +151,30 @@ async def test_event_handler(distributed_runtime):
assert
not
scores
.
scores
assert
not
scores
.
scores
async
def
test_approx_kv_indexer
(
distributed_runtime
):
kv_block_size
=
32
namespace
=
"kv_test"
component
=
"approx_kv"
kv_listener
=
distributed_runtime
.
namespace
(
namespace
).
component
(
component
)
await
kv_listener
.
create_service
()
indexer
=
ApproxKvIndexer
(
kv_listener
,
kv_block_size
,
30.0
)
tokens
=
[
0
]
*
(
kv_block_size
*
2
)
scores
=
await
indexer
.
find_matches_for_request
(
tokens
)
assert
not
scores
.
scores
worker_id
=
0
await
indexer
.
process_routing_decision_for_request
(
tokens
,
worker_id
)
scores
=
await
indexer
.
find_matches_for_request
(
tokens
)
assert
scores
.
scores
assert
worker_id
in
scores
.
scores
assert
scores
.
scores
[
worker_id
]
==
2
class
EventPublisher
:
class
EventPublisher
:
def
__init__
(
self
,
component
:
Component
,
worker_id
:
int
,
kv_block_size
:
int
):
def
__init__
(
self
,
component
:
Component
,
worker_id
:
int
,
kv_block_size
:
int
):
self
.
publisher
=
KvEventPublisher
(
component
,
worker_id
,
kv_block_size
)
self
.
publisher
=
KvEventPublisher
(
component
,
worker_id
,
kv_block_size
)
...
...
lib/llm/src/kv_router.rs
View file @
aaf283bb
...
@@ -15,6 +15,7 @@ use dynamo_runtime::{
...
@@ -15,6 +15,7 @@ use dynamo_runtime::{
};
};
use
futures
::
stream
::{
self
,
StreamExt
};
use
futures
::
stream
::{
self
,
StreamExt
};
pub
mod
approx
;
pub
mod
indexer
;
pub
mod
indexer
;
pub
mod
metrics_aggregator
;
pub
mod
metrics_aggregator
;
pub
mod
protocols
;
pub
mod
protocols
;
...
...
lib/llm/src/kv_router/approx.rs
0 → 100644
View file @
aaf283bb
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment