Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
30610e73
Unverified
Commit
30610e73
authored
Oct 03, 2025
by
Yan Ru Pei
Committed by
GitHub
Oct 03, 2025
Browse files
feat: use KvPushRouter for prefill router (#3401)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
c48f49a4
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
164 additions
and
440 deletions
+164
-440
components/src/dynamo/router/__main__.py
components/src/dynamo/router/__main__.py
+63
-107
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+72
-76
components/src/dynamo/vllm/main.py
components/src/dynamo/vllm/main.py
+1
-9
components/src/dynamo/vllm/protocol.py
components/src/dynamo/vllm/protocol.py
+0
-34
lib/bindings/python/rust/lib.rs
lib/bindings/python/rust/lib.rs
+0
-1
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+12
-115
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+0
-97
lib/bindings/python/src/dynamo/llm/__init__.py
lib/bindings/python/src/dynamo/llm/__init__.py
+0
-1
lib/engines/llamacpp/src/lib.rs
lib/engines/llamacpp/src/lib.rs
+1
-0
lib/llm/src/migration.rs
lib/llm/src/migration.rs
+1
-0
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+1
-0
lib/llm/src/protocols/common/llm_backend.rs
lib/llm/src/protocols/common/llm_backend.rs
+8
-0
lib/llm/src/protocols/common/preprocessor.rs
lib/llm/src/protocols/common/preprocessor.rs
+5
-0
No files found.
components/src/dynamo/router/__main__.py
View file @
30610e73
...
...
@@ -19,7 +19,7 @@ from typing import Optional
import
uvloop
from
dynamo.llm
import
KvRouter
,
KvRouterConfig
from
dynamo.llm
import
Kv
Push
Router
,
KvRouterConfig
from
dynamo.runtime
import
Client
,
DistributedRuntime
,
dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
...
...
@@ -41,7 +41,7 @@ class StandaloneRouterHandler:
self
.
worker_endpoint_path
=
worker_endpoint_path
self
.
block_size
=
block_size
self
.
kv_router_config
=
kv_router_config
self
.
kv_router
:
Optional
[
KvRouter
]
=
None
self
.
kv_
push_
router
:
Optional
[
Kv
Push
Router
]
=
None
self
.
worker_client
:
Optional
[
Client
]
=
None
async
def
initialize
(
self
):
...
...
@@ -65,121 +65,76 @@ class StandaloneRouterHandler:
self
.
worker_client
=
await
worker_endpoint
.
client
()
# Create KvRouter with specified configuration
self
.
kv_router
=
KvRouter
(
# Create Kv
Push
Router with specified configuration
self
.
kv_
push_
router
=
Kv
Push
Router
(
endpoint
=
worker_endpoint
,
block_size
=
self
.
block_size
,
kv_router_config
=
self
.
kv_router_config
,
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to initialize KvRouter:
{
e
}
"
)
logger
.
error
(
f
"Failed to initialize Kv
Push
Router:
{
e
}
"
)
raise
async
def
find_best_worker
(
self
,
request
):
async
def
generate
(
self
,
request
):
"""
Find the best worker based on KV cache sta
te.
Generate tokens using the KV-aware rou
te
r
.
This endpoint is called by clients to determine which worker
should handle a request.
This endpoint routes the request to the best worker and streams back results.
Wraps the request into PreprocessedRequest format and wraps worker responses
into LLMEngineOutput format.
"""
if
self
.
kv_router
is
None
:
# Fallback to round-robin if router not initialized
logger
.
warning
(
"KvRouter not initialized, falling back to round-robin"
)
yield
{
"status"
:
"fallback"
,
"message"
:
"Router not initialized"
,
if
self
.
kv_push_router
is
None
:
logger
.
error
(
"KvPushRouter not initialized - cannot process request"
)
raise
RuntimeError
(
"Router not initialized"
)
# Wrap incoming request into PreprocessedRequest format for KvPushRouter
# The request should already have most fields, but we ensure it has the structure
preprocessed_request
=
{
"model"
:
request
.
get
(
"model"
,
"unknown"
),
"token_ids"
:
request
[
"token_ids"
],
"stop_conditions"
:
request
.
get
(
"stop_conditions"
,
{}),
"sampling_options"
:
request
.
get
(
"sampling_options"
,
{}),
"output_options"
:
request
.
get
(
"output_options"
,
{}),
"eos_token_ids"
:
request
.
get
(
"eos_token_ids"
,
[]),
"annotations"
:
request
.
get
(
"annotations"
,
[]),
"extra_args"
:
request
.
get
(
"extra_args"
,
{}),
}
return
try
:
# Get current workers
if
self
.
worker_client
is
None
:
yield
{
"status"
:
"error"
,
"message"
:
"Worker client not initialized"
,
}
return
instance_ids
=
self
.
worker_client
.
instance_ids
()
if
not
instance_ids
:
yield
{
"status"
:
"error"
,
"message"
:
"No workers available"
,
}
return
logger
.
debug
(
f
"Routing request with
{
len
(
instance_ids
)
}
available workers"
)
# Validate required fields
if
"token_ids"
not
in
request
:
raise
ValueError
(
"Missing required field 'token_ids' in request"
)
if
"request_id"
not
in
request
:
raise
ValueError
(
"Missing required field 'request_id' in request"
)
token_ids
=
request
[
"token_ids"
]
request_id
=
request
[
"request_id"
]
# Use KvRouter to find the best worker with state updates
best_worker_id
,
overlap_blocks
=
await
self
.
kv_router
.
find_best_match
(
request_id
=
request_id
,
tokens
=
token_ids
,
update_states
=
True
,
# Always update states for routing
)
logger
.
debug
(
f
"Selected worker
{
best_worker_id
}
with
{
overlap_blocks
}
overlap blocks for request
{
request_id
}
"
)
yield
{
"worker_id"
:
best_worker_id
,
"overlap_blocks"
:
overlap_blocks
,
}
except
Exception
as
e
:
logger
.
error
(
f
"Error finding best worker:
{
e
}
"
)
yield
{
"status"
:
"error"
,
"message"
:
str
(
e
),
# Route and process through KvPushRouter
async
for
worker_output
in
await
self
.
kv_push_router
.
generate_from_request
(
preprocessed_request
):
# Wrap worker output into LLMEngineOutput format
# Worker should return dict with at minimum kv_transfer_params in extra_args
llm_engine_output
=
{
"token_ids"
:
worker_output
.
get
(
"token_ids"
,
[]),
"tokens"
:
worker_output
.
get
(
"tokens"
),
"text"
:
worker_output
.
get
(
"text"
),
"cum_log_probs"
:
worker_output
.
get
(
"cum_log_probs"
),
"log_probs"
:
worker_output
.
get
(
"log_probs"
),
"top_logprobs"
:
worker_output
.
get
(
"top_logprobs"
),
"finish_reason"
:
worker_output
.
get
(
"finish_reason"
),
"index"
:
worker_output
.
get
(
"index"
),
"extra_args"
:
worker_output
.
get
(
"extra_args"
),
}
yield
llm_engine_output
async
def
free
(
self
,
request
):
async
def
best_worker_id
(
self
,
token_ids
,
router_config_override
=
None
):
"""
Fre
e
r
es
ources associated with a request
.
Get th
e
b
es
t worker ID for a given set of tokens without actually routing
.
This endpoint is called when a request is completed to clean up
router state.
This method returns the worker ID that would be selected based on KV cache
overlap, but does NOT actually route the request or update router states.
It's useful for debugging, monitoring, or implementing custom routing logic.
"""
if
self
.
kv_router
is
None
:
logger
.
warning
(
"KvRouter not initialized"
)
yield
{
"status"
:
"error"
,
"message"
:
"Router not initialized"
,
}
return
try
:
if
"request_id"
not
in
request
:
raise
ValueError
(
"Missing required field 'request_id' in request"
)
request_id
=
request
[
"request_id"
]
if
self
.
kv_push_router
is
None
:
logger
.
error
(
"KvPushRouter not initialized - cannot get best worker"
)
raise
RuntimeError
(
"Router not initialized"
)
# Free the request from the router
await
self
.
kv_router
.
free
(
request_id
=
request_id
)
logger
.
debug
(
f
"Freed resources for request
{
request_id
}
"
)
yield
{
"status"
:
"success"
,
"message"
:
f
"Request
{
request_id
}
freed successfully"
,
}
except
Exception
as
e
:
logger
.
error
(
f
"Error freeing request:
{
e
}
"
)
yield
{
"status"
:
"error"
,
"message"
:
str
(
e
),
}
return
await
self
.
kv_push_router
.
best_worker_id
(
token_ids
,
router_config_override
)
def
parse_args
():
...
...
@@ -308,20 +263,21 @@ async def worker(runtime: DistributedRuntime):
await
handler
.
initialize
()
# Expose endpoints
find_best_worker
_endpoint
=
component
.
endpoint
(
"
find_best_worker
"
)
free
_endpoint
=
component
.
endpoint
(
"
free
"
)
generate
_endpoint
=
component
.
endpoint
(
"
generate
"
)
best_worker
_endpoint
=
component
.
endpoint
(
"
best_worker_id
"
)
logger
.
debug
(
"Starting to serve
find_best_worker and free
endpoints..."
)
logger
.
debug
(
"Starting to serve endpoints..."
)
# Serve both endpoints concurrently
try
:
await
asyncio
.
gather
(
find_best_worker
_endpoint
.
serve_endpoint
(
handler
.
find_best_worker
,
generate
_endpoint
.
serve_endpoint
(
handler
.
generate
,
graceful_shutdown
=
True
,
metrics_labels
=
[(
"service"
,
"router"
)],
),
free
_endpoint
.
serve_endpoint
(
handler
.
free
,
best_worker
_endpoint
.
serve_endpoint
(
handler
.
best_worker_id
,
graceful_shutdown
=
True
,
metrics_labels
=
[(
"service"
,
"router"
)],
),
...
...
components/src/dynamo/vllm/handlers.py
View file @
30610e73
...
...
@@ -8,7 +8,7 @@ import uuid
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
asynccontextmanager
from
copy
import
deepcopy
from
typing
import
AsyncGenerator
from
typing
import
Any
,
AsyncGenerator
,
Dict
import
msgspec
from
vllm.inputs
import
TokensPrompt
...
...
@@ -18,7 +18,6 @@ from vllm.v1.engine.exceptions import EngineDeadError
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
.engine_monitor
import
VllmEngineMonitor
from
.protocol
import
MyRequestOutput
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -126,27 +125,34 @@ class DecodeWorkerHandler(BaseWorkerHandler):
default_sampling_params
,
prefill_worker_client
=
None
,
prefill_router_client
=
None
,
prefill_router_free_client
=
None
,
):
super
().
__init__
(
runtime
,
component
,
engine
,
default_sampling_params
)
self
.
prefill_worker_client
=
prefill_worker_client
self
.
prefill_router_client
=
prefill_router_client
self
.
prefill_router_free_client
=
prefill_router_free_client
self
.
can_prefill
=
0
self
.
_prefill_check_task
=
None
if
self
.
prefill_worker_client
is
not
None
:
if
self
.
prefill_worker_client
or
self
.
prefill_router_client
:
self
.
_prefill_check_task
=
asyncio
.
create_task
(
self
.
_prefill_check_loop
())
async
def
_prefill_check_loop
(
self
):
"""Background task that checks prefill worker availability every 5 seconds."""
"""Background task that checks prefill
router/
worker availability every 5 seconds."""
while
True
:
try
:
if
self
.
prefill_worker_client
is
not
None
:
self
.
can_prefill
=
len
(
self
.
prefill_worker_client
.
instance_ids
())
logger
.
debug
(
f
"Current Prefill Workers:
{
self
.
can_prefill
}
"
)
else
:
self
.
can_prefill
=
0
router_count
=
(
len
(
self
.
prefill_router_client
.
instance_ids
())
if
self
.
prefill_router_client
is
not
None
else
0
)
worker_count
=
(
len
(
self
.
prefill_worker_client
.
instance_ids
())
if
self
.
prefill_worker_client
is
not
None
else
0
)
self
.
can_prefill
=
max
(
router_count
,
worker_count
)
logger
.
debug
(
f
"Prefill availability - Routers:
{
router_count
}
, Workers:
{
worker_count
}
"
)
except
asyncio
.
CancelledError
:
logger
.
warning
(
"Prefill check loop cancelled."
)
raise
...
...
@@ -178,15 +184,10 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if
value
is
not
None
and
hasattr
(
sampling_params
,
key
):
setattr
(
sampling_params
,
key
,
value
)
# TODO: Change to prefill queue
# TODO: (PeaBrane) eventually, do not use a router_client and a free_client directly.
# This is least intrusive for now, but quite error prone. Should consider (major) refactoring
# TODO: (PeaBrane) longer term, decode workers should not handle prefill routing at all.
# Prefill routing logic should be integrated directly into the frontend service potentially.
# Use prefill router or worker if available
if
self
.
can_prefill
:
# Create
a copy for prefill with specific
modifications
# Create
prefill sampling params with
modifications
prefill_sampling_params
=
deepcopy
(
sampling_params
)
if
prefill_sampling_params
.
extra_args
is
None
:
prefill_sampling_params
.
extra_args
=
{}
prefill_sampling_params
.
extra_args
[
"kv_transfer_params"
]
=
{
...
...
@@ -195,68 +196,55 @@ class DecodeWorkerHandler(BaseWorkerHandler):
prefill_sampling_params
.
max_tokens
=
1
prefill_sampling_params
.
min_tokens
=
1
prefill_request
=
{
"token_ids"
:
request
[
"token_ids"
],
try
:
# Send request with sampling_params and request_id in extra_args
prefill_request
=
request
.
copy
()
# TODO (PeaBrane): this smells a bit bad as not we have two nestings
# of extra_args (an inner one again in sampling_params)
prefill_request
[
"extra_args"
]
=
{
"sampling_params"
:
msgspec
.
to_builtins
(
prefill_sampling_params
),
"request_id"
:
request_id
,
}
used_prefill_router
=
False
try
:
prefill_worker_id
=
None
# Try router first if available, fallback to worker
if
(
self
.
prefill_router_client
is
not
None
and
self
.
prefill_router_client
.
instance_ids
()
):
used_prefill_router
=
True
best_worker_response
=
await
anext
(
await
self
.
prefill_router_client
.
generate
(
{
"token_ids"
:
request
[
"token_ids"
],
"request_id"
:
request_id
,
}
)
)
prefill_worker_id
=
best_worker_response
.
data
().
get
(
"worker_id"
)
if
prefill_worker_id
is
not
None
:
# Call router's generate endpoint which returns LLMEngineOutput
prefill_response
=
await
anext
(
await
self
.
prefill_
work
er_client
.
direct
(
prefill_request
,
prefill_worker_id
,
context
=
context
await
self
.
prefill_
rout
er_client
.
generate
(
prefill_request
,
context
=
context
)
)
else
:
elif
self
.
prefill_worker_client
is
not
None
:
# Fallback to direct worker with same format
prefill_response
=
await
anext
(
await
self
.
prefill_worker_client
.
round_robin
(
prefill_request
,
context
=
context
)
)
else
:
raise
ValueError
(
"No prefill router or worker available"
)
except
Exception
as
e
:
if
context
.
is_stopped
()
or
context
.
is_killed
():
logger
.
debug
(
f
"Aborted Remote Prefill Request ID:
{
request_id
}
"
)
return
raise
e
finally
:
if
used_prefill_router
:
await
anext
(
await
self
.
prefill_router_free_client
.
generate
(
{
"request_id"
:
request_id
}
)
)
logger
.
debug
(
f
"Freed router state for request
{
request_id
}
"
)
prefill_output
=
prefill_response
.
data
()
prefill_response
=
MyRequestOutput
.
model_validate_json
(
prefill_response
.
data
()
# Extract kv_transfer_params from response
kv_transfer_params
=
prefill_output
.
get
(
"extra_args"
,
{}).
get
(
"kv_transfer_params"
)
# Modify original sampling_params for decode
if
kv_transfer_params
:
if
sampling_params
.
extra_args
is
None
:
sampling_params
.
extra_args
=
{}
sampling_params
.
extra_args
[
"kv_transfer_params"
]
=
prefill_response
.
kv_transfer_params
]
=
kv_transfer_params
except
Exception
as
e
:
if
context
.
is_stopped
()
or
context
.
is_killed
():
logger
.
debug
(
f
"Aborted Remote Prefill Request ID:
{
request_id
}
"
)
return
logger
.
warning
(
f
"Prefill error:
{
e
}
, falling back to local prefill"
)
async
with
self
.
_abort_monitor
(
context
,
request_id
):
try
:
...
...
@@ -276,11 +264,17 @@ class PrefillWorkerHandler(BaseWorkerHandler):
super
().
__init__
(
runtime
,
component
,
engine
,
default_sampling_params
)
async
def
generate
(
self
,
request
,
context
):
request_id
=
request
[
"request_id"
]
# Extract from PreprocessedRequest format - request_id and sampling_params from extra_args
extra_args
=
request
.
get
(
"extra_args"
,
{})
request_id
=
extra_args
.
get
(
"request_id"
,
str
(
uuid
.
uuid4
().
hex
))
logger
.
debug
(
f
"New Prefill Request ID:
{
request_id
}
"
)
prompt
=
TokensPrompt
(
prompt_token_ids
=
request
[
"token_ids"
])
sampling_params
=
msgspec
.
convert
(
request
[
"sampling_params"
],
SamplingParams
)
token_ids
=
request
[
"token_ids"
]
prompt
=
TokensPrompt
(
prompt_token_ids
=
token_ids
)
# Get sampling_params from extra_args
sampling_params_dict
=
extra_args
.
get
(
"sampling_params"
,
{})
sampling_params
=
msgspec
.
convert
(
sampling_params_dict
,
SamplingParams
)
async
with
self
.
_abort_monitor
(
context
,
request_id
,
is_prefill
=
True
):
try
:
...
...
@@ -291,20 +285,22 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self
.
runtime
.
shutdown
()
os
.
_exit
(
1
)
# Generate only 1 token in prefill
try
:
async
for
res
in
gen
:
logger
.
debug
(
f
"kv transfer params:
{
res
.
kv_transfer_params
}
"
)
yield
MyRequestOutput
(
request_id
=
res
.
request_id
,
prompt
=
res
.
prompt
,
prompt_token_ids
=
res
.
prompt_token_ids
,
prompt_logprobs
=
res
.
prompt_logprobs
,
outputs
=
res
.
outputs
,
finished
=
res
.
finished
,
metrics
=
res
.
metrics
,
kv_transfer_params
=
res
.
kv_transfer_params
,
).
model_dump_json
()
token_ids
=
res
.
outputs
[
0
].
token_ids
if
res
.
outputs
else
[]
output
:
Dict
[
str
,
Any
]
=
{
"token_ids"
:
list
(
token_ids
),
"extra_args"
:
(
{
"kv_transfer_params"
:
res
.
kv_transfer_params
}
if
res
.
kv_transfer_params
else
{}
),
}
yield
output
except
asyncio
.
CancelledError
:
# raise the error because we cannot migrate prefill requests
raise
GeneratorExit
(
...
...
components/src/dynamo/vllm/main.py
View file @
30610e73
...
...
@@ -227,14 +227,7 @@ async def init(runtime: DistributedRuntime, config: Config):
prefill_router_client
=
(
await
runtime
.
namespace
(
config
.
namespace
)
.
component
(
"router"
)
# Standalone router for prefill workers
.
endpoint
(
"find_best_worker"
)
.
client
()
)
prefill_router_free_client
=
(
await
runtime
.
namespace
(
config
.
namespace
)
.
component
(
"router"
)
# Standalone router for prefill workers
.
endpoint
(
"free"
)
.
endpoint
(
"generate"
)
.
client
()
)
...
...
@@ -268,7 +261,6 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params
,
prefill_worker_client
,
prefill_router_client
,
prefill_router_free_client
,
)
# Set up KV event publisher for prefix caching if enabled
...
...
components/src/dynamo/vllm/protocol.py
deleted
100644 → 0
View file @
c48f49a4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
List
,
Optional
from
pydantic
import
BaseModel
,
ConfigDict
from
vllm.outputs
import
CompletionOutput
from
vllm.sequence
import
PromptLogprobs
,
RequestMetrics
class
MyRequestOutput
(
BaseModel
):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request_id
:
str
prompt
:
Optional
[
str
]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
outputs
:
List
[
CompletionOutput
]
finished
:
bool
metrics
:
Optional
[
RequestMetrics
]
=
None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
lib/bindings/python/rust/lib.rs
View file @
30610e73
...
...
@@ -174,7 +174,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_class
::
<
llm
::
kv
::
WorkerStats
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvStats
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
SpecDecodeStats
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvRouter
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvPushRouter
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvPushRouterStream
>
()
?
;
m
.add_class
::
<
RouterMode
>
()
?
;
...
...
lib/bindings/python/rust/llm/kv.rs
View file @
30610e73
...
...
@@ -866,118 +866,6 @@ async fn create_kv_router_from_endpoint(
Ok
(
kv_router
)
}
#[pyclass]
pub
(
crate
)
struct
KvRouter
{
inner
:
Arc
<
llm_rs
::
kv_router
::
KvRouter
>
,
}
#[pymethods]
impl
KvRouter
{
#[new]
#[pyo3(signature
=
(endpoint,
block_size,
kv_router_config=None))]
fn
new
(
endpoint
:
&
Endpoint
,
block_size
:
usize
,
kv_router_config
:
Option
<&
super
::
entrypoint
::
KvRouterConfig
>
,
)
->
PyResult
<
Self
>
{
let
runtime
=
pyo3_async_runtimes
::
tokio
::
get_runtime
();
runtime
.block_on
(
async
move
{
let
kv_router
=
create_kv_router_from_endpoint
(
endpoint
,
block_size
,
kv_router_config
.map
(|
c
|
c
.inner
()),
)
.await
?
;
Ok
(
Self
{
inner
:
kv_router
})
})
}
#[pyo3(signature
=
(request_id,
tokens,
update_states=
false
,
router_config_override=None))]
fn
find_best_match
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
request_id
:
String
,
tokens
:
Vec
<
u32
>
,
update_states
:
bool
,
router_config_override
:
Option
<
PyObject
>
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
router_config_override
=
if
let
Some
(
obj
)
=
router_config_override
{
Python
::
with_gil
(|
py
|
{
let
override_config
:
llm_rs
::
kv_router
::
RouterConfigOverride
=
depythonize
(
obj
.bind
(
py
))
.map_err
(
to_pyerr
)
?
;
Ok
::
<
_
,
PyErr
>
(
Some
(
override_config
))
})
?
}
else
{
None
};
let
inner
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
let
(
worker_id
,
overlap_blocks
)
=
inner
.find_best_match
(
Some
(
&
request_id
),
&
tokens
,
router_config_override
.as_ref
(),
update_states
,
)
.await
.map_err
(
to_pyerr
)
?
;
Ok
((
worker_id
,
overlap_blocks
))
})
}
fn
add_request
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
request_id
:
String
,
tokens
:
Vec
<
u32
>
,
overlap_blocks
:
u32
,
worker_id
:
i64
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
inner
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
inner
.add_request
(
request_id
,
&
tokens
,
overlap_blocks
,
worker_id
)
.await
;
Ok
(())
})
}
fn
mark_prefill_completed
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
request_id
:
String
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
inner
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
inner
.mark_prefill_completed
(
&
request_id
)
.await
.map_err
(
to_pyerr
)
?
;
Ok
(())
})
}
fn
free
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
request_id
:
String
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
inner
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
inner
.free
(
&
request_id
)
.await
.map_err
(
to_pyerr
)
?
;
Ok
(())
})
}
#[getter]
fn
block_size
(
&
self
)
->
PyResult
<
u32
>
{
Ok
(
self
.inner
.block_size
())
}
}
#[pyclass]
pub
(
crate
)
struct
KvPushRouter
{
inner
:
Arc
<
llm_rs
::
kv_router
::
KvPushRouter
>
,
...
...
@@ -1072,7 +960,7 @@ impl KvPushRouter {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature
=
(token_ids,
model,
stop_conditions=None,
sampling_options=None,
output_options=None,
router_config_override=None,
worker_id=None))]
#[pyo3(signature
=
(token_ids,
model,
stop_conditions=None,
sampling_options=None,
output_options=None,
router_config_override=None,
worker_id=None
,
extra_args=None
))]
fn
generate
<
'p
>
(
&
self
,
py
:
Python
<
'p
>
,
...
...
@@ -1083,9 +971,10 @@ impl KvPushRouter {
output_options
:
Option
<
PyObject
>
,
router_config_override
:
Option
<
PyObject
>
,
worker_id
:
Option
<
i64
>
,
extra_args
:
Option
<
PyObject
>
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
// Depythonize the options with defaults
let
(
stop_conditions
,
sampling_options
,
output_options
,
router_config_override
)
=
let
(
stop_conditions
,
sampling_options
,
output_options
,
router_config_override
,
extra_args
)
=
Python
::
with_gil
(|
py
|
{
let
stop_conditions
:
StopConditions
=
if
let
Some
(
obj
)
=
stop_conditions
{
depythonize
(
obj
.bind
(
py
))
.map_err
(
to_pyerr
)
?
...
...
@@ -1112,11 +1001,18 @@ impl KvPushRouter {
None
};
let
extra_args
:
Option
<
serde_json
::
Value
>
=
if
let
Some
(
obj
)
=
extra_args
{
Some
(
depythonize
(
obj
.bind
(
py
))
.map_err
(
to_pyerr
)
?
)
}
else
{
None
};
Ok
::
<
_
,
PyErr
>
((
stop_conditions
,
sampling_options
,
output_options
,
router_config_override
,
extra_args
,
))
})
?
;
...
...
@@ -1129,7 +1025,8 @@ impl KvPushRouter {
.stop_conditions
(
stop_conditions
)
.sampling_options
(
sampling_options
)
.output_options
(
output_options
)
.router_config_override
(
router_config_override
);
.router_config_override
(
router_config_override
)
.extra_args
(
extra_args
);
// Set backend_instance_id if worker_id is provided
if
let
Some
(
worker_id
)
=
worker_id
{
...
...
lib/bindings/python/src/dynamo/_core.pyi
View file @
30610e73
...
...
@@ -1129,103 +1129,6 @@ class ZmqKvEventListener:
"""
...
class KvRouter:
"""
A KV Router that decides which worker to use based on KV cache overlap.
This router tracks request states and manages KV cache distribution across workers.
"""
def __init__(
self,
endpoint: Endpoint,
block_size: int,
kv_router_config: Optional[KvRouterConfig] = None,
consumer_uuid: Optional[str] = None,
) -> None:
"""
Create a new KvRouter instance.
Args:
endpoint: The endpoint to associate with this router
block_size: The KV cache block size
kv_router_config: Optional configuration for the KV router
consumer_uuid: Optional unique identifier for this router instance.
If not provided, a UUID will be generated.
"""
...
async def find_best_match(
self,
request_id: str,
tokens: List[int],
*,
update_states: bool = False,
router_config_override: Optional[JsonLike] = None,
) -> Tuple[int, int]:
"""
Find the best matching worker for the given tokens.
Args:
request_id: Unique identifier for the request used for tracking
tokens: List of token IDs to find matches for
update_states: Whether to update router states for this request (default: False)
router_config_override: Optional router configuration override with fields:
- overlap_score_weight: Optional weight for overlap score
- router_temperature: Optional temperature for worker selection
Returns:
A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def add_request(
self,
request_id: str,
tokens: List[int],
overlap_blocks: int,
worker_id: int,
) -> None:
"""
Add a request to the router's tracking system.
Args:
request_id: Unique identifier for the request
tokens: List of token IDs for the request
overlap_blocks: Number of overlapping blocks found
worker_id: ID of the worker handling this request
"""
...
async def mark_prefill_completed(self, request_id: str) -> None:
"""
Mark that prefill has been completed for a request.
Args:
request_id: The request ID to mark as prefill completed
"""
...
async def free(self, request_id: str) -> None:
"""
Free resources associated with a request.
Args:
request_id: The request ID to free
"""
...
@property
def block_size(self) -> int:
"""
Get the KV cache block size.
Returns:
The block size in tokens
"""
...
class KvPushRouter:
"""
A KV-aware push router that performs intelligent routing based on KV cache overlap.
...
...
lib/bindings/python/src/dynamo/llm/__init__.py
View file @
30610e73
...
...
@@ -26,7 +26,6 @@ from dynamo._core import KvIndexer as KvIndexer
from
dynamo._core
import
KvMetricsAggregator
as
KvMetricsAggregator
from
dynamo._core
import
KvPushRouter
as
KvPushRouter
from
dynamo._core
import
KvRecorder
as
KvRecorder
from
dynamo._core
import
KvRouter
as
KvRouter
from
dynamo._core
import
KvRouterConfig
as
KvRouterConfig
from
dynamo._core
import
KvStats
as
KvStats
from
dynamo._core
import
ModelInput
as
ModelInput
...
...
lib/engines/llamacpp/src/lib.rs
View file @
30610e73
...
...
@@ -271,6 +271,7 @@ fn run_request(
top_logprobs
:
None
,
finish_reason
:
None
,
index
:
None
,
extra_args
:
None
,
};
work_request
.response_channel
...
...
lib/llm/src/migration.rs
View file @
30610e73
...
...
@@ -210,6 +210,7 @@ mod tests {
top_logprobs
:
None
,
finish_reason
:
None
,
index
:
None
,
extra_args
:
None
,
})
}
...
...
lib/llm/src/mocker/engine.rs
View file @
30610e73
...
...
@@ -392,6 +392,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
top_logprobs
:
None
,
finish_reason
:
None
,
index
:
None
,
extra_args
:
None
,
};
if
signal
.completed
&&
token_count
<
max_tokens
{
...
...
lib/llm/src/protocols/common/llm_backend.rs
View file @
30610e73
...
...
@@ -84,6 +84,10 @@ pub struct LLMEngineOutput {
// Index field for batch requests to match OpenAI format
pub
index
:
Option
<
u32
>
,
/// Additional arguments for extensibility
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
extra_args
:
Option
<
serde_json
::
Value
>
,
}
impl
LLMEngineOutput
{
...
...
@@ -97,6 +101,7 @@ impl LLMEngineOutput {
top_logprobs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Cancelled
),
index
:
None
,
extra_args
:
None
,
}
}
...
...
@@ -110,6 +115,7 @@ impl LLMEngineOutput {
finish_reason
:
Some
(
FinishReason
::
Stop
),
top_logprobs
:
None
,
index
:
None
,
extra_args
:
None
,
}
}
...
...
@@ -123,6 +129,7 @@ impl LLMEngineOutput {
top_logprobs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Length
),
index
:
None
,
extra_args
:
None
,
}
}
...
...
@@ -136,6 +143,7 @@ impl LLMEngineOutput {
top_logprobs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Error
(
err_msg
)),
index
:
None
,
extra_args
:
None
,
}
}
}
...
...
lib/llm/src/protocols/common/preprocessor.rs
View file @
30610e73
...
...
@@ -59,6 +59,11 @@ pub struct PreprocessedRequest {
/// Router configuration overrides for this specific request
#[builder(default)]
pub
router_config_override
:
Option
<
RouterConfigOverride
>
,
/// Additional arguments for extensibility
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
extra_args
:
Option
<
serde_json
::
Value
>
,
}
impl
PreprocessedRequest
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment