Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
e159e53f
Commit
e159e53f
authored
Mar 05, 2025
by
GuanLuo
Committed by
GitHub
Mar 05, 2025
Browse files
feat: expose KV routing components for easier router customization (#15)
parent
ea78a424
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
442 additions
and
5 deletions
+442
-5
examples/python_rs/llm/vllm/kv_router/router.py
examples/python_rs/llm/vllm/kv_router/router.py
+71
-4
lib/bindings/python/rust/lib.rs
lib/bindings/python/rust/lib.rs
+5
-0
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+161
-0
lib/bindings/python/src/dynemo/_core.pyi
lib/bindings/python/src/dynemo/_core.pyi
+52
-0
lib/bindings/python/src/dynemo/llm/__init__.py
lib/bindings/python/src/dynemo/llm/__init__.py
+2
-0
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+1
-0
lib/llm/src/kv_router/metrics_aggregator.rs
lib/llm/src/kv_router/metrics_aggregator.rs
+149
-0
lib/llm/src/kv_router/scoring.rs
lib/llm/src/kv_router/scoring.rs
+1
-1
No files found.
examples/python_rs/llm/vllm/kv_router/router.py
View file @
e159e53f
...
...
@@ -23,7 +23,7 @@ import uvloop
from
common.protocol
import
Tokens
from
vllm.logger
import
logger
as
vllm_logger
from
dynemo.llm
import
KvRouter
from
dynemo.llm
import
KvIndexer
,
KvMetricsAggregator
,
KvRouter
from
dynemo.runtime
import
DistributedRuntime
,
dynemo_endpoint
,
dynemo_worker
WorkerId
=
str
...
...
@@ -78,6 +78,60 @@ class Router:
)
class
CustomRouter
:
"""
Request handler for the generate endpoint
"""
def
__init__
(
self
,
indexer
:
KvIndexer
,
metrics_aggregator
:
KvMetricsAggregator
,
):
self
.
indexer
=
indexer
self
.
metrics_aggregator
=
metrics_aggregator
def
_cost_function
(
self
,
scores
,
metrics
):
# naive cost function for demonstration purposes
current_best
=
(
""
,
0
)
for
worker_id
,
score
in
scores
.
scores
.
items
():
if
score
>
current_best
[
1
]:
current_best
=
(
worker_id
,
score
)
for
endpoint
in
metrics
.
endpoints
:
if
endpoint
.
worker_id
==
current_best
[
0
]:
print
(
f
"Metrics of endpoint:
{
endpoint
.
worker_id
}
"
)
print
(
f
"request slot usage:
{
endpoint
.
request_active_slots
}
/
{
endpoint
.
request_total_slots
}
"
)
print
(
f
"KV block usage:
{
endpoint
.
kv_active_blocks
}
/
{
endpoint
.
kv_total_blocks
}
"
)
return
current_best
[
0
]
@
dynemo_endpoint
(
Tokens
,
WorkerId
)
async
def
generate
(
self
,
request
)
->
AsyncIterator
[
WorkerId
]:
lora_id
=
0
worker_id
=
""
try
:
scores
=
await
self
.
indexer
.
find_matches_for_request
(
request
.
tokens
,
lora_id
)
metrics
=
await
self
.
metrics_aggregator
.
get_metrics
()
worker_id
=
self
.
_cost_function
(
scores
,
metrics
)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except
Exception
as
e
:
vllm_logger
.
info
(
f
"
{
e
}
"
)
worker_id
=
""
vllm_logger
.
exception
(
f
"Error during worker selection:
{
e
}
"
)
vllm_logger
.
info
(
f
"Scheduling to worker_id:
{
worker_id
}
"
)
yield
str
(
worker_id
)
@
dynemo_worker
()
async
def
worker
(
runtime
:
DistributedRuntime
,
args
:
Namespace
):
"""
...
...
@@ -116,10 +170,17 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
router_component
=
runtime
.
namespace
(
"dynemo"
).
component
(
"router"
)
await
router_component
.
create_service
()
router
=
KvRouter
(
runtime
,
kv_listener
)
endpoint
=
router_component
.
endpoint
(
"generate"
)
await
endpoint
.
serve_endpoint
(
Router
(
router
,
args
.
routing_strategy
).
generate
)
if
args
.
custom_router
:
indexer
=
KvIndexer
(
kv_listener
)
metrics_aggregator
=
KvMetricsAggregator
(
kv_listener
)
await
endpoint
.
serve_endpoint
(
CustomRouter
(
indexer
,
metrics_aggregator
).
generate
)
else
:
router
=
KvRouter
(
runtime
,
kv_listener
)
await
endpoint
.
serve_endpoint
(
Router
(
router
,
args
.
routing_strategy
).
generate
)
if
__name__
==
"__main__"
:
...
...
@@ -147,6 +208,12 @@ if __name__ == "__main__":
default
=
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
help
=
"Model that is being served"
,
)
parser
.
add_argument
(
"--custom-router"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to use custom router or not"
,
)
args
=
parser
.
parse_args
()
asyncio
.
run
(
worker
(
args
))
lib/bindings/python/rust/lib.rs
View file @
e159e53f
...
...
@@ -70,6 +70,11 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_class
::
<
llm
::
model_card
::
ModelDeploymentCard
>
()
?
;
m
.add_class
::
<
llm
::
preprocessor
::
OAIChatPreprocessor
>
()
?
;
m
.add_class
::
<
llm
::
backend
::
Backend
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
OverlapScores
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvIndexer
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
EndpointKvMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
AggregatedMetrics
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvMetricsAggregator
>
()
?
;
engine
::
add_to_module
(
m
)
?
;
...
...
lib/bindings/python/rust/llm/kv.rs
View file @
e159e53f
...
...
@@ -13,7 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
collections
::
HashMap
;
use
super
::
*
;
use
llm_rs
::
kv_router
::
indexer
::
KvIndexerInterface
;
use
tracing
;
#[pyclass]
pub
(
crate
)
struct
KvRouter
{
...
...
@@ -106,3 +110,160 @@ impl KvMetricsPublisher {
.map_err
(
to_pyerr
)
}
}
#[pyclass]
#[derive(Clone)]
pub
(
crate
)
struct
OverlapScores
{
inner
:
llm_rs
::
kv_router
::
indexer
::
OverlapScores
,
}
#[pymethods]
impl
OverlapScores
{
#[getter]
fn
scores
(
&
self
)
->
HashMap
<
llm_rs
::
kv_router
::
indexer
::
WorkerId
,
u32
>
{
self
.inner.scores
.clone
()
}
#[getter]
fn
frequencies
(
&
self
)
->
Vec
<
usize
>
{
self
.inner.frequencies
.clone
()
}
}
#[pyclass]
pub
(
crate
)
struct
KvIndexer
{
inner
:
Arc
<
llm_rs
::
kv_router
::
indexer
::
KvIndexer
>
,
}
#[pymethods]
impl
KvIndexer
{
#[new]
fn
new
(
component
:
Component
)
->
PyResult
<
Self
>
{
let
runtime
=
pyo3_async_runtimes
::
tokio
::
get_runtime
();
runtime
.block_on
(
async
{
let
kv_subject
=
component
.inner
.event_subject
(
llm_rs
::
kv_router
::
KV_EVENT_SUBJECT
);
let
inner
:
Arc
<
llm_rs
::
kv_router
::
indexer
::
KvIndexer
>
=
llm_rs
::
kv_router
::
indexer
::
KvIndexer
::
new
(
component
.inner
.drt
()
.runtime
()
.child_token
(),
)
.into
();
let
mut
kv_events_rx
=
component
.inner
.drt
()
.nats_client
()
.client
()
.subscribe
(
kv_subject
)
.await
.map_err
(
to_pyerr
)
?
;
let
kv_events_tx
=
inner
.event_sender
();
// [FIXME] this is the added functionality to the indexer to subscribe to kv events,
// should have been made to a trait and implemented here? i.e. AsyncEngine style
tokio
::
spawn
(
async
move
{
while
let
Some
(
event
)
=
kv_events_rx
.next
()
.await
{
let
event
:
llm_rs
::
kv_router
::
indexer
::
RouterEvent
=
serde_json
::
from_slice
(
&
event
.payload
)
.unwrap
();
tracing
::
debug!
(
"received kv event: {:?}"
,
event
);
if
let
Err
(
e
)
=
kv_events_tx
.send
(
event
)
.await
{
tracing
::
trace!
(
"failed to send kv event to indexer; shutting down: {:?}"
,
e
);
}
}
});
Ok
(
Self
{
inner
})
})
}
fn
find_matches_for_request
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
token_ids
:
Vec
<
u32
>
,
_
lora_id
:
u64
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
indexer
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
let
rs_overlap_scores
=
indexer
.find_matches_for_request
(
token_ids
.as_slice
())
.await
.map_err
(
to_pyerr
)
?
;
Ok
(
OverlapScores
{
inner
:
rs_overlap_scores
,
})
})
}
}
#[pyclass]
#[derive(Clone)]
pub
(
crate
)
struct
EndpointKvMetrics
{
#[pyo3(get,
set)]
pub
worker_id
:
i64
,
#[pyo3(get,
set)]
pub
request_active_slots
:
u64
,
#[pyo3(get,
set)]
pub
request_total_slots
:
u64
,
#[pyo3(get,
set)]
pub
kv_active_blocks
:
u64
,
#[pyo3(get,
set)]
pub
kv_total_blocks
:
u64
,
}
#[pyclass]
#[derive(Clone)]
pub
(
crate
)
struct
AggregatedMetrics
{
#[pyo3(get,
set)]
pub
endpoints
:
Vec
<
EndpointKvMetrics
>
,
#[pyo3(get,
set)]
pub
load_avg
:
f64
,
#[pyo3(get,
set)]
pub
load_std
:
f64
,
}
#[pyclass]
pub
(
crate
)
struct
KvMetricsAggregator
{
inner
:
Arc
<
llm_rs
::
kv_router
::
metrics_aggregator
::
KvMetricsAggregator
>
,
}
#[pymethods]
impl
KvMetricsAggregator
{
#[new]
fn
new
(
component
:
Component
)
->
PyResult
<
Self
>
{
let
runtime
=
pyo3_async_runtimes
::
tokio
::
get_runtime
();
runtime
.block_on
(
async
{
let
inner
=
llm_rs
::
kv_router
::
metrics_aggregator
::
KvMetricsAggregator
::
new
(
component
.inner
.clone
(),
component
.inner
.drt
()
.runtime
()
.child_token
(),
)
.await
;
Ok
(
Self
{
inner
:
inner
.into
(),
})
})
}
fn
get_metrics
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
endpoints
=
self
.inner
.get_endpoints
();
let
endpoint_kv_metrics
=
endpoints
.endpoints
.iter
()
.map
(|
x
|
EndpointKvMetrics
{
worker_id
:
x
.worker_id
(),
request_active_slots
:
x
.data.request_active_slots
,
request_total_slots
:
x
.data.request_total_slots
,
kv_active_blocks
:
x
.data.kv_active_blocks
,
kv_total_blocks
:
x
.data.kv_total_blocks
,
})
.collect
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
Ok
(
AggregatedMetrics
{
endpoints
:
endpoint_kv_metrics
,
load_avg
:
endpoints
.load_avg
,
load_std
:
endpoints
.load_std
,
})
})
}
}
lib/bindings/python/src/dynemo/_core.pyi
View file @
e159e53f
...
...
@@ -233,3 +233,55 @@ class Backend:
Start the backend engine and requests to the downstream LLM engine
"""
...
class OverlapScores:
"""
A collection of prefix matching scores of workers for a given token ids.
'scores' is a map of worker id to the score which is the number of matching blocks.
"""
...
class KvIndexer:
"""
A KV Indexer that tracks KV Events emitted by workers. Events include add_block and remove_block.
"""
...
def __init__(self, component: Component) -> None:
"""
Create a `KvIndexer` object
"""
def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
"""
...
class AggregatedMetrics:
"""
A collection of metrics of the endpoints
"""
...
class KvMetricsAggregator:
"""
A metrics aggregator will collect KV metrics of the endpoints.
"""
...
def __init__(self, component: Component) -> None:
"""
Create a `KvMetricsAggregator` object
"""
def get_metrics(self) -> AggregatedMetrics:
"""
Return the aggregated metrics of the endpoints.
"""
...
lib/bindings/python/src/dynemo/llm/__init__.py
View file @
e159e53f
...
...
@@ -13,5 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
dynemo._core
import
KvIndexer
as
KvIndexer
from
dynemo._core
import
KvMetricsAggregator
as
KvMetricsAggregator
from
dynemo._core
import
KvMetricsPublisher
as
KvMetricsPublisher
from
dynemo._core
import
KvRouter
as
KvRouter
lib/llm/src/kv_router.rs
View file @
e159e53f
...
...
@@ -21,6 +21,7 @@ use tokio_util::sync::CancellationToken;
use
tracing
;
pub
mod
indexer
;
pub
mod
metrics_aggregator
;
pub
mod
protocols
;
pub
mod
publisher
;
pub
mod
scheduler
;
...
...
lib/llm/src/kv_router/metrics_aggregator.rs
0 → 100644
View file @
e159e53f
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
sync
::{
Arc
,
Mutex
};
pub
use
crate
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
crate
::
kv_router
::
scheduler
::{
Endpoint
,
Service
};
use
crate
::
kv_router
::
ProcessedEndpoints
;
use
dynemo_runtime
::
component
::
Component
;
use
std
::
time
::
Duration
;
use
tokio_util
::
sync
::
CancellationToken
;
pub
struct
KvMetricsAggregator
{
pub
service_name
:
String
,
pub
endpoints
:
Arc
<
Mutex
<
ProcessedEndpoints
>>
,
}
impl
KvMetricsAggregator
{
pub
async
fn
new
(
component
:
Component
,
cancellation_token
:
CancellationToken
)
->
Self
{
let
(
ep_tx
,
mut
ep_rx
)
=
tokio
::
sync
::
mpsc
::
channel
(
128
);
tokio
::
spawn
(
collect_endpoints
(
component
.drt
()
.nats_client
()
.clone
(),
component
.service_name
(),
ep_tx
,
cancellation_token
.clone
(),
));
tracing
::
trace!
(
"awaiting the start of the background endpoint subscriber"
);
let
endpoints
=
Arc
::
new
(
Mutex
::
new
(
ProcessedEndpoints
::
default
()));
let
endpoints_clone
=
endpoints
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
debug!
(
"scheduler background task started"
);
loop
{
match
ep_rx
.recv
()
.await
{
Some
(
endpoints
)
=>
match
endpoints_clone
.lock
()
{
Ok
(
mut
shared_endpoint
)
=>
{
*
shared_endpoint
=
endpoints
;
}
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to acquire lock on endpoints: {:?}"
,
e
);
}
},
None
=>
{
tracing
::
warn!
(
"endpoint subscriber shutdown"
);
break
;
}
};
}
tracing
::
trace!
(
"background endpoint subscriber shutting down"
);
});
Self
{
service_name
:
component
.service_name
(),
endpoints
,
}
}
pub
fn
get_endpoints
(
&
self
)
->
ProcessedEndpoints
{
match
self
.endpoints
.lock
()
{
Ok
(
endpoints
)
=>
endpoints
.clone
(),
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to acquire lock on endpoints: {:?}"
,
e
);
ProcessedEndpoints
::
default
()
}
}
}
}
async
fn
collect_endpoints
(
nats_client
:
dynemo_runtime
::
transports
::
nats
::
Client
,
service_name
:
String
,
ep_tx
:
tokio
::
sync
::
mpsc
::
Sender
<
ProcessedEndpoints
>
,
cancel
:
CancellationToken
,
)
{
loop
{
tokio
::
select!
{
_
=
cancel
.cancelled
()
=>
{
tracing
::
debug!
(
"cancellation token triggered"
);
break
;
}
_
=
tokio
::
time
::
sleep
(
Duration
::
from_secs
(
1
))
=>
{
tracing
::
trace!
(
"collecting endpoints for service: {}"
,
service_name
);
}
}
let
values
=
match
nats_client
.get_endpoints
(
&
service_name
,
Duration
::
from_secs
(
1
))
.await
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to retrieve endpoints for {}: {:?}"
,
service_name
,
e
);
continue
;
}
};
tracing
::
debug!
(
"values: {:?}"
,
values
);
let
services
:
Vec
<
Service
>
=
values
.into_iter
()
.filter
(|
v
|
!
v
.is_empty
())
.filter_map
(|
v
|
match
serde_json
::
from_slice
::
<
Service
>
(
&
v
)
{
Ok
(
service
)
=>
Some
(
service
),
Err
(
e
)
=>
{
tracing
::
warn!
(
"For value: {:?}
\n
Failed to parse service: {:?}"
,
v
,
e
);
None
}
})
.collect
();
tracing
::
debug!
(
"services: {:?}"
,
services
);
let
endpoints
:
Vec
<
Endpoint
>
=
services
.into_iter
()
.flat_map
(|
s
|
s
.endpoints
)
.filter
(|
s
|
s
.data
.is_some
())
.map
(|
s
|
Endpoint
{
name
:
s
.name
,
subject
:
s
.subject
,
data
:
s
.data
.unwrap
(),
})
.collect
();
tracing
::
debug!
(
"endpoints: {:?}"
,
endpoints
);
tracing
::
trace!
(
"found {} endpoints for service: {}"
,
endpoints
.len
(),
service_name
);
let
processed
=
ProcessedEndpoints
::
new
(
endpoints
);
if
ep_tx
.send
(
processed
)
.await
.is_err
()
{
tracing
::
trace!
(
"failed to send processed endpoints; shutting down"
);
break
;
}
}
}
lib/llm/src/kv_router/scoring.rs
View file @
e159e53f
...
...
@@ -20,7 +20,7 @@ use std::collections::HashSet;
use
crate
::
kv_router
::
scheduler
::
Endpoint
;
#[derive(Debug,
Default,
Serialize,
Deserialize)]
#[derive(Debug,
Default,
Serialize,
Deserialize
,
Clone
)]
pub
struct
ProcessedEndpoints
{
pub
endpoints
:
Vec
<
Endpoint
>
,
pub
worker_ids
:
Vec
<
i64
>
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment