Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
aaf283bb
"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "cf5f65f7dfe87c88ac33cfadf3cd17b8ad96c8e3"
Unverified
Commit
aaf283bb
authored
Jun 30, 2025
by
jthomson04
Committed by
GitHub
Jun 30, 2025
Browse files
feat: Approximate KV Routing (#1636)
parent
9cd9993d
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
806 additions
and
5 deletions
+806
-5
examples/llm/components/kv_router.py
examples/llm/components/kv_router.py
+36
-2
examples/llm/components/processor.py
examples/llm/components/processor.py
+11
-2
examples/llm/components/worker.py
examples/llm/components/worker.py
+1
-1
examples/llm/utils/protocol.py
examples/llm/utils/protocol.py
+1
-0
examples/llm/utils/vllm.py
examples/llm/utils/vllm.py
+2
-0
lib/bindings/python/rust/lib.rs
lib/bindings/python/rust/lib.rs
+1
-0
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+58
-0
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+29
-0
lib/bindings/python/src/dynamo/llm/__init__.py
lib/bindings/python/src/dynamo/llm/__init__.py
+1
-0
lib/bindings/python/tests/test_kv_bindings.py
lib/bindings/python/tests/test_kv_bindings.py
+25
-0
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+1
-0
lib/llm/src/kv_router/approx.rs
lib/llm/src/kv_router/approx.rs
+640
-0
No files found.
examples/llm/components/kv_router.py
View file @
aaf283bb
...
@@ -26,7 +26,13 @@ from utils.check_worker import check_required_workers
...
@@ -26,7 +26,13 @@ from utils.check_worker import check_required_workers
from
utils.protocol
import
LocalBlockHashes
from
utils.protocol
import
LocalBlockHashes
from
utils.vllm
import
RouterType
from
utils.vllm
import
RouterType
from
dynamo.llm
import
AggregatedMetrics
,
KvIndexer
,
KvMetricsAggregator
,
OverlapScores
from
dynamo.llm
import
(
AggregatedMetrics
,
ApproxKvIndexer
,
KvIndexer
,
KvMetricsAggregator
,
OverlapScores
,
)
from
dynamo.sdk
import
async_on_start
,
depends
,
dynamo_context
,
endpoint
,
service
from
dynamo.sdk
import
async_on_start
,
depends
,
dynamo_context
,
endpoint
,
service
from
dynamo.sdk.lib.config
import
ServiceConfig
from
dynamo.sdk.lib.config
import
ServiceConfig
...
@@ -153,6 +159,10 @@ class Router:
...
@@ -153,6 +159,10 @@ class Router:
await
kv_listener
.
create_service
()
await
kv_listener
.
create_service
()
if
self
.
router_type
==
RouterType
.
KV
:
if
self
.
router_type
==
RouterType
.
KV
:
self
.
indexer
=
KvIndexer
(
kv_listener
,
self
.
args
.
block_size
)
self
.
indexer
=
KvIndexer
(
kv_listener
,
self
.
args
.
block_size
)
elif
self
.
router_type
==
RouterType
.
APPROX_KV
:
# For now, hardcode the TTL to 2 minutes.
self
.
indexer
=
ApproxKvIndexer
(
kv_listener
,
self
.
args
.
block_size
,
120.0
)
self
.
metrics_aggregator
=
KvMetricsAggregator
(
kv_listener
)
self
.
metrics_aggregator
=
KvMetricsAggregator
(
kv_listener
)
self
.
active_blocks_dict
=
{}
self
.
active_blocks_dict
=
{}
...
@@ -352,7 +362,10 @@ class Router:
...
@@ -352,7 +362,10 @@ class Router:
# Existing KV routing logic
# Existing KV routing logic
try
:
try
:
scores
=
await
self
.
indexer
.
find_matches
(
request
.
hashes
)
if
self
.
router_type
==
RouterType
.
APPROX_KV
:
scores
=
await
self
.
indexer
.
find_matches_for_request
(
request
.
tokens
)
else
:
scores
=
await
self
.
indexer
.
find_matches
(
request
.
hashes
)
except
Exception
as
e
:
except
Exception
as
e
:
scores
=
{}
scores
=
{}
logger
.
exception
(
f
"Error finding matches:
{
e
}
.
{
fallback_msg
}
"
)
logger
.
exception
(
f
"Error finding matches:
{
e
}
.
{
fallback_msg
}
"
)
...
@@ -363,9 +376,30 @@ class Router:
...
@@ -363,9 +376,30 @@ class Router:
scores
,
metrics
,
request
.
num_tokens
scores
,
metrics
,
request
.
num_tokens
)
)
if
self
.
router_type
==
RouterType
.
APPROX_KV
:
# For the approx kv router, we need to know what worker we route to.
# We can't defer to the engine client to select a random worker.
# Because of this, we need to select a worker here.
if
not
worker_id
:
all_workers
=
self
.
workers_client
.
instance_ids
()
worker_id
=
random
.
choice
(
all_workers
)
await
self
.
log_router_decision
(
request
.
tokens
,
worker_id
)
if
worker_id
:
if
worker_id
:
logger
.
info
(
logger
.
info
(
f
"Scheduling to worker_id:
{
worker_id
}
with estimated prefix hit rate:
{
prefix_hit_rate
}
"
f
"Scheduling to worker_id:
{
worker_id
}
with estimated prefix hit rate:
{
prefix_hit_rate
}
"
)
)
yield
worker_id
,
prefix_hit_rate
yield
worker_id
,
prefix_hit_rate
async
def
log_router_decision
(
self
,
tokens
:
list
[
int
],
worker_id
:
str
):
if
self
.
router_type
==
RouterType
.
APPROX_KV
:
try
:
await
self
.
indexer
.
process_routing_decision_for_request
(
tokens
,
worker_id
)
except
Exception
as
e
:
logger
.
exception
(
f
"Error processing routing decision:
{
e
}
.
{
fallback_msg
}
"
)
examples/llm/components/processor.py
View file @
aaf283bb
...
@@ -102,7 +102,11 @@ class Processor(ProcessMixIn):
...
@@ -102,7 +102,11 @@ class Processor(ProcessMixIn):
.
client
()
.
client
()
)
)
self
.
use_router
=
self
.
engine_args
.
router
in
(
RouterType
.
KV
,
RouterType
.
KV_LOAD
)
self
.
use_router
=
self
.
engine_args
.
router
in
(
RouterType
.
KV
,
RouterType
.
KV_LOAD
,
RouterType
.
APPROX_KV
,
)
if
self
.
use_router
:
if
self
.
use_router
:
router_ns
,
router_name
=
Router
.
dynamo_address
()
# type: ignore
router_ns
,
router_name
=
Router
.
dynamo_address
()
# type: ignore
self
.
router_client
=
(
self
.
router_client
=
(
...
@@ -238,7 +242,11 @@ class Processor(ProcessMixIn):
...
@@ -238,7 +242,11 @@ class Processor(ProcessMixIn):
# TODO: queue request at processor when engines are full
# TODO: queue request at processor when engines are full
router_mode
=
(
await
self
.
etcd_kv_cache
.
get
(
"router"
)).
decode
()
router_mode
=
(
await
self
.
etcd_kv_cache
.
get
(
"router"
)).
decode
()
self
.
use_router
=
router_mode
in
(
RouterType
.
KV
,
RouterType
.
KV_LOAD
)
self
.
use_router
=
router_mode
in
(
RouterType
.
KV
,
RouterType
.
KV_LOAD
,
RouterType
.
APPROX_KV
,
)
prefix_hit_rate
=
0.0
# Default value
prefix_hit_rate
=
0.0
# Default value
if
self
.
use_router
:
if
self
.
use_router
:
...
@@ -248,6 +256,7 @@ class Processor(ProcessMixIn):
...
@@ -248,6 +256,7 @@ class Processor(ProcessMixIn):
hashes
=
compute_block_hash_for_seq_py
(
hashes
=
compute_block_hash_for_seq_py
(
token_ids
,
self
.
engine_args
.
block_size
token_ids
,
self
.
engine_args
.
block_size
),
),
tokens
=
token_ids
,
num_tokens
=
len
(
token_ids
),
num_tokens
=
len
(
token_ids
),
).
model_dump_json
()
).
model_dump_json
()
)
)
...
...
examples/llm/components/worker.py
View file @
aaf283bb
...
@@ -75,7 +75,7 @@ class VllmWorker:
...
@@ -75,7 +75,7 @@ class VllmWorker:
logger
.
info
(
"Pipeline parallel size is not supported yet, setting to 1"
)
logger
.
info
(
"Pipeline parallel size is not supported yet, setting to 1"
)
self
.
engine_args
.
pipeline_parallel_size
=
1
self
.
engine_args
.
pipeline_parallel_size
=
1
if
self
.
engine_args
.
router
==
RouterType
.
KV
:
if
self
.
engine_args
.
router
in
(
RouterType
.
KV
,
RouterType
.
APPROX_KV
)
:
if
not
self
.
engine_args
.
enable_prefix_caching
:
if
not
self
.
engine_args
.
enable_prefix_caching
:
logger
.
info
(
logger
.
info
(
"When using KV router, prefix caching must be enabled, setting to True"
"When using KV router, prefix caching must be enabled, setting to True"
...
...
examples/llm/utils/protocol.py
View file @
aaf283bb
...
@@ -38,6 +38,7 @@ class Tokens(BaseModel):
...
@@ -38,6 +38,7 @@ class Tokens(BaseModel):
class
LocalBlockHashes
(
BaseModel
):
class
LocalBlockHashes
(
BaseModel
):
hashes
:
list
[
int
]
hashes
:
list
[
int
]
tokens
:
list
[
int
]
num_tokens
:
int
num_tokens
:
int
...
...
examples/llm/utils/vllm.py
View file @
aaf283bb
...
@@ -25,6 +25,7 @@ class RouterType:
...
@@ -25,6 +25,7 @@ class RouterType:
ROUND_ROBIN
=
"round-robin"
ROUND_ROBIN
=
"round-robin"
KV
=
"kv"
KV
=
"kv"
KV_LOAD
=
"kv-load"
KV_LOAD
=
"kv-load"
APPROX_KV
=
"approx-kv"
def
parse_vllm_args
(
service_name
,
prefix
)
->
AsyncEngineArgs
:
def
parse_vllm_args
(
service_name
,
prefix
)
->
AsyncEngineArgs
:
...
@@ -39,6 +40,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
...
@@ -39,6 +40,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
RouterType
.
ROUND_ROBIN
,
RouterType
.
ROUND_ROBIN
,
RouterType
.
KV
,
RouterType
.
KV
,
RouterType
.
KV_LOAD
,
RouterType
.
KV_LOAD
,
RouterType
.
APPROX_KV
,
],
],
default
=
RouterType
.
RANDOM
,
default
=
RouterType
.
RANDOM
,
help
=
"Router type to use for scheduling requests to workers"
,
help
=
"Router type to use for scheduling requests to workers"
,
...
...
lib/bindings/python/rust/lib.rs
View file @
aaf283bb
...
@@ -59,6 +59,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
...
@@ -59,6 +59,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_class
::
<
llm
::
backend
::
Backend
>
()
?
;
m
.add_class
::
<
llm
::
backend
::
Backend
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
OverlapScores
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
OverlapScores
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvIndexer
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvIndexer
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
ApproxKvIndexer
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
EndpointKvMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
EndpointKvMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
AggregatedMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
AggregatedMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvMetricsAggregator
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvMetricsAggregator
>
()
?
;
...
...
lib/bindings/python/rust/llm/kv.rs
View file @
aaf283bb
...
@@ -521,6 +521,64 @@ impl KvIndexer {
...
@@ -521,6 +521,64 @@ impl KvIndexer {
}
}
}
}
/// Bindings for the approximate KV indexer. We need to exactly match the regular KV Indexer
/// interface, so that the router can switch between the two.
#[pyclass]
pub
(
crate
)
struct
ApproxKvIndexer
{
inner
:
Arc
<
llm_rs
::
kv_router
::
approx
::
ApproxKvIndexer
>
,
}
#[pymethods]
impl
ApproxKvIndexer
{
#[new]
fn
new
(
component
:
Component
,
kv_block_size
:
usize
,
ttl_secs
:
f64
)
->
PyResult
<
Self
>
{
let
ttl
=
tokio
::
time
::
Duration
::
from_secs_f64
(
ttl_secs
);
let
inner
=
Arc
::
new
(
llm_rs
::
kv_router
::
approx
::
ApproxKvIndexer
::
new
(
component
.inner
.drt
()
.runtime
()
.child_token
(),
kv_block_size
,
ttl
,
));
Ok
(
Self
{
inner
})
}
fn
block_size
(
&
self
)
->
usize
{
self
.inner
.block_size
()
}
fn
find_matches_for_request
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
token_ids
:
Vec
<
u32
>
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
indexer
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
let
rs_overlap_scores
=
indexer
.find_matches_for_request
(
token_ids
.as_slice
())
.await
.map_err
(
to_pyerr
)
?
;
Ok
(
OverlapScores
{
inner
:
rs_overlap_scores
,
})
})
}
fn
process_routing_decision_for_request
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
tokens
:
Vec
<
u32
>
,
worker_id
:
i64
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
indexer
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
indexer
.process_routing_decision_for_request
(
tokens
.as_slice
(),
worker_id
)
.await
.map_err
(
to_pyerr
)
?
;
Ok
(())
})
}
}
#[pyclass]
#[pyclass]
#[derive(Clone)]
#[derive(Clone)]
pub
(
crate
)
struct
EndpointKvMetrics
{
pub
(
crate
)
struct
EndpointKvMetrics
{
...
...
lib/bindings/python/src/dynamo/_core.pyi
View file @
aaf283bb
...
@@ -553,6 +553,35 @@ class KvIndexer:
...
@@ -553,6 +553,35 @@ class KvIndexer:
"""
"""
...
...
class ApproxKvIndexer:
"""
A KV Indexer that doesn't use KV cache events. It instead relies solely on the input tokens.
"""
def __init__(self, component: Component, kv_block_size: int, ttl_secs: float) -> None:
"""
Create a `ApproxKvIndexer` object
"""
...
def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
"""
...
def block_size(self) -> int:
"""
Return the block size of the ApproxKvIndexer.
"""
...
def process_routing_decision_for_request(self, tokens: List[int], lora_id: int, worker_id: int) -> None:
"""
Notify the indexer that a token sequence has been sent to a specific worker.
"""
...
class KvRecorder:
class KvRecorder:
"""
"""
A recorder for KV Router events.
A recorder for KV Router events.
...
...
lib/bindings/python/src/dynamo/llm/__init__.py
View file @
aaf283bb
...
@@ -22,6 +22,7 @@ try:
...
@@ -22,6 +22,7 @@ try:
except
ImportError
:
except
ImportError
:
pass
# BlockManager is not enabled by default
pass
# BlockManager is not enabled by default
from
dynamo._core
import
ApproxKvIndexer
as
ApproxKvIndexer
from
dynamo._core
import
DisaggregatedRouter
as
DisaggregatedRouter
from
dynamo._core
import
DisaggregatedRouter
as
DisaggregatedRouter
from
dynamo._core
import
HttpAsyncEngine
as
HttpAsyncEngine
from
dynamo._core
import
HttpAsyncEngine
as
HttpAsyncEngine
from
dynamo._core
import
HttpError
as
HttpError
from
dynamo._core
import
HttpError
as
HttpError
...
...
lib/bindings/python/tests/test_kv_bindings.py
View file @
aaf283bb
...
@@ -25,6 +25,7 @@ from typing import List
...
@@ -25,6 +25,7 @@ from typing import List
import
pytest
import
pytest
from
dynamo.llm
import
(
from
dynamo.llm
import
(
ApproxKvIndexer
,
KvEventPublisher
,
KvEventPublisher
,
KvIndexer
,
KvIndexer
,
KvMetricsAggregator
,
KvMetricsAggregator
,
...
@@ -150,6 +151,30 @@ async def test_event_handler(distributed_runtime):
...
@@ -150,6 +151,30 @@ async def test_event_handler(distributed_runtime):
assert
not
scores
.
scores
assert
not
scores
.
scores
async
def
test_approx_kv_indexer
(
distributed_runtime
):
kv_block_size
=
32
namespace
=
"kv_test"
component
=
"approx_kv"
kv_listener
=
distributed_runtime
.
namespace
(
namespace
).
component
(
component
)
await
kv_listener
.
create_service
()
indexer
=
ApproxKvIndexer
(
kv_listener
,
kv_block_size
,
30.0
)
tokens
=
[
0
]
*
(
kv_block_size
*
2
)
scores
=
await
indexer
.
find_matches_for_request
(
tokens
)
assert
not
scores
.
scores
worker_id
=
0
await
indexer
.
process_routing_decision_for_request
(
tokens
,
worker_id
)
scores
=
await
indexer
.
find_matches_for_request
(
tokens
)
assert
scores
.
scores
assert
worker_id
in
scores
.
scores
assert
scores
.
scores
[
worker_id
]
==
2
class
EventPublisher
:
class
EventPublisher
:
def
__init__
(
self
,
component
:
Component
,
worker_id
:
int
,
kv_block_size
:
int
):
def
__init__
(
self
,
component
:
Component
,
worker_id
:
int
,
kv_block_size
:
int
):
self
.
publisher
=
KvEventPublisher
(
component
,
worker_id
,
kv_block_size
)
self
.
publisher
=
KvEventPublisher
(
component
,
worker_id
,
kv_block_size
)
...
...
lib/llm/src/kv_router.rs
View file @
aaf283bb
...
@@ -15,6 +15,7 @@ use dynamo_runtime::{
...
@@ -15,6 +15,7 @@ use dynamo_runtime::{
};
};
use
futures
::
stream
::{
self
,
StreamExt
};
use
futures
::
stream
::{
self
,
StreamExt
};
pub
mod
approx
;
pub
mod
indexer
;
pub
mod
indexer
;
pub
mod
metrics_aggregator
;
pub
mod
metrics_aggregator
;
pub
mod
protocols
;
pub
mod
protocols
;
...
...
lib/llm/src/kv_router/approx.rs
0 → 100644
View file @
aaf283bb
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment