Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
02b1c58a
Unverified
Commit
02b1c58a
authored
Mar 25, 2026
by
Yan Ru Pei
Committed by
GitHub
Mar 25, 2026
Browse files
feat(mocker): add offline disagg replay (#7617)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
4b8826b3
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2355 additions
and
1319 deletions
+2355
-1319
lib/kv-router/src/scheduling/policy.rs
lib/kv-router/src/scheduling/policy.rs
+1
-0
lib/kv-router/src/scheduling/queue.rs
lib/kv-router/src/scheduling/queue.rs
+38
-5
lib/kv-router/src/scheduling/types.rs
lib/kv-router/src/scheduling/types.rs
+1
-0
lib/kv-router/src/sequences/multi_worker.rs
lib/kv-router/src/sequences/multi_worker.rs
+74
-5
lib/kv-router/src/sequences/single.rs
lib/kv-router/src/sequences/single.rs
+87
-3
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+19
-197
lib/llm/src/kv_router/indexer.rs
lib/llm/src/kv_router/indexer.rs
+233
-0
lib/llm/src/kv_router/prefill_router/activation.rs
lib/llm/src/kv_router/prefill_router/activation.rs
+188
-0
lib/llm/src/kv_router/prefill_router/execution.rs
lib/llm/src/kv_router/prefill_router/execution.rs
+322
-0
lib/llm/src/kv_router/prefill_router/inner.rs
lib/llm/src/kv_router/prefill_router/inner.rs
+55
-0
lib/llm/src/kv_router/prefill_router/mod.rs
lib/llm/src/kv_router/prefill_router/mod.rs
+259
-0
lib/llm/src/kv_router/prefill_router/types.rs
lib/llm/src/kv_router/prefill_router/types.rs
+56
-0
lib/llm/src/kv_router/publisher/event_processor.rs
lib/llm/src/kv_router/publisher/event_processor.rs
+326
-0
lib/llm/src/kv_router/publisher/mod.rs
lib/llm/src/kv_router/publisher/mod.rs
+396
-0
lib/llm/src/kv_router/publisher/tests.rs
lib/llm/src/kv_router/publisher/tests.rs
+14
-1043
lib/llm/src/kv_router/publisher/worker_metrics.rs
lib/llm/src/kv_router/publisher/worker_metrics.rs
+116
-0
lib/llm/src/kv_router/publisher/zmq_listener.rs
lib/llm/src/kv_router/publisher/zmq_listener.rs
+161
-0
lib/llm/src/kv_router/remote_indexer.rs
lib/llm/src/kv_router/remote_indexer.rs
+0
-65
lib/llm/src/kv_router/scheduler.rs
lib/llm/src/kv_router/scheduler.rs
+3
-1
lib/llm/src/kv_router/sequence.rs
lib/llm/src/kv_router/sequence.rs
+6
-0
No files found.
lib/kv-router/src/scheduling/policy.rs
View file @
02b1c58a
...
...
@@ -135,6 +135,7 @@ mod tests {
overlaps
,
decode_blocks
:
HashMap
::
new
(),
prefill_tokens
:
HashMap
::
new
(),
track_prefill_tokens
:
true
,
router_config_override
:
None
,
update_states
:
false
,
lora_name
:
None
,
...
...
lib/kv-router/src/scheduling/queue.rs
View file @
02b1c58a
...
...
@@ -191,11 +191,14 @@ impl<
/// Run the full scheduling pipeline for a single request:
/// compute potential load -> select worker -> respond -> book via add_request.
async
fn
schedule
(
&
self
,
mut
request
:
SchedulingRequest
)
{
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.potential_blocks_and_tokens
(
request
.token_seq
.as_deref
(),
request
.isl_tokens
,
request
.overlaps
.clone
(),
);
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.potential_blocks_and_tokens_with_prefill_tracking
(
request
.token_seq
.as_deref
(),
request
.isl_tokens
,
request
.overlaps
.clone
(),
request
.track_prefill_tokens
,
);
request
.decode_blocks
=
decode_blocks
;
request
.prefill_tokens
=
prefill_tokens
;
...
...
@@ -235,6 +238,7 @@ impl<
token_sequence
:
request
.token_seq
,
isl
:
request
.isl_tokens
,
overlap
:
selection
.overlap_blocks
,
track_prefill_tokens
:
request
.track_prefill_tokens
,
expected_output_tokens
:
request
.expected_output_tokens
,
worker
:
selection
.worker
,
lora_name
:
request
.lora_name
.clone
(),
...
...
@@ -376,6 +380,7 @@ mod tests {
overlaps
:
OverlapScores
::
default
(),
decode_blocks
:
HashMap
::
new
(),
prefill_tokens
:
HashMap
::
new
(),
track_prefill_tokens
:
true
,
router_config_override
:
None
,
update_states
:
true
,
lora_name
:
None
,
...
...
@@ -695,6 +700,7 @@ mod tests {
overlaps
:
OverlapScores
::
default
(),
decode_blocks
:
HashMap
::
new
(),
prefill_tokens
:
HashMap
::
new
(),
track_prefill_tokens
:
true
,
router_config_override
:
None
,
update_states
:
true
,
lora_name
:
None
,
...
...
@@ -719,4 +725,31 @@ mod tests {
.unwrap
();
slots
.free
(
&
"filter-0"
.to_string
())
.await
.unwrap
();
}
#[tokio::test(flavor
=
"multi_thread"
)]
async
fn
test_queue_busy_check_ignores_untracked_prefill_tokens
()
{
let
(
queue
,
slots
)
=
make_queue
(
1
,
16
,
256
,
Some
(
0.0
));
let
(
mut
req1
,
rx1
)
=
make_request
(
"req-1"
,
256
);
req1
.track_prefill_tokens
=
false
;
queue
.enqueue
(
req1
)
.await
;
let
_
resp1
=
rx1
.await
.unwrap
()
.unwrap
();
assert_eq!
(
slots
.active_tokens
()
.get
(
&
WorkerWithDpRank
::
new
(
0
,
0
))
.copied
(),
Some
(
0
)
);
let
(
req2
,
rx2
)
=
make_request
(
"req-2"
,
256
);
queue
.enqueue
(
req2
)
.await
;
let
_
resp2
=
rx2
.await
.unwrap
()
.unwrap
();
assert_eq!
(
queue
.pending_count
(),
0
);
let
_
=
slots
.mark_prefill_completed
(
&
"req-1"
.to_string
())
.await
;
let
_
=
slots
.free
(
&
"req-1"
.to_string
())
.await
;
let
_
=
slots
.mark_prefill_completed
(
&
"req-2"
.to_string
())
.await
;
let
_
=
slots
.free
(
&
"req-2"
.to_string
())
.await
;
}
}
lib/kv-router/src/scheduling/types.rs
View file @
02b1c58a
...
...
@@ -42,6 +42,7 @@ pub struct SchedulingRequest {
pub
overlaps
:
OverlapScores
,
pub
decode_blocks
:
HashMap
<
WorkerWithDpRank
,
usize
>
,
pub
prefill_tokens
:
HashMap
<
WorkerWithDpRank
,
usize
>
,
pub
track_prefill_tokens
:
bool
,
pub
router_config_override
:
Option
<
RouterConfigOverride
>
,
pub
update_states
:
bool
,
pub
lora_name
:
Option
<
String
>
,
...
...
lib/kv-router/src/sequences/multi_worker.rs
View file @
02b1c58a
...
...
@@ -97,6 +97,7 @@ pub struct SequenceRequest {
pub
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
pub
isl
:
usize
,
pub
overlap
:
u32
,
pub
track_prefill_tokens
:
bool
,
pub
expected_output_tokens
:
Option
<
u32
>
,
pub
worker
:
WorkerWithDpRank
,
pub
lora_name
:
Option
<
String
>
,
...
...
@@ -221,6 +222,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence
,
isl
,
overlap
,
track_prefill_tokens
,
expected_output_tokens
,
}
=>
{
self
.request_to_worker
...
...
@@ -233,12 +235,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
let
table
=
self
.workers
.read
();
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
event
.worker
)
{
table
.slots
[
idx
]
.1
.write
()
.add_request
(
table
.slots
[
idx
]
.1
.write
()
.add_request
_with_prefill_tracking
(
event
.request_id
.clone
(),
token_sequence
.clone
(),
*
isl
,
*
overlap
,
*
expected_output_tokens
,
*
track_prefill_tokens
,
);
}
else
{
tracing
::
warn!
(
...
...
@@ -380,6 +383,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence
,
isl
,
overlap
,
track_prefill_tokens
,
expected_output_tokens
,
worker
,
lora_name
,
...
...
@@ -409,12 +413,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
.get
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
let
mut
seq
=
table
.slots
[
idx
]
.1
.write
();
seq
.add_request
(
seq
.add_request
_with_prefill_tracking
(
request_id
,
token_sequence
,
isl
,
overlap
,
expected_output_tokens
,
track_prefill_tokens
,
)
};
...
...
@@ -437,6 +442,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence
:
req
.token_sequence
.clone
(),
isl
:
req
.isl
,
overlap
:
req
.overlap
,
track_prefill_tokens
:
req
.track_prefill_tokens
,
expected_output_tokens
:
req
.expected_output_tokens
,
},
router_id
:
self
.router_id
,
...
...
@@ -527,6 +533,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
///
/// Note: This operation is idempotent. Calling it multiple times for the same request
/// will log a warning but not return an error (double free is allowed).
///
/// This also performs the underlying prefill-complete cleanup via
/// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request.
pub
async
fn
free
(
&
self
,
request_id
:
&
RequestId
)
->
Result
<
(),
SequenceError
>
{
if
!
self
.request_to_worker
.contains_key
(
request_id
)
{
tracing
::
debug!
(
"Request {request_id} not found, already freed (idempotent)"
);
...
...
@@ -696,6 +706,19 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
)
->
(
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
self
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
isl
,
overlaps
,
true
)
}
pub
fn
potential_blocks_and_tokens_with_prefill_tracking
(
&
self
,
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlaps
:
OverlapScores
,
track_prefill_tokens
:
bool
,
)
->
(
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
#[cfg(feature
=
"bench"
)]
let
start
=
tokio
::
time
::
Instant
::
now
();
...
...
@@ -711,9 +734,14 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
for
(
worker
,
lock
)
in
&
table
.slots
{
let
overlap
=
*
overlaps
.scores
.get
(
worker
)
.unwrap_or
(
&
0
);
let
(
blocks
,
tokens
)
=
lock
.read
()
.potential_blocks_and_tokens
(
token_sequence
,
isl
,
overlap
);
let
(
blocks
,
tokens
)
=
lock
.read
()
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
isl
,
overlap
,
track_prefill_tokens
,
);
potential_blocks
.insert
(
*
worker
,
blocks
);
potential_tokens
.insert
(
*
worker
,
tokens
);
}
...
...
@@ -832,3 +860,44 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
});
}
}
#[cfg(test)]
mod
tests
{
use
std
::
collections
::
HashMap
;
use
super
::
*
;
use
crate
::
test_utils
::
NoopSequencePublisher
;
fn
make_sequences
()
->
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>
{
ActiveSequencesMultiWorker
::
new
(
NoopSequencePublisher
,
4
,
HashMap
::
from
([(
1_u64
,
(
0_u32
,
1_u32
))]),
false
,
0
,
"test"
,
)
}
#[tokio::test]
async
fn
add_request_can_skip_prefill_token_tracking
()
{
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
sequences
.add_request
(
SequenceRequest
{
request_id
:
"req-1"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
false
,
expected_output_tokens
:
None
,
worker
,
lora_name
:
None
,
})
.await
.unwrap
();
assert_eq!
(
sequences
.active_tokens
()
.get
(
&
worker
)
.copied
(),
Some
(
0
));
}
}
lib/kv-router/src/sequences/single.rs
View file @
02b1c58a
...
...
@@ -143,6 +143,27 @@ impl ActiveSequences {
isl
:
usize
,
overlap
:
u32
,
expected_output_tokens
:
Option
<
u32
>
,
)
->
HashSet
<
RequestId
>
{
self
.add_request_with_prefill_tracking
(
request_id
,
token_sequence
,
isl
,
overlap
,
expected_output_tokens
,
true
,
)
}
/// Add a new request with optional prompt-token load accounting.
/// Returns the set of expired request IDs that were removed during cleanup.
pub
fn
add_request_with_prefill_tracking
(
&
mut
self
,
request_id
:
RequestId
,
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
isl
:
usize
,
overlap
:
u32
,
expected_output_tokens
:
Option
<
u32
>
,
track_prefill_tokens
:
bool
,
)
->
HashSet
<
RequestId
>
{
// Check for double-add and log error, returning early
if
self
.active_seqs
.contains_key
(
&
request_id
)
{
...
...
@@ -153,7 +174,11 @@ impl ActiveSequences {
// Lazily check and clean up expired requests, capturing removed IDs
let
removed_requests
=
self
.force_expiry
();
let
prefill_tokens
=
self
.new_tokens
(
isl
,
overlap
);
let
prefill_tokens
=
if
track_prefill_tokens
{
self
.new_tokens
(
isl
,
overlap
)
}
else
{
0
};
self
.prefill_tokens
.insert
(
request_id
.clone
(),
prefill_tokens
);
self
.active_tokens
+=
prefill_tokens
;
...
...
@@ -208,13 +233,27 @@ impl ActiveSequences {
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlap
:
u32
,
)
->
(
usize
,
usize
)
{
self
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
isl
,
overlap
,
true
)
}
pub
fn
potential_blocks_and_tokens_with_prefill_tracking
(
&
self
,
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlap
:
u32
,
track_prefill_tokens
:
bool
,
)
->
(
usize
,
usize
)
{
let
potential_blocks
=
if
let
Some
(
token_seq
)
=
token_sequence
{
self
.new_blocks
(
token_seq
)
+
self
.active_blocks
()
}
else
{
self
.active_blocks
()
};
let
potential_tokens
=
self
.new_tokens
(
isl
,
overlap
)
+
self
.active_tokens
;
let
potential_tokens
=
if
track_prefill_tokens
{
self
.new_tokens
(
isl
,
overlap
)
+
self
.active_tokens
}
else
{
self
.active_tokens
};
(
potential_blocks
,
potential_tokens
)
}
...
...
@@ -232,7 +271,10 @@ impl ActiveSequences {
self
.new_blocks
(
token_sequence
)
+
self
.active_blocks
()
}
/// Free all blocks associated with a request
/// Free all blocks associated with a request.
///
/// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
/// to invoke both when the request is finishing.
pub
fn
free
(
&
mut
self
,
request_id
:
&
RequestId
)
->
usize
{
self
.mark_prefill_completed
(
request_id
);
...
...
@@ -424,6 +466,48 @@ mod tests {
assert_eq!
(
seq_manager
.active_tokens
(),
0
);
}
#[test]
fn
test_add_request_without_prefill_tracking_keeps_active_tokens_zero
()
{
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
seq_manager
.add_request_with_prefill_tracking
(
"r1"
.to_string
(),
Some
(
vec!
[
1
,
2
,
3
]),
12
,
0
,
None
,
false
,
);
assert_eq!
(
seq_manager
.active_tokens
(),
0
);
seq_manager
.mark_prefill_completed
(
&
"r1"
.to_string
());
assert_eq!
(
seq_manager
.active_tokens
(),
0
);
seq_manager
.free
(
&
"r1"
.to_string
());
assert_eq!
(
seq_manager
.active_blocks
(),
0
);
}
#[test]
fn
test_potential_blocks_and_tokens_without_prefill_tracking_ignores_prompt_load
()
{
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
seq_manager
.add_request_with_prefill_tracking
(
"r1"
.to_string
(),
Some
(
vec!
[
1
,
2
,
3
]),
12
,
0
,
None
,
false
,
);
let
(
blocks
,
tokens
)
=
seq_manager
.potential_blocks_and_tokens_with_prefill_tracking
(
Some
(
&
[
1
,
2
,
3
,
4
]),
16
,
0
,
false
,
);
assert_eq!
(
blocks
,
4
);
assert_eq!
(
tokens
,
0
);
}
#[tokio::test(start_paused
=
true
)]
async
fn
test_force_expiry
()
{
let
block_size
=
4
;
...
...
lib/llm/src/kv_router.rs
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::
Arc
;
use
std
::
time
::{
Duration
,
Instant
};
use
std
::
time
::
Instant
;
use
anyhow
::
Result
;
use
dynamo_kv_router
::{
ConcurrentRadixTree
,
ThreadPoolIndexer
,
approx
::
PruneConfig
,
config
::{
KvRouterConfig
,
RouterConfigOverride
},
indexer
::
{
GetWorkersRequest
,
KvIndexer
,
KvIndexerInterface
,
KvIndexerMetrics
,
KvRouterError
}
,
indexer
::
KvRouterError
,
protocols
::
KV_EVENT_SUBJECT
,
protocols
::{
BlockExtraInfo
,
BlockHashOptions
,
DpRank
,
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
RouterRequest
,
RouterResponse
,
TokensWithHashes
,
WorkerId
,
WorkerWithDpRank
,
compute_block_hash_for_seq
,
BlockExtraInfo
,
BlockHashOptions
,
DpRank
,
RouterEvent
,
RouterRequest
,
RouterResponse
,
TokensWithHashes
,
WorkerId
,
WorkerWithDpRank
,
compute_block_hash_for_seq
,
},
};
use
dynamo_runtime
::{
...
...
@@ -29,30 +25,29 @@ use dynamo_runtime::{
traits
::
DistributedRuntimeProvider
,
};
use
futures
::
stream
;
use
tokio
::
sync
::
oneshot
;
use
tracing
::
Instrument
;
use
validator
::
Validate
;
pub
mod
cache_control
;
pub
mod
indexer
;
mod
jetstream
;
pub
mod
metrics
;
pub
mod
prefill_router
;
pub
mod
publisher
;
pub
mod
push_router
;
pub
mod
remote_indexer
;
pub
mod
scheduler
;
pub
mod
sequence
;
pub
mod
subscriber
;
pub
mod
worker_query
;
pub
use
cache_control
::{
CacheControlClient
,
spawn_pin_prefix
};
pub
use
indexer
::
Indexer
;
pub
use
prefill_router
::
PrefillRouter
;
pub
use
push_router
::{
DirectRoutingRouter
,
KvPushRouter
};
use
crate
::{
discovery
::
RuntimeConfigWatch
,
kv_router
::{
remote_indexer
::
RemoteIndexer
,
scheduler
::{
DefaultWorkerSelector
,
KvScheduler
,
PotentialLoad
},
sequence
::{
SequenceError
,
SequenceRequest
},
},
...
...
@@ -108,188 +103,6 @@ pub fn router_discovery_query(namespace: String, component: String) -> Discovery
}
}
#[derive(Clone)]
pub
enum
Indexer
{
/// Single-threaded radix tree with channel-based event processing.
/// Supports TTL-based expiration and size-based pruning.
/// Has the ability to persist and snapshot states.
KvIndexer
(
KvIndexer
),
/// Concurrent radix tree with a thread pool for event processing.
/// Uses sticky worker routing for per-worker event serialization.
/// Does not support TTL/pruning.
Concurrent
(
Arc
<
ThreadPoolIndexer
<
ConcurrentRadixTree
>>
),
/// Forwards queries to a standalone KV indexer service via the request plane.
/// The standalone indexer manages its own radix tree and event subscription.
Remote
(
Arc
<
RemoteIndexer
>
),
/// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
/// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
None
,
}
impl
Indexer
{
pub
async
fn
new
(
component
:
&
dynamo_runtime
::
component
::
Component
,
kv_router_config
:
&
KvRouterConfig
,
block_size
:
u32
,
model_name
:
Option
<
String
>
,
)
->
Result
<
Self
>
{
if
kv_router_config
.overlap_score_weight
==
0.0
{
return
Ok
(
Indexer
::
None
);
}
// Remote indexer: forward queries to a standalone KV indexer service.
if
let
Some
(
ref
indexer_component_name
)
=
kv_router_config
.remote_indexer_component
{
let
model_name
=
model_name
.ok_or_else
(||
{
anyhow
::
anyhow!
(
"model_name is required when remote_indexer_component is configured"
)
})
?
;
tracing
::
info!
(
remote_indexer_component
=
%
indexer_component_name
,
model_name
,
"Using remote KV indexer"
);
let
remote
=
RemoteIndexer
::
new
(
component
,
indexer_component_name
,
model_name
)
.await
?
;
return
Ok
(
Indexer
::
Remote
(
Arc
::
new
(
remote
)));
}
// Approximate mode (--no-kv-events): always use single-threaded KvIndexer
// with TTL/pruning regardless of event_threads, since updates come from
// routing decisions only, not live KV events from workers.
if
!
kv_router_config
.use_kv_events
{
let
kv_indexer_metrics
=
KvIndexerMetrics
::
from_component
(
component
);
let
cancellation_token
=
component
.drt
()
.primary_token
();
let
prune_config
=
Some
(
PruneConfig
{
ttl
:
Duration
::
from_secs_f64
(
kv_router_config
.router_ttl_secs
),
max_tree_size
:
kv_router_config
.router_max_tree_size
,
prune_target_ratio
:
kv_router_config
.router_prune_target_ratio
,
});
return
Ok
(
Indexer
::
KvIndexer
(
KvIndexer
::
new_with_frequency
(
cancellation_token
,
None
,
block_size
,
kv_indexer_metrics
,
prune_config
,
)));
}
if
kv_router_config
.router_event_threads
>
1
{
return
Ok
(
Indexer
::
Concurrent
(
Arc
::
new
(
ThreadPoolIndexer
::
new
(
ConcurrentRadixTree
::
new
(),
kv_router_config
.router_event_threads
as
usize
,
block_size
,
))));
}
let
kv_indexer_metrics
=
KvIndexerMetrics
::
from_component
(
component
);
let
cancellation_token
=
component
.drt
()
.primary_token
();
Ok
(
Indexer
::
KvIndexer
(
KvIndexer
::
new_with_frequency
(
cancellation_token
,
None
,
// expiration_duration for frequency tracking
block_size
,
kv_indexer_metrics
,
None
,
)))
}
pub
(
crate
)
async
fn
find_matches
(
&
self
,
sequence
:
Vec
<
LocalBlockHash
>
,
)
->
Result
<
OverlapScores
,
KvRouterError
>
{
match
self
{
Indexer
::
KvIndexer
(
indexer
)
=>
indexer
.find_matches
(
sequence
)
.await
,
Indexer
::
Concurrent
(
tpi
)
=>
tpi
.find_matches
(
sequence
)
.await
,
Indexer
::
Remote
(
remote
)
=>
remote
.find_matches
(
sequence
)
.await
.map_err
(|
e
|
{
tracing
::
warn!
(
error
=
%
e
,
"Remote indexer query failed"
);
KvRouterError
::
IndexerOffline
}),
Indexer
::
None
=>
Ok
(
OverlapScores
::
new
()),
}
}
pub
(
crate
)
async
fn
dump_events
(
&
self
)
->
Result
<
Vec
<
RouterEvent
>
,
KvRouterError
>
{
match
self
{
Indexer
::
KvIndexer
(
indexer
)
=>
indexer
.dump_events
()
.await
,
Indexer
::
Concurrent
(
tpi
)
=>
tpi
.dump_events
()
.await
,
Indexer
::
Remote
(
_
)
=>
Ok
(
Vec
::
new
()),
Indexer
::
None
=>
{
panic!
(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
);
}
}
}
pub
(
crate
)
async
fn
process_routing_decision_for_request
(
&
self
,
tokens_with_hashes
:
&
mut
TokensWithHashes
,
worker
:
WorkerWithDpRank
,
)
->
Result
<
(),
KvRouterError
>
{
match
self
{
Indexer
::
KvIndexer
(
indexer
)
=>
{
indexer
.process_routing_decision_for_request
(
tokens_with_hashes
,
worker
)
.await
}
Indexer
::
Concurrent
(
tpi
)
=>
{
tpi
.process_routing_decision_for_request
(
tokens_with_hashes
,
worker
)
.await
}
Indexer
::
Remote
(
_
)
=>
Ok
(()),
Indexer
::
None
=>
Ok
(()),
}
}
pub
(
crate
)
async
fn
apply_event
(
&
self
,
event
:
RouterEvent
)
{
match
self
{
Indexer
::
KvIndexer
(
indexer
)
=>
{
if
let
Err
(
e
)
=
indexer
.event_sender
()
.send
(
event
)
.await
{
tracing
::
warn!
(
"Failed to send event to indexer: {e}"
);
}
}
Indexer
::
Concurrent
(
tpi
)
=>
tpi
.apply_event
(
event
)
.await
,
Indexer
::
Remote
(
_
)
=>
{}
// standalone indexer gets events directly
Indexer
::
None
=>
{}
}
}
pub
(
crate
)
async
fn
remove_worker
(
&
self
,
worker_id
:
WorkerId
)
{
match
self
{
Indexer
::
KvIndexer
(
indexer
)
=>
{
if
let
Err
(
e
)
=
indexer
.remove_worker_sender
()
.send
(
worker_id
)
.await
{
tracing
::
warn!
(
"Failed to send worker removal for {worker_id}: {e}"
);
}
}
Indexer
::
Concurrent
(
tpi
)
=>
{
KvIndexerInterface
::
remove_worker
(
tpi
.as_ref
(),
worker_id
)
.await
;
}
Indexer
::
Remote
(
_
)
=>
{}
// standalone indexer manages its own workers
Indexer
::
None
=>
{}
}
}
pub
(
crate
)
async
fn
get_workers
(
&
self
)
->
Vec
<
WorkerId
>
{
match
self
{
Indexer
::
KvIndexer
(
indexer
)
=>
{
let
(
resp_tx
,
resp_rx
)
=
oneshot
::
channel
();
let
req
=
GetWorkersRequest
{
resp
:
resp_tx
};
if
let
Err
(
e
)
=
indexer
.get_workers_sender
()
.send
(
req
)
.await
{
tracing
::
warn!
(
"Failed to send get_workers request: {e}"
);
return
Vec
::
new
();
}
resp_rx
.await
.unwrap_or_default
()
}
Indexer
::
Concurrent
(
tpi
)
=>
tpi
.backend
()
.get_workers
(),
Indexer
::
Remote
(
_
)
=>
Vec
::
new
(),
Indexer
::
None
=>
Vec
::
new
(),
}
}
}
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
pub
struct
KvRouter
<
Sel
=
DefaultWorkerSelector
>
...
...
@@ -529,6 +342,9 @@ where
hash_options
,
None
,
);
let
track_prefill_tokens
=
self
.kv_router_config
.track_prefill_tokens
(
router_config_override
);
if
let
Err
(
e
)
=
self
.scheduler
...
...
@@ -537,6 +353,7 @@ where
token_sequence
:
maybe_seq_hashes
,
isl
:
isl_tokens
,
overlap
:
overlap_blocks
,
track_prefill_tokens
,
expected_output_tokens
,
worker
,
lora_name
,
...
...
@@ -623,12 +440,17 @@ where
hash_options
,
Some
(
&
block_hashes
),
);
let
track_prefill_tokens
=
self
.kv_router_config
.track_prefill_tokens
(
router_config_override
);
let
overlap_scores
=
self
.indexer
.find_matches
(
block_hashes
)
.await
?
;
Ok
(
self
.scheduler
.get_potential_loads
(
maybe_seq_hashes
,
isl_tokens
,
overlap_scores
))
Ok
(
self
.scheduler
.get_potential_loads
(
maybe_seq_hashes
,
isl_tokens
,
overlap_scores
,
track_prefill_tokens
,
))
}
/// Dump all events from the indexer
...
...
lib/llm/src/kv_router/indexer.rs
0 → 100644
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
anyhow
::
Result
;
use
futures
::
StreamExt
;
use
dynamo_kv_router
::{
ConcurrentRadixTree
,
ThreadPoolIndexer
,
approx
::
PruneConfig
,
config
::
KvRouterConfig
,
indexer
::{
IndexerQueryRequest
,
IndexerQueryResponse
,
KV_INDEXER_QUERY_ENDPOINT
,
KvIndexer
,
KvIndexerInterface
,
KvIndexerMetrics
,
KvRouterError
,
},
protocols
::{
LocalBlockHash
,
OverlapScores
,
RouterEvent
,
TokensWithHashes
,
WorkerId
,
WorkerWithDpRank
,
},
};
use
dynamo_runtime
::{
component
::
Component
,
pipeline
::{
ManyOut
,
RouterMode
,
SingleIn
,
network
::
egress
::
push_router
::
PushRouter
},
traits
::
DistributedRuntimeProvider
,
};
use
tokio
::
sync
::
oneshot
;
pub
struct
RemoteIndexer
{
router
:
PushRouter
<
IndexerQueryRequest
,
IndexerQueryResponse
>
,
model_name
:
String
,
namespace
:
String
,
}
impl
RemoteIndexer
{
async
fn
new
(
component
:
&
Component
,
indexer_component_name
:
&
str
,
model_name
:
String
,
)
->
Result
<
Self
>
{
let
namespace
=
component
.namespace
()
.name
();
let
indexer_ns
=
component
.namespace
();
let
indexer_component
=
indexer_ns
.component
(
indexer_component_name
)
?
;
let
endpoint
=
indexer_component
.endpoint
(
KV_INDEXER_QUERY_ENDPOINT
);
let
client
=
endpoint
.client
()
.await
?
;
let
router
=
PushRouter
::
from_client_no_fault_detection
(
client
,
RouterMode
::
RoundRobin
)
.await
?
;
Ok
(
Self
{
router
,
model_name
,
namespace
,
})
}
async
fn
find_matches
(
&
self
,
block_hashes
:
Vec
<
LocalBlockHash
>
)
->
Result
<
OverlapScores
>
{
let
request
=
IndexerQueryRequest
{
model_name
:
self
.model_name
.clone
(),
namespace
:
self
.namespace
.clone
(),
block_hashes
,
};
let
mut
stream
:
ManyOut
<
IndexerQueryResponse
>
=
self
.router
.round_robin
(
SingleIn
::
new
(
request
))
.await
?
;
match
stream
.next
()
.await
{
Some
(
IndexerQueryResponse
::
Scores
(
scores
))
=>
Ok
(
scores
.into
()),
Some
(
IndexerQueryResponse
::
Error
(
msg
))
=>
{
Err
(
anyhow
::
anyhow!
(
"Remote indexer error: {}"
,
msg
))
}
None
=>
Err
(
anyhow
::
anyhow!
(
"Remote indexer returned empty response"
)),
}
}
}
#[derive(Clone)]
pub
enum
Indexer
{
KvIndexer
(
KvIndexer
),
Concurrent
(
Arc
<
ThreadPoolIndexer
<
ConcurrentRadixTree
>>
),
Remote
(
Arc
<
RemoteIndexer
>
),
None
,
}
impl
Indexer
{
pub
async
fn
new
(
component
:
&
Component
,
kv_router_config
:
&
KvRouterConfig
,
block_size
:
u32
,
model_name
:
Option
<
String
>
,
)
->
Result
<
Self
>
{
if
kv_router_config
.overlap_score_weight
==
0.0
{
return
Ok
(
Self
::
None
);
}
if
let
Some
(
ref
indexer_component_name
)
=
kv_router_config
.remote_indexer_component
{
let
model_name
=
model_name
.ok_or_else
(||
{
anyhow
::
anyhow!
(
"model_name is required when remote_indexer_component is configured"
)
})
?
;
tracing
::
info!
(
remote_indexer_component
=
%
indexer_component_name
,
model_name
,
"Using remote KV indexer"
);
let
remote
=
RemoteIndexer
::
new
(
component
,
indexer_component_name
,
model_name
)
.await
?
;
return
Ok
(
Self
::
Remote
(
Arc
::
new
(
remote
)));
}
if
!
kv_router_config
.use_kv_events
{
let
kv_indexer_metrics
=
KvIndexerMetrics
::
from_component
(
component
);
let
cancellation_token
=
component
.drt
()
.primary_token
();
let
prune_config
=
Some
(
PruneConfig
{
ttl
:
Duration
::
from_secs_f64
(
kv_router_config
.router_ttl_secs
),
max_tree_size
:
kv_router_config
.router_max_tree_size
,
prune_target_ratio
:
kv_router_config
.router_prune_target_ratio
,
});
return
Ok
(
Self
::
KvIndexer
(
KvIndexer
::
new_with_frequency
(
cancellation_token
,
None
,
block_size
,
kv_indexer_metrics
,
prune_config
,
)));
}
if
kv_router_config
.router_event_threads
>
1
{
return
Ok
(
Self
::
Concurrent
(
Arc
::
new
(
ThreadPoolIndexer
::
new
(
ConcurrentRadixTree
::
new
(),
kv_router_config
.router_event_threads
as
usize
,
block_size
,
))));
}
let
kv_indexer_metrics
=
KvIndexerMetrics
::
from_component
(
component
);
let
cancellation_token
=
component
.drt
()
.primary_token
();
Ok
(
Self
::
KvIndexer
(
KvIndexer
::
new_with_frequency
(
cancellation_token
,
None
,
block_size
,
kv_indexer_metrics
,
None
,
)))
}
pub
(
crate
)
async
fn
find_matches
(
&
self
,
sequence
:
Vec
<
LocalBlockHash
>
,
)
->
Result
<
OverlapScores
,
KvRouterError
>
{
match
self
{
Self
::
KvIndexer
(
indexer
)
=>
indexer
.find_matches
(
sequence
)
.await
,
Self
::
Concurrent
(
tpi
)
=>
tpi
.find_matches
(
sequence
)
.await
,
Self
::
Remote
(
remote
)
=>
remote
.find_matches
(
sequence
)
.await
.map_err
(|
e
|
{
tracing
::
warn!
(
error
=
%
e
,
"Remote indexer query failed"
);
KvRouterError
::
IndexerOffline
}),
Self
::
None
=>
Ok
(
OverlapScores
::
new
()),
}
}
pub
(
crate
)
async
fn
dump_events
(
&
self
)
->
Result
<
Vec
<
RouterEvent
>
,
KvRouterError
>
{
match
self
{
Self
::
KvIndexer
(
indexer
)
=>
indexer
.dump_events
()
.await
,
Self
::
Concurrent
(
tpi
)
=>
tpi
.dump_events
()
.await
,
Self
::
Remote
(
_
)
=>
Ok
(
Vec
::
new
()),
Self
::
None
=>
{
panic!
(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
);
}
}
}
pub
(
crate
)
async
fn
process_routing_decision_for_request
(
&
self
,
tokens_with_hashes
:
&
mut
TokensWithHashes
,
worker
:
WorkerWithDpRank
,
)
->
Result
<
(),
KvRouterError
>
{
match
self
{
Self
::
KvIndexer
(
indexer
)
=>
{
indexer
.process_routing_decision_for_request
(
tokens_with_hashes
,
worker
)
.await
}
Self
::
Concurrent
(
tpi
)
=>
{
tpi
.process_routing_decision_for_request
(
tokens_with_hashes
,
worker
)
.await
}
Self
::
Remote
(
_
)
|
Self
::
None
=>
Ok
(()),
}
}
pub
(
crate
)
async
fn
apply_event
(
&
self
,
event
:
RouterEvent
)
{
match
self
{
Self
::
KvIndexer
(
indexer
)
=>
{
if
let
Err
(
e
)
=
indexer
.event_sender
()
.send
(
event
)
.await
{
tracing
::
warn!
(
"Failed to send event to indexer: {e}"
);
}
}
Self
::
Concurrent
(
tpi
)
=>
tpi
.apply_event
(
event
)
.await
,
Self
::
Remote
(
_
)
|
Self
::
None
=>
{}
}
}
pub
(
crate
)
async
fn
remove_worker
(
&
self
,
worker_id
:
WorkerId
)
{
match
self
{
Self
::
KvIndexer
(
indexer
)
=>
{
if
let
Err
(
e
)
=
indexer
.remove_worker_sender
()
.send
(
worker_id
)
.await
{
tracing
::
warn!
(
"Failed to send worker removal for {worker_id}: {e}"
);
}
}
Self
::
Concurrent
(
tpi
)
=>
{
KvIndexerInterface
::
remove_worker
(
tpi
.as_ref
(),
worker_id
)
.await
;
}
Self
::
Remote
(
_
)
|
Self
::
None
=>
{}
}
}
pub
(
crate
)
async
fn
get_workers
(
&
self
)
->
Vec
<
WorkerId
>
{
match
self
{
Self
::
KvIndexer
(
indexer
)
=>
{
let
(
resp_tx
,
resp_rx
)
=
oneshot
::
channel
();
let
req
=
dynamo_kv_router
::
indexer
::
GetWorkersRequest
{
resp
:
resp_tx
};
if
let
Err
(
e
)
=
indexer
.get_workers_sender
()
.send
(
req
)
.await
{
tracing
::
warn!
(
"Failed to send get_workers request: {e}"
);
return
Vec
::
new
();
}
resp_rx
.await
.unwrap_or_default
()
}
Self
::
Concurrent
(
tpi
)
=>
tpi
.backend
()
.get_workers
(),
Self
::
Remote
(
_
)
|
Self
::
None
=>
Vec
::
new
(),
}
}
}
lib/llm/src/kv_router/prefill_router/activation.rs
0 → 100644
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::
Arc
;
use
anyhow
::
Result
;
use
tokio
::
sync
::
oneshot
;
use
dynamo_kv_router
::
config
::
KvRouterConfig
;
use
dynamo_runtime
::{
component
::{
Client
,
Endpoint
},
pipeline
::{
PushRouter
,
RouterMode
},
protocols
::
annotated
::
Annotated
,
};
use
super
::{
InnerPrefillRouter
,
PrefillRouter
};
use
crate
::{
discovery
::
ModelManager
,
kv_router
::
KvPushRouter
,
protocols
::
common
::{
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
timing
::
WORKER_TYPE_PREFILL
,
},
};
impl
PrefillRouter
{
/// Create a disabled prefill router that will never activate (passthrough only)
pub
fn
disabled
(
model_manager
:
Arc
<
ModelManager
>
,
router_mode
:
RouterMode
,
enforce_disagg
:
bool
,
)
->
Arc
<
Self
>
{
Arc
::
new
(
Self
{
prefill_router
:
std
::
sync
::
OnceLock
::
new
(),
model_manager
,
endpoint_id
:
std
::
sync
::
OnceLock
::
new
(),
cancel_token
:
tokio_util
::
sync
::
CancellationToken
::
new
(),
router_mode
,
enforce_disagg
,
model_name
:
String
::
new
(),
// Not used for disabled router
namespace
:
String
::
new
(),
// Not used for disabled router
is_eagle
:
false
,
})
}
#[expect(clippy::too_many_arguments)]
pub
fn
new
(
activation_rx
:
oneshot
::
Receiver
<
Endpoint
>
,
model_manager
:
Arc
<
ModelManager
>
,
router_mode
:
RouterMode
,
kv_cache_block_size
:
u32
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
enforce_disagg
:
bool
,
model_name
:
String
,
namespace
:
String
,
is_eagle
:
bool
,
)
->
Arc
<
Self
>
{
let
prefill_router
=
std
::
sync
::
OnceLock
::
new
();
let
cancel_token
=
tokio_util
::
sync
::
CancellationToken
::
new
();
let
router
=
Arc
::
new
(
Self
{
prefill_router
,
model_manager
:
model_manager
.clone
(),
endpoint_id
:
std
::
sync
::
OnceLock
::
new
(),
cancel_token
:
cancel_token
.clone
(),
router_mode
,
enforce_disagg
,
model_name
,
namespace
,
is_eagle
,
});
// Spawn background task to wait for activation
let
router_clone
=
router
.clone
();
tokio
::
spawn
(
async
move
{
tokio
::
select!
{
result
=
activation_rx
=>
{
let
Ok
(
endpoint
)
=
result
else
{
tracing
::
debug!
(
"Prefill router activation channel closed without receiving endpoint"
);
return
;
};
if
let
Err
(
e
)
=
router_clone
.activate
(
endpoint
,
model_manager
,
kv_cache_block_size
,
kv_router_config
,
)
.await
{
tracing
::
error!
(
error
=
%
e
,
"Failed to activate prefill router"
);
}
}
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
debug!
(
"Prefill router activation cancelled"
);
}
}
});
router
}
/// Activate the prefill router with the provided endpoint
async
fn
activate
(
&
self
,
endpoint
:
Endpoint
,
model_manager
:
Arc
<
ModelManager
>
,
kv_cache_block_size
:
u32
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
)
->
Result
<
()
>
{
tracing
::
info!
(
router_mode
=
?
self
.router_mode
,
"Activating prefill router"
);
// Store endpoint_id for later use in resolve_prefill_worker
let
_
=
self
.endpoint_id
.set
(
endpoint
.id
());
// Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint)
// This must be done before creating the router so bootstrap info is available
model_manager
.get_or_create_runtime_config_watcher
(
&
endpoint
)
.await
?
;
let
inner_router
=
if
self
.router_mode
.is_kv_routing
()
{
// Create KV chooser using the endpoint (this is a prefill router)
let
kv_chooser
=
model_manager
.kv_chooser_for
(
&
endpoint
,
kv_cache_block_size
,
kv_router_config
,
WORKER_TYPE_PREFILL
,
Some
(
self
.model_name
.clone
()),
self
.is_eagle
,
)
.await
?
;
// Extract client from kv_chooser to ensure shared state
let
client
=
kv_chooser
.client
()
.clone
();
self
.register_prefill_client
(
model_manager
.as_ref
(),
&
client
);
// Build the PushRouter for prefill with KV mode using the shared client
let
push_router
=
PushRouter
::
<
PreprocessedRequest
,
Annotated
<
LLMEngineOutput
>>
::
from_client_with_threshold
(
client
,
RouterMode
::
KV
,
None
,
// busy_threshold
None
,
// worker_monitor
)
.await
?
;
// Wrap it in KvPushRouter
InnerPrefillRouter
::
KvRouter
(
Arc
::
new
(
KvPushRouter
::
new
(
push_router
,
kv_chooser
)))
}
else
{
// Create client for simple router
let
client
=
endpoint
.client
()
.await
?
;
self
.register_prefill_client
(
model_manager
.as_ref
(),
&
client
);
// Create simple push router with the frontend's router mode
// Note: Per-worker metrics (active_prefill_tokens, active_decode_blocks) are only
// available in KV routing mode where the router has actual bookkeeping.
let
push_router
=
PushRouter
::
<
PreprocessedRequest
,
Annotated
<
LLMEngineOutput
>>
::
from_client_with_threshold
(
client
,
self
.router_mode
,
None
,
// busy_threshold
None
,
// worker_monitor
)
.await
?
;
InnerPrefillRouter
::
SimpleRouter
(
Arc
::
new
(
push_router
))
};
// Set the router (ignore error if already set)
let
_
=
self
.prefill_router
.set
(
inner_router
);
tracing
::
info!
(
router_mode
=
?
self
.router_mode
,
"Prefill router activated successfully"
);
Ok
(())
}
fn
register_prefill_client
(
&
self
,
model_manager
:
&
ModelManager
,
client
:
&
Client
)
{
if
let
Some
(
monitor
)
=
model_manager
.get_worker_monitor_for_namespace
(
&
self
.model_name
,
&
self
.namespace
)
{
monitor
.set_prefill_client
(
client
.clone
());
}
}
}
lib/llm/src/kv_router/prefill_router.rs
→
lib/llm/src/kv_router/prefill_router
/execution
.rs
View file @
02b1c58a
...
...
@@ -2,301 +2,41 @@
// SPDX-License-Identifier: Apache-2.0
use
std
::
collections
::
HashSet
;
use
std
::
sync
::
{
Arc
,
OnceLock
}
;
use
std
::
sync
::
Arc
;
use
anyhow
::
Result
;
use
futures
::
StreamExt
;
use
tokio
::
sync
::{
OwnedSemaphorePermit
,
oneshot
};
use
tokio_util
::
sync
::
CancellationToken
;
use
tokio
::
sync
::
OwnedSemaphorePermit
;
use
tracing
::
Instrument
;
use
dynamo_kv_router
::{
config
::{
KvRouterConfig
,
RouterConfigOverride
},
protocols
::{
BlockExtraInfo
,
WorkerId
},
};
use
dynamo_kv_router
::
protocols
::{
BlockExtraInfo
,
WorkerId
};
use
dynamo_runtime
::{
component
::
Endpoint
,
pipeline
::{
AsyncEngine
,
AsyncEngineContextProvider
,
Context
,
ManyOut
,
Operator
,
PushRouter
,
RouterMode
,
ServerStreamingEngine
,
SingleIn
,
async_trait
,
},
protocols
::{
EndpointId
,
annotated
::
Annotated
,
maybe_error
::
MaybeError
},
engine
::
AsyncEngineContext
,
pipeline
::{
AsyncEngineContextProvider
,
Context
,
SingleIn
},
protocols
::
maybe_error
::
MaybeError
,
};
use
crate
::{
discovery
::
ModelManager
,
kv_router
::
KvPushRouter
,
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
protocols
::
common
::
preprocessor
::{
BootstrapInfo
,
PrefillResult
},
protocols
::
common
::
timing
::{
RequestPhase
,
RequestTracker
,
WORKER_TYPE_PREFILL
},
use
super
::{
InnerPrefillRouter
,
PrefillError
,
PrefillResolveDecision
,
PrefillRouter
};
use
crate
::
protocols
::
common
::{
llm_backend
::
PreprocessedRequest
,
preprocessor
::{
BootstrapInfo
,
PrefillResult
},
};
/// Errors that can occur during prefill routing
#[derive(Debug,
thiserror::Error)]
pub
enum
PrefillError
{
/// Prefill router has not been activated yet
#[error(
"Prefill router not yet activated"
)]
NotActivated
,
/// TODO: Separate prefill worker error from prefill router error
/// Error during prefill execution
#[error(
"Prefill execution failed: {0}"
)]
PrefillError
(
String
,
#[source]
Option
<
Box
<
dyn
std
::
error
::
Error
+
Send
+
Sync
+
'static
>>
,
),
/// Disaggregated params not found in prefill response
#[error(
"No disaggregated params in prefill response: {0}"
)]
NoDisaggregatedParams
(
String
),
}
/// Result of the prefill phase in `generate()`.
enum
PrefillOutcome
{
/// Bootstrap optimization: prefill spawned in background, bootstrap info ready
Bootstrap
(
BootstrapInfo
),
/// Synchronous prefill completed with result
Completed
(
PrefillResult
),
}
/// The inner router used by PrefillRouter
#[derive(Clone)]
enum
InnerPrefillRouter
{
/// KV-aware routing using KvPushRouter
KvRouter
(
Arc
<
KvPushRouter
>
),
/// Simple routing (RoundRobin, Random, Direct)
/// Note: Per-worker metrics (active_prefill_tokens, active_decode_blocks) are only
/// available in KV routing mode where the router has actual bookkeeping.
SimpleRouter
(
Arc
<
PushRouter
<
PreprocessedRequest
,
Annotated
<
LLMEngineOutput
>>>
),
}
impl
InnerPrefillRouter
{
/// Generate with optional direct routing to specific worker.
/// For KvRouter, target_worker is ignored since prefill_worker_id is already set on the request.
/// For SimpleRouter, target_worker triggers direct routing via router.direct().
async
fn
generate_to_worker
(
&
self
,
request
:
SingleIn
<
PreprocessedRequest
>
,
target_worker
:
Option
<
u64
>
,
)
->
Result
<
ManyOut
<
Annotated
<
LLMEngineOutput
>>>
{
match
(
self
,
target_worker
)
{
// KvRouter: prefill_worker_id already set on request, KvPushRouter::select_worker uses it
(
InnerPrefillRouter
::
KvRouter
(
router
),
_
)
=>
router
.generate
(
request
)
.await
,
(
InnerPrefillRouter
::
SimpleRouter
(
router
),
Some
(
worker_id
))
=>
{
router
.direct
(
request
,
worker_id
)
.await
}
(
InnerPrefillRouter
::
SimpleRouter
(
router
),
None
)
=>
router
.generate
(
request
)
.await
,
}
}
/// Select next worker (for non-KV modes only)
fn
select_next_worker
(
&
self
)
->
Option
<
u64
>
{
match
self
{
InnerPrefillRouter
::
SimpleRouter
(
router
)
=>
router
.select_next_worker
(),
InnerPrefillRouter
::
KvRouter
(
_
)
=>
None
,
}
}
}
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
///
/// Modes:
/// - Query-only: `query_instance_id` annotation present → returns worker IDs without execution
/// - Pre-routed: `prefill_worker_id`/`decode_worker_id` set → routes to specified workers
/// - Normal: Worker IDs determined by router based on KV cache state
pub
struct
PrefillRouter
{
prefill_router
:
OnceLock
<
InnerPrefillRouter
>
,
model_manager
:
Arc
<
ModelManager
>
,
endpoint_id
:
OnceLock
<
EndpointId
>
,
cancel_token
:
CancellationToken
,
router_mode
:
RouterMode
,
enforce_disagg
:
bool
,
/// Model name used to look up the worker monitor for prefill client registration
model_name
:
String
,
/// Namespace used to look up the correct WorkerSet's worker monitor
namespace
:
String
,
is_eagle
:
bool
,
}
impl
PrefillRouter
{
/// Create a disabled prefill router that will never activate (passthrough only)
pub
fn
disabled
(
model_manager
:
Arc
<
ModelManager
>
,
router_mode
:
RouterMode
,
enforce_disagg
:
bool
,
)
->
Arc
<
Self
>
{
Arc
::
new
(
Self
{
prefill_router
:
OnceLock
::
new
(),
model_manager
,
endpoint_id
:
OnceLock
::
new
(),
cancel_token
:
CancellationToken
::
new
(),
router_mode
,
enforce_disagg
,
model_name
:
String
::
new
(),
// Not used for disabled router
namespace
:
String
::
new
(),
// Not used for disabled router
is_eagle
:
false
,
})
}
#[expect(clippy::too_many_arguments)]
pub
fn
new
(
activation_rx
:
oneshot
::
Receiver
<
Endpoint
>
,
model_manager
:
Arc
<
ModelManager
>
,
router_mode
:
RouterMode
,
kv_cache_block_size
:
u32
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
enforce_disagg
:
bool
,
model_name
:
String
,
namespace
:
String
,
is_eagle
:
bool
,
)
->
Arc
<
Self
>
{
let
prefill_router
=
OnceLock
::
new
();
let
cancel_token
=
CancellationToken
::
new
();
let
router
=
Arc
::
new
(
Self
{
prefill_router
,
model_manager
:
model_manager
.clone
(),
endpoint_id
:
OnceLock
::
new
(),
cancel_token
:
cancel_token
.clone
(),
router_mode
,
enforce_disagg
,
model_name
,
namespace
,
is_eagle
,
});
// Spawn background task to wait for activation
let
router_clone
=
router
.clone
();
tokio
::
spawn
(
async
move
{
tokio
::
select!
{
result
=
activation_rx
=>
{
let
Ok
(
endpoint
)
=
result
else
{
tracing
::
debug!
(
"Prefill router activation channel closed without receiving endpoint"
);
return
;
};
if
let
Err
(
e
)
=
router_clone
.activate
(
endpoint
,
model_manager
,
kv_cache_block_size
,
kv_router_config
,
)
.await
{
tracing
::
error!
(
error
=
%
e
,
"Failed to activate prefill router"
);
}
}
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
debug!
(
"Prefill router activation cancelled"
);
}
}
});
router
}
/// Activate the prefill router with the provided endpoint
async
fn
activate
(
&
self
,
endpoint
:
Endpoint
,
model_manager
:
Arc
<
ModelManager
>
,
kv_cache_block_size
:
u32
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
)
->
Result
<
()
>
{
tracing
::
info!
(
router_mode
=
?
self
.router_mode
,
"Activating prefill router"
);
// Store endpoint_id for later use in resolve_prefill_worker
let
_
=
self
.endpoint_id
.set
(
endpoint
.id
());
// Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint)
// This must be done before creating the router so bootstrap info is available
model_manager
.get_or_create_runtime_config_watcher
(
&
endpoint
)
.await
?
;
let
inner_router
=
if
self
.router_mode
.is_kv_routing
()
{
// Create KV chooser using the endpoint (this is a prefill router)
let
kv_chooser
=
model_manager
.kv_chooser_for
(
&
endpoint
,
kv_cache_block_size
,
kv_router_config
,
WORKER_TYPE_PREFILL
,
Some
(
self
.model_name
.clone
()),
self
.is_eagle
,
)
.await
?
;
// Extract client from kv_chooser to ensure shared state
let
client
=
kv_chooser
.client
()
.clone
();
// Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode
if
let
Some
(
monitor
)
=
model_manager
.get_worker_monitor_for_namespace
(
&
self
.model_name
,
&
self
.namespace
)
{
monitor
.set_prefill_client
(
client
.clone
());
}
// Build the PushRouter for prefill with KV mode using the shared client
let
push_router
=
PushRouter
::
<
PreprocessedRequest
,
Annotated
<
LLMEngineOutput
>>
::
from_client_with_threshold
(
client
,
RouterMode
::
KV
,
None
,
// busy_threshold
None
,
// worker_monitor
)
.await
?
;
// Wrap it in KvPushRouter
InnerPrefillRouter
::
KvRouter
(
Arc
::
new
(
KvPushRouter
::
new
(
push_router
,
kv_chooser
)))
}
else
{
// Create client for simple router
let
client
=
endpoint
.client
()
.await
?
;
// Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode
if
let
Some
(
monitor
)
=
model_manager
.get_worker_monitor_for_namespace
(
&
self
.model_name
,
&
self
.namespace
)
{
monitor
.set_prefill_client
(
client
.clone
());
}
// Create simple push router with the frontend's router mode
// Note: Per-worker metrics (active_prefill_tokens, active_decode_blocks) are only
// available in KV routing mode where the router has actual bookkeeping.
let
push_router
=
PushRouter
::
<
PreprocessedRequest
,
Annotated
<
LLMEngineOutput
>>
::
from_client_with_threshold
(
client
,
self
.router_mode
,
None
,
// busy_threshold
None
,
// worker_monitor
)
.await
?
;
InnerPrefillRouter
::
SimpleRouter
(
Arc
::
new
(
push_router
))
};
// Set the router (ignore error if already set)
let
_
=
self
.prefill_router
.set
(
inner_router
);
tracing
::
info!
(
router_mode
=
?
self
.router_mode
,
"Prefill router activated successfully"
);
Ok
(())
}
/// Select a prefill worker and resolve its bootstrap connection info.
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
async
fn
resolve_prefill_worker
(
pub
(
super
)
async
fn
resolve_prefill_worker
(
&
self
,
req
:
&
PreprocessedRequest
,
preselected_worker
:
Option
<
u64
>
,
)
->
Option
<
(
u64
,
u32
,
BootstrapInfo
)
>
{
let
endpoint_id
=
self
.endpoint_id
.get
()
?
;
self
.prefill_router
.get
()
?
;
)
->
PrefillResolveDecision
{
let
Some
(
endpoint_id
)
=
self
.endpoint_id
.get
()
else
{
return
PrefillResolveDecision
::
NotActivated
;
};
if
self
.prefill_router
.get
()
.is_none
()
{
return
PrefillResolveDecision
::
NotActivated
;
}
// Worker selection
let
(
worker_id
,
dp_rank
)
=
if
let
Some
(
id
)
=
preselected_worker
{
...
...
@@ -333,16 +73,23 @@ impl PrefillRouter {
.await
{
Ok
((
worker_id
,
dp_rank
))
=>
(
worker_id
,
dp_rank
),
Err
(
_
)
=>
return
Non
e
,
Err
(
_
)
=>
return
PrefillResolveDecision
::
Unavailabl
e
,
}
};
// Get bootstrap info from ModelManager (works for ANY mode)
let
endpoint
=
self
let
Some
(
endpoint
)
=
self
.model_manager
.get_disaggregated_endpoint
(
endpoint_id
,
worker_id
)
?
;
let
host
=
endpoint
.bootstrap_host
?
;
let
port
=
endpoint
.bootstrap_port
?
;
.get_disaggregated_endpoint
(
endpoint_id
,
worker_id
)
else
{
return
PrefillResolveDecision
::
NoBootstrapEndpoint
;
};
let
Some
(
host
)
=
endpoint
.bootstrap_host
else
{
return
PrefillResolveDecision
::
NoBootstrapEndpoint
;
};
let
Some
(
port
)
=
endpoint
.bootstrap_port
else
{
return
PrefillResolveDecision
::
NoBootstrapEndpoint
;
};
let
bootstrap_room
:
u64
=
rand
::
random_range
(
0
..=
i64
::
MAX
.cast_unsigned
());
...
...
@@ -356,31 +103,31 @@ impl PrefillRouter {
"Built bootstrap_info upfront before prefill"
);
Some
((
PrefillResolveDecision
::
Resolved
{
worker_id
,
dp_rank
,
BootstrapInfo
{
bootstrap_info
:
BootstrapInfo
{
bootstrap_host
:
host
,
bootstrap_port
:
port
,
bootstrap_room
,
},
))
}
}
/// Execute prefill with the given router and extract structured result.
///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
///
/// If `phase_permit` is provided, it is dropped
after the first output is received
,
/// allowing subsequent `set_phase` calls to proceed. This
is used in the bootstrap
///
optimization path to ensure
`record_worker_full`
completes
before the phase change
s
.
/// If `phase_
transition_
permit` is provided, it is dropped
immediately after routing completes
,
/// allowing subsequent `set_phase` calls to proceed. This
preserves the current synchronization:
///
the prefill route must finish
`record_worker_full` before the phase
can
change
to Decode
.
///
/// Returns (PrefillResult, Option<(worker_id, dp_rank)>).
async
fn
execute_prefill
(
pub
(
super
)
async
fn
execute_prefill
(
router
:
Option
<
InnerPrefillRouter
>
,
request
:
SingleIn
<
PreprocessedRequest
>
,
target_worker
:
Option
<
u64
>
,
phase_permit
:
Option
<
OwnedSemaphorePermit
>
,
phase_
transition_
permit
:
Option
<
OwnedSemaphorePermit
>
,
)
->
Result
<
(
PrefillResult
,
Option
<
(
u64
,
u32
)
>
),
PrefillError
>
{
let
router
=
router
.ok_or
(
PrefillError
::
NotActivated
)
?
;
let
mut
prefill_response
=
router
...
...
@@ -393,9 +140,9 @@ impl PrefillRouter {
)
})
?
;
//
Drop phase permit
now
-
routing
is
complete
,
record_worker_full
was called in select_worker
.
//
This unblocks set_phase(Decode) in the main task
without waiting for prefill output.
drop
(
phase_permit
);
//
Release the phase barrier
now
that
routing complete
d and
record_worker_full
already ran
.
//
Decode may proceed
without waiting for prefill output
streaming to finish
.
drop
(
phase_
transition_
permit
);
let
Some
(
first_output
)
=
prefill_response
.next
()
.await
else
{
return
Err
(
PrefillError
::
PrefillError
(
...
...
@@ -468,13 +215,13 @@ impl PrefillRouter {
///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
///
/// The `phase_permit` is passed to the spawned task and
dropp
ed after
the first output,
/// allowing the main task's `set_phase(Decode)` to proceed.
fn
spawn_prefill_task
(
/// The `phase_
transition_
permit` is passed to the spawned task and
releas
ed after
routing
///
completes,
allowing the main task's `set_phase(Decode)` to proceed.
pub
(
super
)
fn
spawn_prefill_task
(
&
self
,
prefill_request
:
SingleIn
<
PreprocessedRequest
>
,
target_worker
:
Option
<
u64
>
,
phase_permit
:
OwnedSemaphorePermit
,
phase_
transition_
permit
:
OwnedSemaphorePermit
,
)
{
let
router
=
self
.prefill_router
.get
()
.cloned
();
// Capture current span to propagate trace context to the spawned task
...
...
@@ -486,7 +233,7 @@ impl PrefillRouter {
router
,
prefill_request
,
target_worker
,
Some
(
phase_permit
),
Some
(
phase_
transition_
permit
),
)
.await
{
...
...
@@ -507,13 +254,6 @@ impl PrefillRouter {
///
/// This is the shared worker selection logic used by both `resolve_prefill_worker`
/// and `query_route`.
/// Register externally-provided workers in the prefill router's slot tracker.
pub
fn
register_workers
(
&
self
,
worker_ids
:
&
HashSet
<
WorkerId
>
)
{
if
let
Some
(
InnerPrefillRouter
::
KvRouter
(
r
))
=
self
.prefill_router
.get
()
{
r
.chooser
.register_workers
(
worker_ids
);
}
}
pub
async
fn
query_prefill_worker
(
&
self
,
token_ids
:
&
[
u32
],
...
...
@@ -553,194 +293,30 @@ impl PrefillRouter {
r
.peek_next_worker
()
}
.ok_or_else
(||
anyhow
::
anyhow!
(
"No workers available for prefill"
))
?
;
Ok
((
worker_id
,
u32
::
MAX
))
Ok
((
worker_id
,
0
))
}
}
}
/// Register externally-provided workers in the prefill router's slot tracker.
pub
fn
register_workers
(
&
self
,
worker_ids
:
&
HashSet
<
WorkerId
>
)
{
if
let
Some
(
InnerPrefillRouter
::
KvRouter
(
r
))
=
self
.prefill_router
.get
()
{
r
.chooser
.register_workers
(
worker_ids
);
}
}
/// Check if disaggregated mode is currently active (prefill router activated)
pub
fn
is_activated
(
&
self
)
->
bool
{
self
.prefill_router
.get
()
.is_some
()
}
}
impl
Drop
for
PrefillRouter
{
fn
drop
(
&
mut
self
)
{
tracing
::
debug!
(
"Dropping PrefillRouter, cancelling background activation task"
);
self
.cancel_token
.cancel
();
}
}
#[async_trait]
impl
Operator
<
SingleIn
<
PreprocessedRequest
>
,
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
SingleIn
<
PreprocessedRequest
>
,
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
>
for
PrefillRouter
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
PreprocessedRequest
>
,
next
:
ServerStreamingEngine
<
PreprocessedRequest
,
Annotated
<
LLMEngineOutput
>>
,
)
->
Result
<
ManyOut
<
Annotated
<
LLMEngineOutput
>>>
{
// Extract request data while preserving context
let
(
mut
req
,
context
)
=
request
.into_parts
();
let
request_id
=
context
.id
()
.to_string
();
let
engine_ctx
=
context
.context
();
// Save original max_tokens for decode
let
original_max_tokens
=
req
.stop_conditions.max_tokens
;
// If prefill router is not activated (no prefill workers discovered),
// this is aggregated mode — route directly to decode.
// With --enforce-disagg, fail instead of falling back.
if
self
.prefill_router
.get
()
.is_none
()
{
if
self
.enforce_disagg
{
return
Err
(
anyhow
::
anyhow!
(
PrefillError
::
NotActivated
));
}
return
next
.generate
(
context
.map
(|
_
|
req
))
.await
;
}
// Ensure tracker exists for routing decisions in disaggregated mode.
// Create one if not provided by the upstream DeltaGenerator.
if
req
.tracker
.is_none
()
{
req
.tracker
=
Some
(
Arc
::
new
(
RequestTracker
::
new
()));
}
let
tracker
=
req
.tracker
.as_ref
()
.unwrap
();
let
prefill_phase_permit
=
tracker
.set_phase
(
RequestPhase
::
Prefill
)
.await
;
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let
mut
prefill_req
=
req
.clone
();
prefill_req
.stop_conditions.max_tokens
=
Some
(
1
);
// Try to resolve prefill worker upfront: if we can get bootstrap info early,
// spawn prefill in background and proceed to decode immediately.
let
preselected_worker
=
prefill_req
.routing
.as_ref
()
.and_then
(|
r
|
r
.prefill_worker_id
);
if
self
.router_mode
.is_direct_routing
()
&&
preselected_worker
.is_none
()
{
return
Err
(
anyhow
::
anyhow!
(
"Prefill worker ID required in Direct routing mode but none found in request.
\
Expected prefill_worker_id to be set via x-prefill-instance-id header by external router (e.g., EPP)."
));
}
let
prefill_result
=
async
{
if
let
Some
((
worker_id
,
dp_rank
,
bootstrap_info
))
=
self
.resolve_prefill_worker
(
&
prefill_req
,
preselected_worker
)
.await
{
// Bootstrap optimization path: spawn prefill in background
// We successfully used the peeked worker, so we must now advance the router state
// to ensure the next request gets a different worker.
if
!
self
.router_mode
.is_kv_routing
()
&&
let
Some
(
router
)
=
self
.prefill_router
.get
()
{
router
.select_next_worker
();
}
let
routing
=
prefill_req
.routing_mut
();
routing
.prefill_worker_id
=
Some
(
worker_id
);
routing
.dp_rank
=
Some
(
dp_rank
);
prefill_req
.bootstrap_info
=
Some
(
bootstrap_info
.clone
());
let
prefill_context
=
Context
::
with_id
(
prefill_req
,
request_id
.clone
());
engine_ctx
.link_child
(
prefill_context
.context
());
// Pass phase permit to spawned task - it drops after first output (record_worker_full complete)
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self
.spawn_prefill_task
(
prefill_context
,
Some
(
worker_id
),
prefill_phase_permit
);
Ok
(
PrefillOutcome
::
Bootstrap
(
bootstrap_info
))
}
else
{
// Original prefill path: wait for prefill to complete
tracing
::
debug!
(
"Using original prefill path"
);
// Drop the phase permit - we wait for completion
// so there's no race with set_phase(Decode) below
drop
(
prefill_phase_permit
);
let
prefill_context
=
Context
::
with_id
(
prefill_req
,
request_id
.clone
());
engine_ctx
.link_child
(
prefill_context
.context
());
// In Direct mode, pass preselected_worker so execute_prefill uses
// router.direct() instead of router.generate() (which bails in Direct mode).
let
(
result
,
_
worker_info
)
=
Self
::
execute_prefill
(
self
.prefill_router
.get
()
.cloned
(),
prefill_context
,
preselected_worker
,
None
,
)
.await
?
;
Ok
(
PrefillOutcome
::
Completed
(
result
))
}
}
.await
;
// Abort if cancelled during prefill
if
engine_ctx
.is_stopped
()
||
engine_ctx
.is_killed
()
{
tracing
::
debug!
(
"Abort entering decode after context is stopped or killed"
);
return
Err
(
anyhow
::
anyhow!
(
"Context id {} is stopped or killed"
,
engine_ctx
.id
()
));
}
// Handle prefill result
match
prefill_result
{
Ok
(
outcome
)
=>
{
tracing
::
debug!
(
"Prefill completed, proceeding to decode"
);
// Set phase to Decode for the decode request.
// In bootstrap path, this blocks until the spawned prefill task drops its permit
// (after first output / record_worker_full completes), ensuring correct phase for routing.
if
let
Some
(
ref
tracker
)
=
req
.tracker
{
let
_
decode_permit
=
tracker
.set_phase
(
RequestPhase
::
Decode
)
.await
;
// Permit is dropped immediately - decode proceeds, no need to hold it
}
let
mut
decode_req
=
req
;
match
outcome
{
PrefillOutcome
::
Bootstrap
(
info
)
=>
{
decode_req
.bootstrap_info
=
Some
(
info
);
}
PrefillOutcome
::
Completed
(
result
)
=>
{
decode_req
.prefill_result
=
Some
(
result
);
}
}
// Restore original max_tokens for decode
decode_req
.stop_conditions.max_tokens
=
original_max_tokens
;
// Set router_config_override for decode:
// - overlap_score_weight = 0 (no KV cache overlap scoring for decode)
// - assume_kv_reuse = false (generate random hashes since decode workers
// may already have blocks cached from prefill transfer)
let
existing_override
=
decode_req
.router_config_override
.take
();
decode_req
.router_config_override
=
Some
(
RouterConfigOverride
{
overlap_score_weight
:
Some
(
0.0
),
assume_kv_reuse
:
Some
(
false
),
..
existing_override
.unwrap_or_default
()
});
// Map the modified request through with preserved context
let
decode_request
=
context
.map
(|
_
|
decode_req
);
next
.generate
(
decode_request
)
.await
}
Err
(
PrefillError
::
NotActivated
)
=>
{
tracing
::
error!
(
"Prefill router not activated, failing request"
);
Err
(
anyhow
::
anyhow!
(
PrefillError
::
NotActivated
))
}
Err
(
e
)
=>
{
tracing
::
error!
(
error
=
%
e
,
"Remote prefill failed, failing request"
);
Err
(
anyhow
::
anyhow!
(
e
))
}
}
}
pub
(
super
)
fn
link_child_context
<
T
:
Send
+
Sync
+
'static
>
(
engine_ctx
:
&
Arc
<
dyn
AsyncEngineContext
>
,
request
:
T
,
request_id
:
&
str
,
)
->
Context
<
T
>
{
let
child_context
=
Context
::
with_id
(
request
,
request_id
.to_string
());
engine_ctx
.link_child
(
child_context
.context
());
child_context
}
lib/llm/src/kv_router/prefill_router/inner.rs
0 → 100644
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::
Arc
;
use
anyhow
::
Result
;
use
dynamo_runtime
::{
pipeline
::{
AsyncEngine
,
ManyOut
,
PushRouter
,
SingleIn
},
protocols
::
annotated
::
Annotated
,
};
use
crate
::{
kv_router
::
KvPushRouter
,
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
};
/// The inner router used by PrefillRouter
#[derive(Clone)]
pub
(
super
)
enum
InnerPrefillRouter
{
/// KV-aware routing using KvPushRouter
KvRouter
(
Arc
<
KvPushRouter
>
),
/// Simple routing (RoundRobin, Random, Direct)
/// Note: Per-worker metrics (active_prefill_tokens, active_decode_blocks) are only
/// available in KV routing mode where the router has actual bookkeeping.
SimpleRouter
(
Arc
<
PushRouter
<
PreprocessedRequest
,
Annotated
<
LLMEngineOutput
>>>
),
}
impl
InnerPrefillRouter
{
/// Generate with optional direct routing to specific worker.
/// For KvRouter, target_worker is ignored since prefill_worker_id is already set on the request.
/// For SimpleRouter, target_worker triggers direct routing via router.direct().
pub
(
super
)
async
fn
generate_to_worker
(
&
self
,
request
:
SingleIn
<
PreprocessedRequest
>
,
target_worker
:
Option
<
u64
>
,
)
->
Result
<
ManyOut
<
Annotated
<
LLMEngineOutput
>>>
{
match
(
self
,
target_worker
)
{
// KvRouter: prefill_worker_id already set on request, KvPushRouter::select_worker uses it
(
InnerPrefillRouter
::
KvRouter
(
router
),
_
)
=>
router
.generate
(
request
)
.await
,
(
InnerPrefillRouter
::
SimpleRouter
(
router
),
Some
(
worker_id
))
=>
{
router
.direct
(
request
,
worker_id
)
.await
}
(
InnerPrefillRouter
::
SimpleRouter
(
router
),
None
)
=>
router
.generate
(
request
)
.await
,
}
}
/// Select next worker (for non-KV modes only)
pub
(
super
)
fn
select_next_worker
(
&
self
)
->
Option
<
u64
>
{
match
self
{
InnerPrefillRouter
::
SimpleRouter
(
router
)
=>
router
.select_next_worker
(),
InnerPrefillRouter
::
KvRouter
(
_
)
=>
None
,
}
}
}
lib/llm/src/kv_router/prefill_router/mod.rs
0 → 100644
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::{
Arc
,
OnceLock
};
use
anyhow
::
Result
;
use
tokio_util
::
sync
::
CancellationToken
;
use
dynamo_runtime
::{
pipeline
::{
AsyncEngineContextProvider
,
ManyOut
,
Operator
,
RouterMode
,
ServerStreamingEngine
,
SingleIn
,
async_trait
,
},
protocols
::{
EndpointId
,
annotated
::
Annotated
},
};
use
crate
::{
discovery
::
ModelManager
,
protocols
::
common
::{
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
timing
::{
RequestPhase
,
RequestTracker
},
},
};
mod
activation
;
mod
execution
;
mod
inner
;
mod
types
;
use
execution
::
link_child_context
;
use
inner
::
InnerPrefillRouter
;
pub
use
types
::
PrefillError
;
use
types
::{
PrefillOutcome
,
PrefillResolveDecision
,
build_decode_router_override
};
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
///
/// Modes:
/// - Query-only: `query_instance_id` annotation present → returns worker IDs without execution
/// - Pre-routed: `prefill_worker_id`/`decode_worker_id` set → routes to specified workers
/// - Normal: Worker IDs determined by router based on KV cache state
pub
struct
PrefillRouter
{
prefill_router
:
OnceLock
<
InnerPrefillRouter
>
,
model_manager
:
Arc
<
ModelManager
>
,
endpoint_id
:
OnceLock
<
EndpointId
>
,
cancel_token
:
CancellationToken
,
router_mode
:
RouterMode
,
enforce_disagg
:
bool
,
/// Model name used to look up the worker monitor for prefill client registration
model_name
:
String
,
/// Namespace used to look up the correct WorkerSet's worker monitor
namespace
:
String
,
is_eagle
:
bool
,
}
impl
Drop
for
PrefillRouter
{
fn
drop
(
&
mut
self
)
{
tracing
::
debug!
(
"Dropping PrefillRouter, cancelling background activation task"
);
self
.cancel_token
.cancel
();
}
}
#[async_trait]
impl
Operator
<
SingleIn
<
PreprocessedRequest
>
,
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
SingleIn
<
PreprocessedRequest
>
,
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
>
for
PrefillRouter
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
PreprocessedRequest
>
,
next
:
ServerStreamingEngine
<
PreprocessedRequest
,
Annotated
<
LLMEngineOutput
>>
,
)
->
Result
<
ManyOut
<
Annotated
<
LLMEngineOutput
>>>
{
// Extract request data while preserving context
let
(
mut
req
,
context
)
=
request
.into_parts
();
let
request_id
=
context
.id
()
.to_string
();
let
engine_ctx
=
context
.context
();
// Save original max_tokens for decode
let
original_max_tokens
=
req
.stop_conditions.max_tokens
;
// If prefill router is not activated (no prefill workers discovered),
// this is aggregated mode — route directly to decode.
// With --enforce-disagg, fail instead of falling back.
if
self
.prefill_router
.get
()
.is_none
()
{
if
self
.enforce_disagg
{
return
Err
(
anyhow
::
anyhow!
(
PrefillError
::
NotActivated
));
}
return
next
.generate
(
context
.map
(|
_
|
req
))
.await
;
}
// Ensure tracker exists for routing decisions in disaggregated mode.
// Create one if not provided by the upstream DeltaGenerator.
if
req
.tracker
.is_none
()
{
req
.tracker
=
Some
(
Arc
::
new
(
RequestTracker
::
new
()));
}
let
tracker
=
req
.tracker
.as_ref
()
.unwrap
();
let
prefill_phase_barrier
=
tracker
.set_phase
(
RequestPhase
::
Prefill
)
.await
;
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let
mut
prefill_req
=
req
.clone
();
prefill_req
.stop_conditions.max_tokens
=
Some
(
1
);
// Try to resolve prefill worker upfront: if we can get bootstrap info early,
// spawn prefill in background and proceed to decode immediately.
let
preselected_worker
=
prefill_req
.routing
.as_ref
()
.and_then
(|
r
|
r
.prefill_worker_id
);
if
self
.router_mode
.is_direct_routing
()
&&
preselected_worker
.is_none
()
{
return
Err
(
anyhow
::
anyhow!
(
"Prefill worker ID required in Direct routing mode but none found in request.
\
Expected prefill_worker_id to be set via x-prefill-instance-id header by external router (e.g., EPP)."
));
}
let
prefill_result
=
match
self
.resolve_prefill_worker
(
&
prefill_req
,
preselected_worker
)
.await
{
PrefillResolveDecision
::
Resolved
{
worker_id
,
dp_rank
,
bootstrap_info
,
}
=>
{
// Bootstrap optimization path: spawn prefill in background
// We successfully used the peeked worker, so we must now advance the router state
// to ensure the next request gets a different worker.
if
!
self
.router_mode
.is_kv_routing
()
&&
let
Some
(
router
)
=
self
.prefill_router
.get
()
{
router
.select_next_worker
();
}
let
routing
=
prefill_req
.routing_mut
();
routing
.prefill_worker_id
=
Some
(
worker_id
);
routing
.dp_rank
=
Some
(
dp_rank
);
prefill_req
.bootstrap_info
=
Some
(
bootstrap_info
.clone
());
let
prefill_context
=
link_child_context
(
&
engine_ctx
,
prefill_req
,
request_id
.as_str
());
// Pass the phase barrier to the spawned task. It is released after routing
// completes so `record_worker_full` finishes before phase changes to Decode.
self
.spawn_prefill_task
(
prefill_context
,
Some
(
worker_id
),
prefill_phase_barrier
);
Ok
(
PrefillOutcome
::
Bootstrap
(
bootstrap_info
))
}
PrefillResolveDecision
::
Unavailable
|
PrefillResolveDecision
::
NotActivated
|
PrefillResolveDecision
::
NoBootstrapEndpoint
=>
{
// Original prefill path: wait for prefill to complete
tracing
::
debug!
(
"Using original prefill path"
);
// Drop the phase barrier because we wait for prefill completion in this task,
// so there is no race with set_phase(Decode) below.
drop
(
prefill_phase_barrier
);
let
prefill_context
=
link_child_context
(
&
engine_ctx
,
prefill_req
,
request_id
.as_str
());
// In Direct mode, pass preselected_worker so execute_prefill uses
// router.direct() instead of router.generate() (which bails in Direct mode).
let
(
result
,
_
worker_info
)
=
Self
::
execute_prefill
(
self
.prefill_router
.get
()
.cloned
(),
prefill_context
,
preselected_worker
,
None
,
)
.await
?
;
Ok
(
PrefillOutcome
::
Completed
(
result
))
}
};
// Abort if cancelled during prefill
if
engine_ctx
.is_stopped
()
||
engine_ctx
.is_killed
()
{
tracing
::
debug!
(
"Abort entering decode after context is stopped or killed"
);
return
Err
(
anyhow
::
anyhow!
(
"Context id {} is stopped or killed"
,
engine_ctx
.id
()
));
}
// Handle prefill result
match
prefill_result
{
Ok
(
outcome
)
=>
{
tracing
::
debug!
(
"Prefill completed, proceeding to decode"
);
// Set phase to Decode for the decode request.
// In bootstrap path, this blocks until the spawned prefill task releases its
// phase barrier after routing completes, ensuring correct worker attribution.
if
let
Some
(
ref
tracker
)
=
req
.tracker
{
let
_
decode_permit
=
tracker
.set_phase
(
RequestPhase
::
Decode
)
.await
;
// Permit is dropped immediately - decode proceeds, no need to hold it
}
let
mut
decode_req
=
req
;
match
outcome
{
PrefillOutcome
::
Bootstrap
(
info
)
=>
{
decode_req
.bootstrap_info
=
Some
(
info
);
}
PrefillOutcome
::
Completed
(
result
)
=>
{
decode_req
.prefill_result
=
Some
(
result
);
}
}
// Restore original max_tokens for decode
decode_req
.stop_conditions.max_tokens
=
original_max_tokens
;
// Set router_config_override for decode:
// - overlap_score_weight = 0 (no KV cache overlap scoring for decode)
// - assume_kv_reuse = false (generate random hashes since decode workers
// may already have blocks cached from prefill transfer)
// - track_prefill_tokens = false (decode router should ignore prompt-side load)
let
existing_override
=
decode_req
.router_config_override
.take
();
decode_req
.router_config_override
=
Some
(
build_decode_router_override
(
existing_override
));
// Map the modified request through with preserved context
let
decode_request
=
context
.map
(|
_
|
decode_req
);
next
.generate
(
decode_request
)
.await
}
Err
(
PrefillError
::
NotActivated
)
=>
{
tracing
::
error!
(
"Prefill router not activated, failing request"
);
Err
(
anyhow
::
anyhow!
(
PrefillError
::
NotActivated
))
}
Err
(
e
)
=>
{
tracing
::
error!
(
error
=
%
e
,
"Remote prefill failed, failing request"
);
Err
(
anyhow
::
anyhow!
(
e
))
}
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
dynamo_kv_router
::
config
::
RouterConfigOverride
;
#[test]
fn
decode_router_override_disables_overlap_and_prefill_tracking
()
{
let
override_config
=
build_decode_router_override
(
Some
(
RouterConfigOverride
{
router_temperature
:
Some
(
0.7
),
..
Default
::
default
()
}));
assert_eq!
(
override_config
.overlap_score_weight
,
Some
(
0.0
));
assert_eq!
(
override_config
.assume_kv_reuse
,
Some
(
false
));
assert_eq!
(
override_config
.track_prefill_tokens
,
Some
(
false
));
assert_eq!
(
override_config
.router_temperature
,
Some
(
0.7
));
}
}
lib/llm/src/kv_router/prefill_router/types.rs
0 → 100644
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
dynamo_kv_router
::
config
::
RouterConfigOverride
;
use
crate
::
protocols
::
common
::
preprocessor
::{
BootstrapInfo
,
PrefillResult
};
/// Errors that can occur during prefill routing
#[derive(Debug,
thiserror::Error)]
pub
enum
PrefillError
{
/// Prefill router has not been activated yet
#[error(
"Prefill router not yet activated"
)]
NotActivated
,
/// TODO: Separate prefill worker error from prefill router error
/// Error during prefill execution
#[error(
"Prefill execution failed: {0}"
)]
PrefillError
(
String
,
#[source]
Option
<
Box
<
dyn
std
::
error
::
Error
+
Send
+
Sync
+
'static
>>
,
),
/// Disaggregated params not found in prefill response
#[error(
"No disaggregated params in prefill response: {0}"
)]
NoDisaggregatedParams
(
String
),
}
/// Result of the prefill phase in `generate()`.
pub
(
super
)
enum
PrefillOutcome
{
/// Bootstrap optimization: prefill spawned in background, bootstrap info ready
Bootstrap
(
BootstrapInfo
),
/// Synchronous prefill completed with result
Completed
(
PrefillResult
),
}
pub
(
super
)
enum
PrefillResolveDecision
{
Resolved
{
worker_id
:
u64
,
dp_rank
:
u32
,
bootstrap_info
:
BootstrapInfo
,
},
Unavailable
,
NotActivated
,
NoBootstrapEndpoint
,
}
pub
(
super
)
fn
build_decode_router_override
(
existing_override
:
Option
<
RouterConfigOverride
>
,
)
->
RouterConfigOverride
{
RouterConfigOverride
{
overlap_score_weight
:
Some
(
0.0
),
assume_kv_reuse
:
Some
(
false
),
track_prefill_tokens
:
Some
(
false
),
..
existing_override
.unwrap_or_default
()
}
}
lib/llm/src/kv_router/publisher/event_processor.rs
0 → 100644
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
use
std
::
time
::{
Duration
,
Instant
};
use
anyhow
::
Result
;
use
tokio
::
sync
::
mpsc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
dynamo_kv_router
::
RouterEventSink
;
use
dynamo_kv_router
::
indexer
::
LocalKvIndexer
;
use
dynamo_kv_router
::
protocols
::
*
;
use
dynamo_runtime
::
transports
::
event_plane
::
EventPublisher
;
use
dynamo_runtime
::
transports
::
nats
::
NatsQueue
;
use
crate
::
kv_router
::
KV_EVENT_SUBJECT
;
use
super
::{
DEFAULT_MAX_BATCH_BLOCKS
,
kv_publisher_metrics
};
/// Accumulator for in-flight KV cache events that will be merged into a single
/// [`RouterEvent`] before being forwarded to the event sink.
#[derive(Debug)]
pub
(
super
)
struct
BatchingState
{
pub
(
super
)
pending_removed
:
Option
<
KvCacheRemoveData
>
,
pub
(
super
)
pending_stored
:
Option
<
KvCacheStoreData
>
,
pub
(
super
)
next_publish_id
:
u64
,
pub
(
super
)
last_dp_rank
:
u32
,
pub
(
super
)
last_flush_time
:
Instant
,
}
impl
BatchingState
{
pub
(
super
)
fn
new
()
->
Self
{
Self
{
pending_removed
:
None
,
pending_stored
:
None
,
next_publish_id
:
1
,
last_dp_rank
:
0
,
last_flush_time
:
Instant
::
now
(),
}
}
pub
(
super
)
fn
has_pending
(
&
self
)
->
bool
{
self
.pending_removed
.is_some
()
||
self
.pending_stored
.is_some
()
}
pub
(
super
)
fn
pending_block_count
(
&
self
)
->
usize
{
self
.pending_removed
.as_ref
()
.map
(|
r
|
r
.block_hashes
.len
())
.unwrap_or
(
0
)
+
self
.pending_stored
.as_ref
()
.map
(|
s
|
s
.blocks
.len
())
.unwrap_or
(
0
)
}
pub
(
super
)
fn
record_flush_time
(
&
mut
self
)
{
self
.last_flush_time
=
Instant
::
now
();
}
pub
(
super
)
fn
remaining_timeout
(
&
self
,
timeout_ms
:
u64
)
->
Duration
{
let
timeout
=
Duration
::
from_millis
(
timeout_ms
);
let
elapsed
=
self
.last_flush_time
.elapsed
();
if
elapsed
>=
timeout
{
Duration
::
ZERO
}
else
{
timeout
-
elapsed
}
}
pub
(
super
)
fn
is_timeout_elapsed
(
&
self
,
timeout_ms
:
u64
)
->
bool
{
self
.remaining_timeout
(
timeout_ms
)
==
Duration
::
ZERO
}
async
fn
flush
<
P
:
RouterEventSink
+
Send
+
Sync
+
'static
>
(
&
mut
self
,
publisher
:
&
P
,
local_indexer
:
&
Option
<
Arc
<
LocalKvIndexer
>>
,
worker_id
:
u64
,
)
{
if
!
self
.has_pending
()
{
return
;
}
let
id
=
self
.next_publish_id
;
let
dp_rank
=
self
.last_dp_rank
;
if
let
Some
(
data
)
=
self
.pending_removed
.take
()
{
emit
(
publisher
,
local_indexer
,
worker_id
,
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Removed
(
data
),
dp_rank
,
},
)
.await
;
}
if
let
Some
(
data
)
=
self
.pending_stored
.take
()
{
emit
(
publisher
,
local_indexer
,
worker_id
,
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Stored
(
data
),
dp_rank
,
},
)
.await
;
}
self
.next_publish_id
+=
1
;
self
.record_flush_time
();
}
}
pub
(
super
)
struct
EventPlanePublisher
(
pub
(
super
)
EventPublisher
);
impl
RouterEventSink
for
EventPlanePublisher
{
fn
publish_event
(
&
self
,
event
:
&
RouterEvent
)
->
impl
Future
<
Output
=
Result
<
()
>>
+
Send
{
self
.0
.publish
(
event
)
}
}
pub
(
super
)
struct
JetStreamPublisher
(
pub
(
super
)
NatsQueue
);
impl
RouterEventSink
for
JetStreamPublisher
{
fn
publish_event
(
&
self
,
event
:
&
RouterEvent
)
->
impl
Future
<
Output
=
Result
<
()
>>
+
Send
{
NatsQueue
::
publish_event
(
&
self
.0
,
KV_EVENT_SUBJECT
,
event
)
}
}
async
fn
emit
<
P
:
RouterEventSink
>
(
publisher
:
&
P
,
local_indexer
:
&
Option
<
Arc
<
LocalKvIndexer
>>
,
worker_id
:
u64
,
event
:
KvCacheEvent
,
)
{
let
router_event
=
RouterEvent
::
new
(
worker_id
,
event
);
if
let
Some
(
indexer
)
=
local_indexer
&&
let
Err
(
e
)
=
indexer
.apply_event_with_buffer
(
router_event
.clone
())
.await
{
tracing
::
warn!
(
worker_id
,
error
=
%
e
,
"Failed to apply event to local indexer"
);
}
if
let
Err
(
e
)
=
publisher
.publish_event
(
&
router_event
)
.await
{
tracing
::
error!
(
worker_id
,
error
=
%
e
,
"Failed to publish event"
);
}
}
pub
(
super
)
async
fn
run_event_processor_loop
<
P
:
RouterEventSink
+
Send
+
Sync
+
'static
>
(
publisher
:
P
,
worker_id
:
u64
,
cancellation_token
:
CancellationToken
,
mut
rx
:
mpsc
::
UnboundedReceiver
<
PlacementEvent
>
,
local_indexer
:
Option
<
Arc
<
LocalKvIndexer
>>
,
timeout_ms
:
Option
<
u64
>
,
max_batch_blocks
:
usize
,
)
{
let
mut
batching_state
=
BatchingState
::
new
();
let
mut
last_raw_input_id
:
Option
<
u64
>
=
None
;
loop
{
tokio
::
select!
{
_
=
cancellation_token
.cancelled
()
=>
{
tracing
::
info!
(
"KV Event source received cancellation signal"
);
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
break
;
}
event
=
rx
.recv
()
=>
{
let
Some
(
placement_event
)
=
event
else
{
tracing
::
debug!
(
"Event processor channel closed."
);
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
break
;
};
let
raw_event_id
=
placement_event
.event.event_id
;
if
let
Some
(
last_id
)
=
last_raw_input_id
&&
raw_event_id
>
last_id
+
1
{
let
gap
=
raw_event_id
-
last_id
-
1
;
tracing
::
warn!
(
worker_id
,
last_raw_input_id
=
last_id
,
raw_event_id
,
gap
,
"Input event gap detected: raw events dropped before batching"
);
if
let
Some
(
metrics
)
=
kv_publisher_metrics
()
{
metrics
.increment_engines_dropped_events
(
worker_id
,
gap
);
}
else
{
tracing
::
warn!
(
worker_id
,
gap
,
"Failed to record dropped events metric: metrics not initialized"
);
}
}
last_raw_input_id
=
Some
(
raw_event_id
);
if
!
placement_event
.placement
.is_local_gpu
()
{
tracing
::
trace!
(
worker_id
,
?
placement_event
.placement
,
event_id
=
placement_event
.event.event_id
,
"Skipping non-local-GPU placement event"
);
continue
;
}
let
event
=
placement_event
.event
;
tracing
::
trace!
(
"Event processor for worker_id {} processing event: {:?}"
,
worker_id
,
event
.data
);
let
dp_rank_changed
=
batching_state
.has_pending
()
&&
event
.dp_rank
!=
batching_state
.last_dp_rank
;
match
event
.data
{
KvCacheEventData
::
Removed
(
data
)
=>
{
if
batching_state
.pending_stored
.is_some
()
||
dp_rank_changed
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
}
match
&
mut
batching_state
.pending_removed
{
Some
(
pending
)
=>
pending
.block_hashes
.extend
(
data
.block_hashes
),
None
=>
{
batching_state
.pending_removed
=
Some
(
data
);
}
}
}
KvCacheEventData
::
Stored
(
data
)
=>
{
let
should_flush
=
dp_rank_changed
||
batching_state
.pending_removed
.is_some
()
||
batching_state
.pending_stored
.as_ref
()
.is_some_and
(|
p
|
{
data
.parent_hash
!=
p
.blocks
.last
()
.map
(|
b
|
b
.block_hash
)
});
if
should_flush
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
}
match
&
mut
batching_state
.pending_stored
{
Some
(
pending
)
=>
pending
.blocks
.extend
(
data
.blocks
),
None
=>
{
batching_state
.pending_stored
=
Some
(
data
);
}
}
}
KvCacheEventData
::
Cleared
=>
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
emit
(
&
publisher
,
&
local_indexer
,
worker_id
,
KvCacheEvent
{
event_id
:
batching_state
.next_publish_id
,
data
:
KvCacheEventData
::
Cleared
,
dp_rank
:
event
.dp_rank
,
},
)
.await
;
batching_state
.next_publish_id
+=
1
;
}
}
batching_state
.last_dp_rank
=
event
.dp_rank
;
if
batching_state
.has_pending
()
&&
(
timeout_ms
.is_none_or
(|
ms
|
batching_state
.is_timeout_elapsed
(
ms
))
||
batching_state
.pending_block_count
()
>
max_batch_blocks
)
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
}
}
_
=
tokio
::
time
::
sleep
(
timeout_ms
.map
(|
ms
|
batching_state
.remaining_timeout
(
ms
))
.unwrap_or
(
Duration
::
from_secs
(
3600
))
),
if
timeout_ms
.is_some
()
&&
batching_state
.has_pending
()
=>
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
}
}
}
}
pub
(
super
)
async
fn
start_event_processor
<
P
:
RouterEventSink
+
Send
+
Sync
+
'static
>
(
publisher
:
P
,
worker_id
:
u64
,
cancellation_token
:
CancellationToken
,
rx
:
mpsc
::
UnboundedReceiver
<
PlacementEvent
>
,
local_indexer
:
Option
<
Arc
<
LocalKvIndexer
>>
,
batching_timeout_ms
:
Option
<
u64
>
,
)
{
run_event_processor_loop
(
publisher
,
worker_id
,
cancellation_token
,
rx
,
local_indexer
,
batching_timeout_ms
,
DEFAULT_MAX_BATCH_BLOCKS
,
)
.await
}
pub
(
super
)
async
fn
start_event_processor_jetstream
(
publisher
:
NatsQueue
,
worker_id
:
u64
,
cancellation_token
:
CancellationToken
,
rx
:
mpsc
::
UnboundedReceiver
<
PlacementEvent
>
,
local_indexer
:
Option
<
Arc
<
LocalKvIndexer
>>
,
batching_timeout_ms
:
Option
<
u64
>
,
)
{
run_event_processor_loop
(
JetStreamPublisher
(
publisher
),
worker_id
,
cancellation_token
,
rx
,
local_indexer
,
batching_timeout_ms
,
DEFAULT_MAX_BATCH_BLOCKS
,
)
.await
}
lib/llm/src/kv_router/publisher/mod.rs
0 → 100644
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::
Arc
;
use
std
::
sync
::
OnceLock
;
use
std
::
sync
::
atomic
::{
AtomicU64
,
Ordering
};
use
anyhow
::
Result
;
use
tokio
::
sync
::
mpsc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
dynamo_kv_router
::
indexer
::{
KvIndexerMetrics
,
LocalKvIndexer
};
use
dynamo_kv_router
::
protocols
::
*
;
pub
use
dynamo_kv_router
::
zmq_wire
::
create_stored_blocks
;
#[cfg(test)]
use
dynamo_kv_router
::
zmq_wire
::
*
;
use
dynamo_runtime
::
config
::
environment_names
::
nats
as
env_nats
;
use
dynamo_runtime
::
metrics
::
MetricsHierarchy
;
use
dynamo_runtime
::
metrics
::
prometheus_names
::
kv_publisher
;
use
dynamo_runtime
::
traits
::
DistributedRuntimeProvider
;
use
dynamo_runtime
::{
component
::
Component
,
transports
::
nats
::{
NatsQueue
,
Slug
},
};
use
crate
::
kv_router
::{
KV_EVENT_SUBJECT
,
WORKER_KV_INDEXER_BUFFER_SIZE
,
worker_query
::
start_worker_kv_query_endpoint
,
};
mod
event_processor
;
#[cfg(test)]
mod
tests
;
mod
worker_metrics
;
mod
zmq_listener
;
#[cfg(test)]
use
event_processor
::{
BatchingState
,
run_event_processor_loop
};
use
event_processor
::{
EventPlanePublisher
,
start_event_processor
,
start_event_processor_jetstream
,
};
pub
use
worker_metrics
::
WorkerMetricsPublisher
;
use
zmq_listener
::
start_zmq_listener
;
#[cfg(test)]
use
zmq_listener
::{
INITIAL_BACKOFF_MS
,
MAX_BACKOFF_EXPONENT
,
MAX_BACKOFF_MS
,
MAX_CONSECUTIVE_ERRORS
,
calculate_backoff_ms
,
};
const
MAX_BATCHING_TIMEOUT_MS
:
u64
=
15_000
;
pub
const
DEFAULT_BATCHING_TIMEOUT_MS
:
Option
<
u64
>
=
None
;
const
DEFAULT_MAX_BATCH_BLOCKS
:
usize
=
128
;
/// Helper function to create a KV stream name from a component and subject.
///
/// Generates a slugified stream name in the format:
/// `namespace-{namespace}-component-{component}-{subject}`
fn
create_kv_stream_name
(
component
:
&
Component
,
subject
:
&
str
)
->
String
{
Slug
::
slugify
(
&
format!
(
"namespace.{}.component.{}.{}"
,
component
.namespace
()
.name
(),
component
.name
(),
subject
))
.to_string
()
.replace
(
"_"
,
"-"
)
}
/// Metrics for the KV publisher, created via the MetricsHierarchy API.
/// This provides automatic `dynamo_namespace`, `dynamo_component`, and other
/// hierarchy labels for free.
pub
(
super
)
struct
KvPublisherMetrics
{
/// Total number of raw events dropped by engines before reaching publisher
pub
engines_dropped_events_total
:
prometheus
::
IntCounterVec
,
}
static
KV_PUBLISHER_METRICS
:
OnceLock
<
Arc
<
KvPublisherMetrics
>>
=
OnceLock
::
new
();
impl
KvPublisherMetrics
{
/// Create from a Component, memoized in a static OnceLock.
/// Uses the MetricsHierarchy API which auto-prepends `dynamo_component_`,
/// injects hierarchy labels, and registers with the DRT `MetricsRegistry`.
pub
fn
from_component
(
component
:
&
Component
)
->
Arc
<
Self
>
{
KV_PUBLISHER_METRICS
.get_or_init
(||
{
let
metrics
=
component
.metrics
();
match
metrics
.create_intcountervec
(
kv_publisher
::
ENGINES_DROPPED_EVENTS_TOTAL
,
"Total number of raw events dropped by engines before reaching publisher (detected via event_id gaps)"
,
&
[
"worker_id"
],
&
[],
)
{
Ok
(
engines_dropped_events_total
)
=>
{
Arc
::
new
(
Self
{
engines_dropped_events_total
})
}
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to create kv_publisher metrics from component: {}. Using unregistered metrics as fallback."
,
e
);
Arc
::
new
(
Self
::
new_unregistered
())
}
}
})
.clone
()
}
/// Creates unregistered metrics for use when the MetricsRegistry is not available.
/// This is used as a fallback when metric creation fails.
pub
fn
new_unregistered
()
->
Self
{
Self
{
engines_dropped_events_total
:
prometheus
::
IntCounterVec
::
new
(
prometheus
::
Opts
::
new
(
kv_publisher
::
ENGINES_DROPPED_EVENTS_TOTAL
,
"Total number of raw events dropped by engines before reaching publisher (detected via event_id gaps)"
,
),
&
[
"worker_id"
],
)
.expect
(
"failed to create engines_dropped_events_total counter"
),
}
}
/// Increment the engines dropped events counter by the given amount.
pub
fn
increment_engines_dropped_events
(
&
self
,
worker_id
:
u64
,
count
:
u64
)
{
self
.engines_dropped_events_total
.with_label_values
(
&
[
&
worker_id
.to_string
()])
.inc_by
(
count
);
}
}
fn
kv_publisher_metrics
()
->
Option
<
Arc
<
KvPublisherMetrics
>>
{
KV_PUBLISHER_METRICS
.get
()
.cloned
()
}
/// Configure the source of KV events.
/// Currently, only ZMQ is supported.
pub
enum
KvEventSourceConfig
{
Zmq
{
endpoint
:
String
,
topic
:
String
},
}
enum
KvEventSource
{
Zmq
{
zmq_handle
:
tokio
::
task
::
JoinHandle
<
()
>
,
},
}
impl
KvEventSource
{
fn
start
(
component
:
Component
,
worker_id
:
WorkerId
,
kv_block_size
:
u32
,
source_config
:
KvEventSourceConfig
,
cancellation_token
:
CancellationToken
,
tx
:
mpsc
::
UnboundedSender
<
PlacementEvent
>
,
next_event_id
:
Arc
<
AtomicU64
>
,
)
->
Result
<
Self
>
{
match
source_config
{
KvEventSourceConfig
::
Zmq
{
endpoint
,
topic
}
=>
{
let
zmq_handle
=
component
.drt
()
.runtime
()
.secondary
()
.spawn
(
start_zmq_listener
(
endpoint
,
topic
,
worker_id
,
tx
,
cancellation_token
.clone
(),
kv_block_size
,
next_event_id
,
));
Ok
(
KvEventSource
::
Zmq
{
zmq_handle
})
}
}
}
fn
shutdown
(
&
self
)
{
match
self
{
KvEventSource
::
Zmq
{
zmq_handle
}
=>
{
zmq_handle
.abort
();
}
}
}
}
/// A publisher of KV events.
pub
struct
KvEventPublisher
{
/// The size of the KV block.
kv_block_size
:
u32
,
/// The source of KV events.
/// Can be `None` if all events provided through [`KvEventPublisher::publish`].
source
:
Option
<
KvEventSource
>
,
/// The cancellation token.
cancellation_token
:
CancellationToken
,
/// The ID of the local worker emitting placement events.
worker_id
:
WorkerId
,
/// The channel to send events to.
tx
:
mpsc
::
UnboundedSender
<
PlacementEvent
>
,
/// Internal monotonic event ID counter. Shared with the ZMQ listener if present.
next_event_id
:
Arc
<
AtomicU64
>
,
}
impl
KvEventPublisher
{
pub
fn
new
(
component
:
Component
,
kv_block_size
:
u32
,
source_config
:
Option
<
KvEventSourceConfig
>
,
)
->
Result
<
Self
>
{
Self
::
new_with_local_indexer
(
component
,
kv_block_size
,
source_config
,
false
,
0
,
DEFAULT_BATCHING_TIMEOUT_MS
,
)
}
pub
fn
new_with_local_indexer
(
component
:
Component
,
kv_block_size
:
u32
,
source_config
:
Option
<
KvEventSourceConfig
>
,
enable_local_indexer
:
bool
,
dp_rank
:
DpRank
,
batching_timeout_ms
:
Option
<
u64
>
,
)
->
Result
<
Self
>
{
let
cancellation_token
=
CancellationToken
::
new
();
let
batching_timeout_ms
=
batching_timeout_ms
.filter
(|
&
ms
|
{
if
ms
>
MAX_BATCHING_TIMEOUT_MS
{
tracing
::
warn!
(
requested_ms
=
ms
,
max_ms
=
MAX_BATCHING_TIMEOUT_MS
,
"batching_timeout_ms too high, capping to 15s"
);
}
ms
>
0
})
.map
(|
ms
|
ms
.min
(
MAX_BATCHING_TIMEOUT_MS
));
let
(
tx
,
rx
)
=
mpsc
::
unbounded_channel
::
<
PlacementEvent
>
();
let
worker_id
=
component
.drt
()
.connection_id
();
KvPublisherMetrics
::
from_component
(
&
component
);
let
component_name
=
component
.name
();
tracing
::
info!
(
"Initializing KvEventPublisher for worker {worker_id} in component {component_name}"
);
if
enable_local_indexer
{
tracing
::
info!
(
"LocalKvIndexer enabled for worker {worker_id} in component {component_name}"
);
}
let
next_event_id
=
Arc
::
new
(
AtomicU64
::
new
(
0
));
let
mut
source
=
None
;
if
let
Some
(
config
)
=
source_config
{
source
=
Some
(
KvEventSource
::
start
(
component
.clone
(),
worker_id
,
kv_block_size
,
config
,
cancellation_token
.clone
(),
tx
.clone
(),
next_event_id
.clone
(),
)
?
);
}
let
local_indexer
=
if
enable_local_indexer
{
let
metrics
=
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
());
Some
(
Arc
::
new
(
LocalKvIndexer
::
new
(
cancellation_token
.clone
(),
kv_block_size
,
metrics
,
WORKER_KV_INDEXER_BUFFER_SIZE
,
)))
}
else
{
None
};
let
_
local_indexer_query_handle
=
local_indexer
.as_ref
()
.map
(|
local_indexer_ref
|
{
let
component
=
component
.clone
();
let
local_indexer
=
local_indexer_ref
.clone
();
component
.drt
()
.runtime
()
.secondary
()
.spawn
(
start_worker_kv_query_endpoint
(
component
,
worker_id
,
dp_rank
,
local_indexer
,
))
});
let
cancellation_token_clone
=
cancellation_token
.clone
();
let
local_indexer_clone
=
local_indexer
.clone
();
if
enable_local_indexer
{
tracing
::
info!
(
"Using event plane for KV event publishing (local_indexer mode)"
);
let
component_clone
=
component
.clone
();
component
.drt
()
.runtime
()
.secondary
()
.spawn
(
async
move
{
let
event_publisher
=
match
dynamo_runtime
::
transports
::
event_plane
::
EventPublisher
::
for_component
(
&
component_clone
,
KV_EVENT_SUBJECT
,
)
.await
{
Ok
(
publisher
)
=>
publisher
,
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to create event publisher: {}"
,
e
);
return
;
}
};
start_event_processor
(
EventPlanePublisher
(
event_publisher
),
worker_id
,
cancellation_token_clone
,
rx
,
local_indexer_clone
,
batching_timeout_ms
,
)
.await
});
}
else
{
let
stream_name
=
create_kv_stream_name
(
&
component
,
KV_EVENT_SUBJECT
);
let
nats_server
=
std
::
env
::
var
(
env_nats
::
NATS_SERVER
)
.unwrap_or_else
(|
_
|
"nats://localhost:4222"
.to_string
());
let
mut
nats_queue
=
NatsQueue
::
new_without_consumer
(
stream_name
,
nats_server
,
std
::
time
::
Duration
::
from_secs
(
60
),
);
component
.drt
()
.runtime
()
.secondary
()
.spawn
(
async
move
{
if
let
Err
(
e
)
=
nats_queue
.connect
()
.await
{
tracing
::
error!
(
"Failed to connect NatsQueue: {e}"
);
return
;
}
start_event_processor_jetstream
(
nats_queue
,
worker_id
,
cancellation_token_clone
,
rx
,
local_indexer_clone
,
batching_timeout_ms
,
)
.await
});
}
Ok
(
Self
{
kv_block_size
,
source
,
cancellation_token
,
worker_id
,
tx
,
next_event_id
,
})
}
pub
fn
publish
(
&
self
,
event
:
KvCacheEvent
)
->
Result
<
(),
mpsc
::
error
::
SendError
<
KvCacheEvent
>>
{
let
placement_event
=
PlacementEvent
::
local_gpu
(
self
.worker_id
,
event
);
match
self
.tx
.send
(
placement_event
)
{
Ok
(())
=>
Ok
(()),
Err
(
err
)
=>
Err
(
mpsc
::
error
::
SendError
(
err
.0
.event
)),
}
}
pub
fn
next_event_id
(
&
self
)
->
u64
{
self
.next_event_id
.fetch_add
(
1
,
Ordering
::
SeqCst
)
}
pub
fn
kv_block_size
(
&
self
)
->
u32
{
self
.kv_block_size
}
pub
fn
shutdown
(
&
mut
self
)
{
if
!
self
.cancellation_token
.is_cancelled
()
{
self
.cancellation_token
.cancel
();
}
if
let
Some
(
source
)
=
self
.source
.take
()
{
source
.shutdown
();
}
}
}
impl
Drop
for
KvEventPublisher
{
fn
drop
(
&
mut
self
)
{
self
.shutdown
();
}
}
lib/llm/src/kv_router/publisher.rs
→
lib/llm/src/kv_router/publisher
/tests
.rs
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
super
::
*
;
#[allow(unused_imports)]
use
bytes
::
Bytes
;
#[allow(unused_imports)]
use
dynamo_kv_router
::
RouterEventSink
;
#[allow(unused_imports)]
use
rmp_serde
as
rmps
;
#[allow(unused_imports)]
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
#[allow(unused_imports)]
use
std
::
sync
::
atomic
::{
AtomicU32
,
AtomicU64
,
Ordering
};
use
std
::
time
::{
Duration
,
Instant
};
use
anyhow
::
Result
;
use
rmp_serde
as
rmps
;
use
tokio
::
sync
::
mpsc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
zeromq
::{
Socket
,
SocketRecv
,
SubSocket
};
use
dynamo_runtime
::
metrics
::
MetricsHierarchy
;
use
dynamo_runtime
::
traits
::
DistributedRuntimeProvider
;
use
dynamo_runtime
::
transports
::
event_plane
::
EventPublisher
;
use
dynamo_runtime
::{
component
::{
Component
,
Namespace
},
transports
::
nats
::{
NatsQueue
,
Slug
},
};
/// Helper function to create a KV stream name from a component and subject.
///
/// Generates a slugified stream name in the format:
/// `namespace-{namespace}-component-{component}-{subject}`
fn
create_kv_stream_name
(
component
:
&
Component
,
subject
:
&
str
)
->
String
{
Slug
::
slugify
(
&
format!
(
"namespace.{}.component.{}.{}"
,
component
.namespace
()
.name
(),
component
.name
(),
subject
))
.to_string
()
.replace
(
"_"
,
"-"
)
}
use
dynamo_kv_router
::
indexer
::{
KvIndexerMetrics
,
LocalKvIndexer
};
use
dynamo_kv_router
::
protocols
::
*
;
pub
use
dynamo_kv_router
::
zmq_wire
::
create_stored_blocks
;
use
dynamo_kv_router
::
zmq_wire
::
*
;
use
crate
::
kv_router
::{
KV_EVENT_SUBJECT
,
KV_METRICS_SUBJECT
,
WORKER_KV_INDEXER_BUFFER_SIZE
,
worker_query
::
start_worker_kv_query_endpoint
,
};
use
dynamo_runtime
::
config
::
environment_names
::
nats
as
env_nats
;
// Error handling configuration for ZMQ operations
const
INITIAL_BACKOFF_MS
:
u64
=
10
;
const
MAX_BACKOFF_MS
:
u64
=
5000
;
const
MAX_CONSECUTIVE_ERRORS
:
u32
=
10
;
const
MAX_BACKOFF_EXPONENT
:
u32
=
8
;
// Cap at 2^8 = 256x multiplier to prevent overflow
// Batching configuration
const
MAX_BATCHING_TIMEOUT_MS
:
u64
=
15_000
;
// 15 seconds, prevents misconfiguration
pub
const
DEFAULT_BATCHING_TIMEOUT_MS
:
Option
<
u64
>
=
None
;
// disabled by default
const
DEFAULT_MAX_BATCH_BLOCKS
:
usize
=
128
;
// Max blocks to batch before flushing
// ---------------------------------------------------------------------------
// Engines dropped events metric
// ---------------------------------------------------------------------------
use
std
::
sync
::
OnceLock
;
use
dynamo_runtime
::
metrics
::
prometheus_names
::
kv_publisher
;
/// Metrics for the KV publisher, created via the MetricsHierarchy API.
/// This provides automatic `dynamo_namespace`, `dynamo_component`, and other
/// hierarchy labels for free.
pub
struct
KvPublisherMetrics
{
/// Total number of raw events dropped by engines before reaching publisher
pub
engines_dropped_events_total
:
prometheus
::
IntCounterVec
,
}
static
KV_PUBLISHER_METRICS
:
OnceLock
<
Arc
<
KvPublisherMetrics
>>
=
OnceLock
::
new
();
impl
KvPublisherMetrics
{
/// Create from a Component, memoized in a static OnceLock.
/// Uses the MetricsHierarchy API which auto-prepends `dynamo_component_`,
/// injects hierarchy labels, and registers with the DRT `MetricsRegistry`.
pub
fn
from_component
(
component
:
&
Component
)
->
Arc
<
Self
>
{
KV_PUBLISHER_METRICS
.get_or_init
(||
{
let
metrics
=
component
.metrics
();
match
metrics
.create_intcountervec
(
kv_publisher
::
ENGINES_DROPPED_EVENTS_TOTAL
,
"Total number of raw events dropped by engines before reaching publisher (detected via event_id gaps)"
,
&
[
"worker_id"
],
&
[],
)
{
Ok
(
engines_dropped_events_total
)
=>
{
Arc
::
new
(
Self
{
engines_dropped_events_total
})
}
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to create kv_publisher metrics from component: {}. Using unregistered metrics as fallback."
,
e
);
Arc
::
new
(
Self
::
new_unregistered
())
}
}
})
.clone
()
}
/// Creates unregistered metrics for use when the MetricsRegistry is not available.
/// This is used as a fallback when metric creation fails.
pub
fn
new_unregistered
()
->
Self
{
Self
{
engines_dropped_events_total
:
prometheus
::
IntCounterVec
::
new
(
prometheus
::
Opts
::
new
(
kv_publisher
::
ENGINES_DROPPED_EVENTS_TOTAL
,
"Total number of raw events dropped by engines before reaching publisher (detected via event_id gaps)"
,
),
&
[
"worker_id"
],
)
.expect
(
"failed to create engines_dropped_events_total counter"
),
}
}
/// Increment the engines dropped events counter by the given amount.
pub
fn
increment_engines_dropped_events
(
&
self
,
worker_id
:
u64
,
count
:
u64
)
{
self
.engines_dropped_events_total
.with_label_values
(
&
[
&
worker_id
.to_string
()])
.inc_by
(
count
);
}
}
/// Get the KV publisher metrics if initialized.
fn
kv_publisher_metrics
()
->
Option
<
Arc
<
KvPublisherMetrics
>>
{
KV_PUBLISHER_METRICS
.get
()
.cloned
()
}
// -------------------------------------------------------------------------
// Batching State -----------------------------------------------------------
// -------------------------------------------------------------------------
/// Accumulator for in-flight KV cache events that will be merged into a single
/// [`RouterEvent`] before being forwarded to the event sink.
#[derive(Debug)]
struct
BatchingState
{
/// Block hashes accumulating for the next Removed event.
pending_removed
:
Option
<
KvCacheRemoveData
>
,
/// Blocks accumulating for the next Stored event.
pending_stored
:
Option
<
KvCacheStoreData
>
,
/// Monotonic published-batch counter. Increments by 1 per flush so downstream
/// consumers always see consecutive event IDs, regardless of how many raw source
/// events were merged into the batch.
next_publish_id
:
u64
,
/// dp_rank of the events in the current pending batch.
/// A change signals that the batch must be flushed before accumulating further.
last_dp_rank
:
u32
,
/// When we last flushed (or initialized). Used to detect stale pending data:
/// if a new event arrives after a long idle period (exceeding timeout),
/// we flush immediately for lower latency on sparse important events.
last_flush_time
:
Instant
,
}
impl
BatchingState
{
fn
new
()
->
Self
{
Self
{
pending_removed
:
None
,
pending_stored
:
None
,
next_publish_id
:
1
,
last_dp_rank
:
0
,
last_flush_time
:
Instant
::
now
(),
}
}
fn
has_pending
(
&
self
)
->
bool
{
self
.pending_removed
.is_some
()
||
self
.pending_stored
.is_some
()
}
fn
pending_block_count
(
&
self
)
->
usize
{
self
.pending_removed
.as_ref
()
.map
(|
r
|
r
.block_hashes
.len
())
.unwrap_or
(
0
)
+
self
.pending_stored
.as_ref
()
.map
(|
s
|
s
.blocks
.len
())
.unwrap_or
(
0
)
}
/// Records that a flush just happened. Called after every flush to track
/// idle periods for stale-data detection.
fn
record_flush_time
(
&
mut
self
)
{
self
.last_flush_time
=
Instant
::
now
();
}
/// Returns the time remaining in the current batch window (zero if already elapsed).
fn
remaining_timeout
(
&
self
,
timeout_ms
:
u64
)
->
Duration
{
let
timeout
=
Duration
::
from_millis
(
timeout_ms
);
let
elapsed
=
self
.last_flush_time
.elapsed
();
if
elapsed
>=
timeout
{
Duration
::
ZERO
}
else
{
timeout
-
elapsed
}
}
/// Returns `true` when the batch window has elapsed (or `timeout_ms` is zero).
fn
is_timeout_elapsed
(
&
self
,
timeout_ms
:
u64
)
->
bool
{
self
.remaining_timeout
(
timeout_ms
)
==
Duration
::
ZERO
}
}
// -------------------------------------------------------------------------
// KV Event Publishers -----------------------------------------------------
// -------------------------------------------------------------------------
/// Configure the source of KV events.
/// Currently, only ZMQ is supported.
pub
enum
KvEventSourceConfig
{
Zmq
{
endpoint
:
String
,
topic
:
String
},
}
/// The source of KV events.
enum
KvEventSource
{
Zmq
{
zmq_handle
:
tokio
::
task
::
JoinHandle
<
()
>
,
},
}
impl
KvEventSource
{
/// Start the event source from a [`KvEventSourceConfig`].
fn
start
(
component
:
Component
,
worker_id
:
WorkerId
,
kv_block_size
:
u32
,
source_config
:
KvEventSourceConfig
,
cancellation_token
:
CancellationToken
,
tx
:
mpsc
::
UnboundedSender
<
PlacementEvent
>
,
next_event_id
:
Arc
<
AtomicU64
>
,
)
->
Result
<
Self
>
{
match
source_config
{
KvEventSourceConfig
::
Zmq
{
endpoint
,
topic
}
=>
{
let
zmq_handle
=
component
.drt
()
.runtime
()
.secondary
()
.spawn
(
start_zmq_listener
(
endpoint
,
topic
,
worker_id
,
tx
,
cancellation_token
.clone
(),
kv_block_size
,
next_event_id
,
));
Ok
(
KvEventSource
::
Zmq
{
zmq_handle
})
}
}
}
fn
shutdown
(
&
self
)
{
match
self
{
KvEventSource
::
Zmq
{
zmq_handle
}
=>
{
zmq_handle
.abort
();
}
}
}
}
/// A publisher of KV events.
pub
struct
KvEventPublisher
{
/// The size of the KV block.
kv_block_size
:
u32
,
/// The source of KV events.
/// Can be `None` if all events provided through [`KvEventPublisher::publish`].
source
:
Option
<
KvEventSource
>
,
/// The cancellation token.
cancellation_token
:
CancellationToken
,
/// The ID of the local worker emitting placement events.
worker_id
:
WorkerId
,
/// The channel to send events to.
tx
:
mpsc
::
UnboundedSender
<
PlacementEvent
>
,
/// Internal monotonic event ID counter - ensures each event gets a unique, incrementing ID.
/// Shared with the ZMQ listener (if any) to maintain consistency.
next_event_id
:
Arc
<
AtomicU64
>
,
}
impl
KvEventPublisher
{
pub
fn
new
(
component
:
Component
,
kv_block_size
:
u32
,
source_config
:
Option
<
KvEventSourceConfig
>
,
)
->
Result
<
Self
>
{
Self
::
new_with_local_indexer
(
component
,
kv_block_size
,
source_config
,
false
,
0
,
DEFAULT_BATCHING_TIMEOUT_MS
,
)
}
pub
fn
new_with_local_indexer
(
component
:
Component
,
kv_block_size
:
u32
,
source_config
:
Option
<
KvEventSourceConfig
>
,
enable_local_indexer
:
bool
,
dp_rank
:
DpRank
,
batching_timeout_ms
:
Option
<
u64
>
,
)
->
Result
<
Self
>
{
let
cancellation_token
=
CancellationToken
::
new
();
// None = disabled (flush every event); Some(0) normalised to None; Some(ms) = opt-in.
// Cap at MAX_BATCHING_TIMEOUT_MS to prevent misconfiguration.
let
batching_timeout_ms
=
batching_timeout_ms
.filter
(|
&
ms
|
{
if
ms
>
MAX_BATCHING_TIMEOUT_MS
{
tracing
::
warn!
(
requested_ms
=
ms
,
max_ms
=
MAX_BATCHING_TIMEOUT_MS
,
"batching_timeout_ms too high, capping to 15s"
);
}
// if ms is 0, treat as disabled (None)
ms
>
0
})
.map
(|
ms
|
ms
.min
(
MAX_BATCHING_TIMEOUT_MS
));
let
(
tx
,
rx
)
=
mpsc
::
unbounded_channel
::
<
PlacementEvent
>
();
// Infer worker_id from component's connection
let
worker_id
=
component
.drt
()
.connection_id
();
// Initialize the KV publisher metrics via MetricsHierarchy API
// This provides automatic hierarchy labels (dynamo_namespace, dynamo_component, etc.)
KvPublisherMetrics
::
from_component
(
&
component
);
let
component_name
=
component
.name
();
tracing
::
info!
(
"Initializing KvEventPublisher for worker {worker_id} in component {component_name}"
);
if
enable_local_indexer
{
tracing
::
info!
(
"LocalKvIndexer enabled for worker {worker_id} in component {component_name}"
);
}
// Internal monotonic event ID counter - shared with ZMQ listener if any
let
next_event_id
=
Arc
::
new
(
AtomicU64
::
new
(
0
));
// Create our event source (if any)
let
mut
source
=
None
;
if
let
Some
(
config
)
=
source_config
{
source
=
Some
(
KvEventSource
::
start
(
component
.clone
(),
worker_id
,
kv_block_size
,
config
,
cancellation_token
.clone
(),
tx
.clone
(),
next_event_id
.clone
(),
)
?
);
}
// Create local indexer if requested
let
local_indexer
=
if
enable_local_indexer
{
let
metrics
=
Arc
::
new
(
KvIndexerMetrics
::
new_unregistered
());
Some
(
Arc
::
new
(
LocalKvIndexer
::
new
(
cancellation_token
.clone
(),
kv_block_size
,
metrics
,
WORKER_KV_INDEXER_BUFFER_SIZE
,
)))
}
else
{
None
};
// Spawn runtime for router->local indexer comm if requested
let
_
local_indexer_query_handle
=
local_indexer
.as_ref
()
.map
(|
local_indexer_ref
|
{
let
component
=
component
.clone
();
let
local_indexer
=
local_indexer_ref
.clone
();
component
.drt
()
.runtime
()
.secondary
()
.spawn
(
start_worker_kv_query_endpoint
(
component
,
worker_id
,
dp_rank
,
local_indexer
,
))
});
let
cancellation_token_clone
=
cancellation_token
.clone
();
let
local_indexer_clone
=
local_indexer
.clone
();
if
enable_local_indexer
{
// When local indexer is enabled, use the event plane directly.
// EventPublisher handles transport selection (ZMQ or NATS) based on environment.
// Durability is provided by the local indexer's event buffer.
tracing
::
info!
(
"Using event plane for KV event publishing (local_indexer mode)"
);
let
component_clone
=
component
.clone
();
component
.drt
()
.runtime
()
.secondary
()
.spawn
(
async
move
{
let
event_publisher
=
match
EventPublisher
::
for_component
(
&
component_clone
,
KV_EVENT_SUBJECT
)
.await
{
Ok
(
publisher
)
=>
publisher
,
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to create event publisher: {}"
,
e
);
return
;
}
};
start_event_processor
(
EventPlanePublisher
(
event_publisher
),
worker_id
,
cancellation_token_clone
,
rx
,
local_indexer_clone
,
batching_timeout_ms
,
)
.await
});
}
else
{
// When local indexer is disabled, use JetStream (NatsQueue) for durability.
let
stream_name
=
create_kv_stream_name
(
&
component
,
KV_EVENT_SUBJECT
);
let
nats_server
=
std
::
env
::
var
(
env_nats
::
NATS_SERVER
)
.unwrap_or_else
(|
_
|
"nats://localhost:4222"
.to_string
());
let
mut
nats_queue
=
NatsQueue
::
new_without_consumer
(
stream_name
,
nats_server
,
std
::
time
::
Duration
::
from_secs
(
60
),
// 1 minute timeout
);
component
.drt
()
.runtime
()
.secondary
()
.spawn
(
async
move
{
if
let
Err
(
e
)
=
nats_queue
.connect
()
.await
{
tracing
::
error!
(
"Failed to connect NatsQueue: {e}"
);
return
;
}
start_event_processor_jetstream
(
JetStreamPublisher
(
nats_queue
),
worker_id
,
cancellation_token_clone
,
rx
,
local_indexer_clone
,
batching_timeout_ms
,
)
.await
});
}
Ok
(
Self
{
kv_block_size
,
source
,
cancellation_token
,
worker_id
,
tx
,
next_event_id
,
})
}
pub
fn
publish
(
&
self
,
event
:
KvCacheEvent
)
->
Result
<
(),
mpsc
::
error
::
SendError
<
KvCacheEvent
>>
{
let
placement_event
=
PlacementEvent
::
local_gpu
(
self
.worker_id
,
event
);
match
self
.tx
.send
(
placement_event
)
{
Ok
(())
=>
Ok
(()),
Err
(
err
)
=>
Err
(
mpsc
::
error
::
SendError
(
err
.0
.event
)),
}
}
/// Get and increment the next event ID atomically.
/// Use this to assign monotonically increasing event IDs to events before publishing.
pub
fn
next_event_id
(
&
self
)
->
u64
{
self
.next_event_id
.fetch_add
(
1
,
Ordering
::
SeqCst
)
}
pub
fn
kv_block_size
(
&
self
)
->
u32
{
self
.kv_block_size
}
pub
fn
shutdown
(
&
mut
self
)
{
if
!
self
.cancellation_token
.is_cancelled
()
{
self
.cancellation_token
.cancel
();
}
if
let
Some
(
source
)
=
self
.source
.take
()
{
source
.shutdown
();
}
}
}
impl
Drop
for
KvEventPublisher
{
fn
drop
(
&
mut
self
)
{
self
.shutdown
();
}
}
use
dynamo_kv_router
::
RouterEventSink
;
struct
EventPlanePublisher
(
EventPublisher
);
impl
RouterEventSink
for
EventPlanePublisher
{
fn
publish_event
(
&
self
,
event
:
&
RouterEvent
)
->
impl
Future
<
Output
=
Result
<
()
>>
+
Send
{
self
.0
.publish
(
event
)
}
}
struct
JetStreamPublisher
(
NatsQueue
);
impl
RouterEventSink
for
JetStreamPublisher
{
fn
publish_event
(
&
self
,
event
:
&
RouterEvent
)
->
impl
Future
<
Output
=
Result
<
()
>>
+
Send
{
NatsQueue
::
publish_event
(
&
self
.0
,
KV_EVENT_SUBJECT
,
event
)
}
}
/// Publishes a single [`KvCacheEvent`] to the event sink and, when present, the local indexer.
/// Errors are logged and swallowed so the caller loop can continue uninterrupted.
async
fn
emit
<
P
:
RouterEventSink
>
(
publisher
:
&
P
,
local_indexer
:
&
Option
<
Arc
<
LocalKvIndexer
>>
,
worker_id
:
u64
,
event
:
KvCacheEvent
,
)
{
let
router_event
=
RouterEvent
::
new
(
worker_id
,
event
);
if
let
Some
(
indexer
)
=
local_indexer
&&
let
Err
(
e
)
=
indexer
.apply_event_with_buffer
(
router_event
.clone
())
.await
{
tracing
::
warn!
(
worker_id
,
error
=
%
e
,
"Failed to apply event to local indexer"
);
}
if
let
Err
(
e
)
=
publisher
.publish_event
(
&
router_event
)
.await
{
tracing
::
error!
(
worker_id
,
error
=
%
e
,
"Failed to publish event"
);
}
}
impl
BatchingState
{
/// Publishes any pending batch as a single [`RouterEvent`] and advances the monotonic
/// batch ID. No-ops when nothing is pending, so callers may call unconditionally.
async
fn
flush
<
P
:
RouterEventSink
+
Send
+
Sync
+
'static
>
(
&
mut
self
,
publisher
:
&
P
,
local_indexer
:
&
Option
<
Arc
<
LocalKvIndexer
>>
,
worker_id
:
u64
,
)
{
if
!
self
.has_pending
()
{
return
;
}
let
id
=
self
.next_publish_id
;
let
dp_rank
=
self
.last_dp_rank
;
if
let
Some
(
data
)
=
self
.pending_removed
.take
()
{
emit
(
publisher
,
local_indexer
,
worker_id
,
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Removed
(
data
),
dp_rank
,
},
)
.await
;
}
if
let
Some
(
data
)
=
self
.pending_stored
.take
()
{
emit
(
publisher
,
local_indexer
,
worker_id
,
KvCacheEvent
{
event_id
:
id
,
data
:
KvCacheEventData
::
Stored
(
data
),
dp_rank
,
},
)
.await
;
}
// Consecutive batch IDs (1, 2, 3, …) keep downstream gap-detection happy.
self
.next_publish_id
+=
1
;
// Record when we flushed for stale-data detection on next event.
self
.record_flush_time
();
}
}
/// Batching loop: accumulates Removed/Stored events and flushes them as a single
/// [`RouterEvent`] when any of the following conditions are met:
/// - Event type switches (Removed ↔ Stored)
/// - `dp_rank` changes between consecutive events
/// - A `Stored` event's `parent_hash` breaks the sequential chain
/// - The batch window expires (`Some(timeout_ms)`; `None` = disabled, flush every event)
/// - Channel is closed or a cancellation signal is received
async
fn
run_event_processor_loop
<
P
:
RouterEventSink
+
Send
+
Sync
+
'static
>
(
publisher
:
P
,
worker_id
:
u64
,
cancellation_token
:
CancellationToken
,
mut
rx
:
mpsc
::
UnboundedReceiver
<
PlacementEvent
>
,
local_indexer
:
Option
<
Arc
<
LocalKvIndexer
>>
,
timeout_ms
:
Option
<
u64
>
,
max_batch_blocks
:
usize
,
)
{
let
mut
batching_state
=
BatchingState
::
new
();
// Track last raw input event_id for gap detection (dropped events before batching).
// The raw event_id is a globally monotonic counter assigned by the ZMQ listener,
// so any gap here means events were silently dropped (e.g. send error on the channel).
let
mut
last_raw_input_id
:
Option
<
u64
>
=
None
;
loop
{
tokio
::
select!
{
_
=
cancellation_token
.cancelled
()
=>
{
tracing
::
info!
(
"KV Event source received cancellation signal"
);
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
break
;
}
event
=
rx
.recv
()
=>
{
let
Some
(
placement_event
)
=
event
else
{
tracing
::
debug!
(
"Event processor channel closed."
);
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
break
;
};
// Warn if the raw input event_id is not consecutive — events were dropped
// (e.g. channel send error) before they reached the batching layer.
let
raw_event_id
=
placement_event
.event.event_id
;
if
let
Some
(
last_id
)
=
last_raw_input_id
&&
raw_event_id
>
last_id
+
1
{
let
gap
=
raw_event_id
-
last_id
-
1
;
tracing
::
warn!
(
worker_id
,
last_raw_input_id
=
last_id
,
raw_event_id
,
gap
,
"Input event gap detected: raw events dropped before batching"
);
// Increment Prometheus counter for dropped events (if initialized)
if
let
Some
(
metrics
)
=
kv_publisher_metrics
()
{
metrics
.increment_engines_dropped_events
(
worker_id
,
gap
);
}
else
{
tracing
::
warn!
(
worker_id
,
gap
,
"Failed to record dropped events metric: metrics not initialized"
);
}
}
last_raw_input_id
=
Some
(
raw_event_id
);
if
!
placement_event
.placement
.is_local_gpu
()
{
tracing
::
trace!
(
worker_id
,
?
placement_event
.placement
,
event_id
=
placement_event
.event.event_id
,
"Skipping non-local-GPU placement event"
);
continue
;
}
let
event
=
placement_event
.event
;
tracing
::
trace!
(
"Event processor for worker_id {} processing event: {:?}"
,
worker_id
,
event
.data
);
let
dp_rank_changed
=
batching_state
.has_pending
()
&&
event
.dp_rank
!=
batching_state
.last_dp_rank
;
match
event
.data
{
KvCacheEventData
::
Removed
(
data
)
=>
{
if
batching_state
.pending_stored
.is_some
()
||
dp_rank_changed
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
}
match
&
mut
batching_state
.pending_removed
{
Some
(
pending
)
=>
pending
.block_hashes
.extend
(
data
.block_hashes
),
None
=>
{
batching_state
.pending_removed
=
Some
(
data
);
}
}
}
KvCacheEventData
::
Stored
(
data
)
=>
{
// Flush if: type switch, dp_rank change, or the chain is broken
// (new event's parent_hash doesn't continue from the last stored block).
let
should_flush
=
dp_rank_changed
||
batching_state
.pending_removed
.is_some
()
||
batching_state
.pending_stored
.as_ref
()
.is_some_and
(|
p
|
{
data
.parent_hash
!=
p
.blocks
.last
()
.map
(|
b
|
b
.block_hash
)
});
if
should_flush
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
}
match
&
mut
batching_state
.pending_stored
{
// Only extend blocks; parent_hash stays fixed from the first event.
Some
(
pending
)
=>
pending
.blocks
.extend
(
data
.blocks
),
None
=>
{
batching_state
.pending_stored
=
Some
(
data
);
}
}
}
KvCacheEventData
::
Cleared
=>
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
emit
(
&
publisher
,
&
local_indexer
,
worker_id
,
KvCacheEvent
{
event_id
:
batching_state
.next_publish_id
,
data
:
KvCacheEventData
::
Cleared
,
dp_rank
:
event
.dp_rank
,
})
.await
;
batching_state
.next_publish_id
+=
1
;
}
}
// Track dp_rank after the match so in-flight flushes use the old value.
batching_state
.last_dp_rank
=
event
.dp_rank
;
// Flush after every event when disabled (None), or when the window has elapsed,
// or when the batch exceeds the max block count.
// The sleep arm only arms when batching is enabled; this covers the disabled path.
if
batching_state
.has_pending
()
&&
(
timeout_ms
.is_none_or
(|
ms
|
batching_state
.is_timeout_elapsed
(
ms
))
||
batching_state
.pending_block_count
()
>
max_batch_blocks
)
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
}
}
// if has some pending and has timeout, and no new events come in, then flush when timeout elapsed to prevent stale events
_
=
tokio
::
time
::
sleep
(
timeout_ms
.map
(|
ms
|
batching_state
.remaining_timeout
(
ms
))
.unwrap_or
(
Duration
::
from_secs
(
3600
))
),
if
timeout_ms
.is_some
()
&&
batching_state
.has_pending
()
=>
{
batching_state
.flush
(
&
publisher
,
&
local_indexer
,
worker_id
)
.await
;
}
}
}
}
/// Batched event processor for ephemeral transports (NATS Core / ZMQ).
async
fn
start_event_processor
<
P
:
RouterEventSink
+
Send
+
Sync
+
'static
>
(
publisher
:
P
,
worker_id
:
u64
,
cancellation_token
:
CancellationToken
,
rx
:
mpsc
::
UnboundedReceiver
<
PlacementEvent
>
,
local_indexer
:
Option
<
Arc
<
LocalKvIndexer
>>
,
batching_timeout_ms
:
Option
<
u64
>
,
)
{
run_event_processor_loop
(
publisher
,
worker_id
,
cancellation_token
,
rx
,
local_indexer
,
batching_timeout_ms
,
DEFAULT_MAX_BATCH_BLOCKS
,
)
.await
}
/// Batched event processor using JetStream (durable).
async
fn
start_event_processor_jetstream
<
P
:
RouterEventSink
+
Send
+
Sync
+
'static
>
(
publisher
:
P
,
worker_id
:
u64
,
cancellation_token
:
CancellationToken
,
rx
:
mpsc
::
UnboundedReceiver
<
PlacementEvent
>
,
local_indexer
:
Option
<
Arc
<
LocalKvIndexer
>>
,
batching_timeout_ms
:
Option
<
u64
>
,
)
{
run_event_processor_loop
(
publisher
,
worker_id
,
cancellation_token
,
rx
,
local_indexer
,
batching_timeout_ms
,
DEFAULT_MAX_BATCH_BLOCKS
,
)
.await
}
/// Calculate exponential backoff duration based on consecutive error count
fn
calculate_backoff_ms
(
consecutive_errors
:
u32
)
->
u64
{
std
::
cmp
::
min
(
INITIAL_BACKOFF_MS
*
2_u64
.pow
(
consecutive_errors
.min
(
MAX_BACKOFF_EXPONENT
)),
MAX_BACKOFF_MS
,
)
}
pub
async
fn
start_zmq_listener
(
zmq_endpoint
:
String
,
zmq_topic
:
String
,
worker_id
:
WorkerId
,
tx
:
mpsc
::
UnboundedSender
<
PlacementEvent
>
,
cancellation_token
:
CancellationToken
,
kv_block_size
:
u32
,
next_event_id
:
Arc
<
AtomicU64
>
,
)
{
tracing
::
debug!
(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')"
,
zmq_endpoint
,
zmq_topic
);
let
warning_count
=
Arc
::
new
(
AtomicU32
::
new
(
0
));
let
mut
socket
=
SubSocket
::
new
();
// Subscribe to the requested topic (empty string == all topics)
if
let
Err
(
e
)
=
socket
.subscribe
(
&
zmq_topic
)
.await
{
tracing
::
error!
(
"Failed to subscribe on ZMQ socket: {}"
,
e
);
return
;
}
// Connect to the ZMQ endpoint. SGLang binds locally, Dynamo connects.
// In multi-node setups, each node runs dynamo.sglang alongside local SGLang ranks,
// so ZMQ connections are always local. NATS handles cross-node event distribution.
if
let
Err
(
e
)
=
socket
.connect
(
&
zmq_endpoint
)
.await
{
tracing
::
error!
(
"Failed to connect ZMQ SUB socket to {zmq_endpoint}: {e}"
);
return
;
}
let
mut
consecutive_errors
=
0u32
;
#[expect(unused_assignments)]
let
mut
exit_reason
=
"unknown"
;
let
mut
messages_processed
=
0u64
;
'main
:
loop
{
tokio
::
select!
{
biased
;
// Check for cancellation
_
=
cancellation_token
.cancelled
()
=>
{
tracing
::
debug!
(
"ZMQ listener received cancellation signal"
);
exit_reason
=
"cancellation token cancelled"
;
break
'main
;
}
// Receive message
msg_result
=
socket
.recv
()
=>
{
let
Ok
(
msg
)
=
msg_result
else
{
let
e
=
msg_result
.unwrap_err
();
consecutive_errors
+=
1
;
if
consecutive_errors
>=
MAX_CONSECUTIVE_ERRORS
{
tracing
::
error!
(
error
=%
e
,
consecutive_errors
=%
consecutive_errors
,
"Too many consecutive ZMQ errors, terminating listener"
);
exit_reason
=
"too many consecutive errors"
;
break
'main
;
}
// Simple exponential backoff with max exponent to prevent overflow
let
backoff_ms
=
calculate_backoff_ms
(
consecutive_errors
);
tracing
::
warn!
(
error
=%
e
,
consecutive_errors
=%
consecutive_errors
,
backoff_ms
=%
backoff_ms
,
"Error reading from ZMQ socket, applying exponential backoff"
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
backoff_ms
))
.await
;
continue
;
};
// Reset error count on successful message
consecutive_errors
=
0
;
// We expect multipart frames: [topic, seq, payload]
let
mut
frames
:
Vec
<
Vec
<
u8
>>
=
msg
.into_vec
()
.into_iter
()
.map
(|
frame
|
frame
.to_vec
())
.collect
();
if
frames
.len
()
!=
3
{
tracing
::
warn!
(
"Received unexpected ZMQ frame count: expected 3, actual {}"
,
frames
.len
());
continue
;
}
// Extract the payload and sequence number.
let
payload
=
frames
.pop
()
.unwrap
();
let
seq_bytes
=
frames
.pop
()
.unwrap
();
if
seq_bytes
.len
()
!=
8
{
tracing
::
warn!
(
"Invalid sequence number byte length: expected 8, actual {}"
,
seq_bytes
.len
());
continue
;
}
// Note: We extract the engine's sequence number for logging but use our own
// internal monotonic counter for event_id to ensure per-dp_rank monotonicity
let
engine_seq
=
u64
::
from_be_bytes
(
seq_bytes
.try_into
()
.unwrap
());
// Decode our batch of events.
let
batch_result
=
rmps
::
from_slice
::
<
KvEventBatch
>
(
&
payload
);
let
Ok
(
batch
)
=
batch_result
else
{
let
e
=
batch_result
.unwrap_err
();
tracing
::
warn!
(
"Failed to decode KVEventBatch msgpack: {e}"
);
continue
;
};
tracing
::
trace!
(
"ZMQ listener on {} received batch with {} events (engine_seq={}, dp_rank={})"
,
zmq_endpoint
,
batch
.events
.len
(),
engine_seq
,
batch
.data_parallel_rank
.unwrap_or
(
0
)
);
let
dp_rank
=
batch
.data_parallel_rank
.unwrap_or
(
0
)
.cast_unsigned
();
for
raw_event
in
batch
.events
.into_iter
()
{
// Use shared monotonic event_id counter instead of engine's sequence number
let
event_id
=
next_event_id
.fetch_add
(
1
,
Ordering
::
SeqCst
);
let
worker
=
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
);
let
event
=
convert_event
(
raw_event
,
event_id
,
kv_block_size
,
worker
,
&
warning_count
,
);
if
tx
.send
(
event
)
.is_err
()
{
tracing
::
warn!
(
"Failed to send message to channel - receiver dropped"
);
exit_reason
=
"channel receiver dropped"
;
break
'main
;
}
messages_processed
+=
1
;
}
}
}
}
tracing
::
debug!
(
"ZMQ listener exiting, reason: {}, messages processed: {}"
,
exit_reason
,
messages_processed
);
}
// -------------------------------------------------------------------------
// Metrics Publishers ------------------------------------------------------
// -------------------------------------------------------------------------
/// Metrics data passed through the channel for NATS publishing
#[derive(Debug,
Clone,
Default,
PartialEq)]
struct
WorkerMetrics
{
dp_rank
:
DpRank
,
active_decode_blocks
:
u64
,
}
pub
struct
WorkerMetricsPublisher
{
tx
:
tokio
::
sync
::
watch
::
Sender
<
WorkerMetrics
>
,
rx
:
tokio
::
sync
::
watch
::
Receiver
<
WorkerMetrics
>
,
}
impl
WorkerMetricsPublisher
{
pub
fn
new
()
->
Result
<
Self
>
{
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
WorkerMetrics
::
default
());
Ok
(
WorkerMetricsPublisher
{
tx
,
rx
})
}
/// Publish worker metrics for load monitoring.
///
/// # Arguments
/// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
/// * `active_decode_blocks` - Number of active KV cache blocks
pub
fn
publish
(
&
self
,
dp_rank
:
Option
<
DpRank
>
,
active_decode_blocks
:
u64
)
->
Result
<
()
>
{
let
metrics
=
WorkerMetrics
{
dp_rank
:
dp_rank
.unwrap_or
(
0
),
active_decode_blocks
,
};
tracing
::
trace!
(
"Publish metrics: dp_rank={}, active_decode_blocks={}"
,
metrics
.dp_rank
,
metrics
.active_decode_blocks
);
self
.tx
.send
(
metrics
)
.map_err
(|
_
|
anyhow
::
anyhow!
(
"metrics channel closed"
))
}
pub
async
fn
create_endpoint
(
&
self
,
component
:
Component
)
->
Result
<
()
>
{
let
worker_id
=
component
.drt
()
.connection_id
();
self
.start_nats_metrics_publishing
(
component
.namespace
()
.clone
(),
worker_id
);
Ok
(())
}
/// Starts a background task to publish metrics over NATS
///
/// This task monitors metric changes (specifically active_decode_blocks)
/// and publishes stable metrics to NATS after they've been unchanged for 1ms.
fn
start_nats_metrics_publishing
(
&
self
,
namespace
:
Namespace
,
worker_id
:
u64
)
{
let
nats_rx
=
self
.rx
.clone
();
tokio
::
spawn
(
async
move
{
let
event_publisher
=
match
EventPublisher
::
for_namespace
(
&
namespace
,
KV_METRICS_SUBJECT
)
.await
{
Ok
(
publisher
)
=>
publisher
,
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to create metrics publisher: {}"
,
e
);
return
;
}
};
let
mut
rx
=
nats_rx
;
let
mut
last_metrics
:
Option
<
WorkerMetrics
>
=
None
;
let
mut
pending_publish
:
Option
<
WorkerMetrics
>
=
None
;
let
mut
publish_timer
=
Box
::
pin
(
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_secs
(
0
)));
publish_timer
.as_mut
()
.reset
(
tokio
::
time
::
Instant
::
now
());
// Complete immediately
loop
{
tokio
::
select!
{
// Handle metrics changes
result
=
rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
debug!
(
"Metrics publisher sender dropped, stopping NATS background task"
);
break
;
}
let
metrics
=
rx
.borrow_and_update
()
.clone
();
// Check if metrics have changed
let
has_changed
=
last_metrics
.as_ref
()
!=
Some
(
&
metrics
);
// If metrics changed, schedule a publish
if
has_changed
{
pending_publish
=
Some
(
metrics
.clone
());
last_metrics
=
Some
(
metrics
);
// Start the 1ms timer
publish_timer
.as_mut
()
.reset
(
tokio
::
time
::
Instant
::
now
()
+
tokio
::
time
::
Duration
::
from_millis
(
1
)
);
}
}
// Timer expired - publish if we have pending metrics
_
=
&
mut
publish_timer
=>
{
if
let
Some
(
metrics
)
=
pending_publish
.take
()
{
let
active_load
=
ActiveLoad
{
worker_id
,
dp_rank
:
metrics
.dp_rank
,
active_decode_blocks
:
Some
(
metrics
.active_decode_blocks
),
active_prefill_tokens
:
None
,
};
if
let
Err
(
e
)
=
event_publisher
.publish
(
&
active_load
)
.await
{
tracing
::
warn!
(
"Failed to publish metrics: {}"
,
e
);
}
}
// Reset timer to pending state to avoid tight loop
// It will be reset to 1ms when metrics actually change
publish_timer
.as_mut
()
.reset
(
tokio
::
time
::
Instant
::
now
()
+
tokio
::
time
::
Duration
::
from_secs
(
3600
)
);
}
}
}
});
}
}
// -------------------------------------------------------------------------
// Testing -----------------------------------------------------------------
// -------------------------------------------------------------------------
use
std
::
time
::
Duration
;
#[allow(unused_imports)]
use
zeromq
::{
PubSocket
,
Socket
,
SocketSend
,
ZmqMessage
};
#[cfg(test)]
mod
test_event_processing
{
...
...
@@ -1459,9 +430,8 @@ mod test_event_processing {
#[cfg(test)]
mod
tests_startup_helpers
{
use
super
::
*
;
use
crate
::
kv_router
::
KvIndexer
;
use
bytes
::
Bytes
;
use
dynamo_kv_router
::
indexer
::{
GetWorkersRequest
,
KvIndexerInterface
};
use
dynamo_kv_router
::
indexer
::{
GetWorkersRequest
,
KvIndexer
,
KvIndexerInterface
};
use
dynamo_kv_router
::
protocols
::{
ExternalSequenceBlockHash
,
LocalBlockHash
};
use
std
::
sync
::{
Arc
,
Mutex
};
use
zeromq
::{
PubSocket
,
Socket
,
SocketSend
,
ZmqMessage
};
...
...
@@ -2202,6 +1172,7 @@ mod test_exponential_backoff {
#[cfg(all(test,
feature
=
"integration"
))]
mod
test_integration_publisher
{
use
super
::
*
;
use
crate
::
kv_router
::
KV_METRICS_SUBJECT
;
use
dynamo_kv_router
::
protocols
::
ActiveLoad
;
use
dynamo_runtime
::
distributed_test_utils
::
create_test_drt_async
;
use
dynamo_runtime
::
transports
::
event_plane
::
EventSubscriber
;
...
...
lib/llm/src/kv_router/publisher/worker_metrics.rs
0 → 100644
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
anyhow
::
Result
;
use
dynamo_kv_router
::
protocols
::{
ActiveLoad
,
DpRank
};
use
dynamo_runtime
::
component
::{
Component
,
Namespace
};
use
dynamo_runtime
::
traits
::
DistributedRuntimeProvider
;
use
dynamo_runtime
::
transports
::
event_plane
::
EventPublisher
;
use
crate
::
kv_router
::
KV_METRICS_SUBJECT
;
#[derive(Debug,
Clone,
Default,
PartialEq)]
struct
WorkerMetrics
{
dp_rank
:
DpRank
,
active_decode_blocks
:
u64
,
}
pub
struct
WorkerMetricsPublisher
{
tx
:
tokio
::
sync
::
watch
::
Sender
<
WorkerMetrics
>
,
rx
:
tokio
::
sync
::
watch
::
Receiver
<
WorkerMetrics
>
,
}
impl
WorkerMetricsPublisher
{
pub
fn
new
()
->
Result
<
Self
>
{
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
WorkerMetrics
::
default
());
Ok
(
Self
{
tx
,
rx
})
}
pub
fn
publish
(
&
self
,
dp_rank
:
Option
<
DpRank
>
,
active_decode_blocks
:
u64
)
->
Result
<
()
>
{
let
metrics
=
WorkerMetrics
{
dp_rank
:
dp_rank
.unwrap_or
(
0
),
active_decode_blocks
,
};
tracing
::
trace!
(
"Publish metrics: dp_rank={}, active_decode_blocks={}"
,
metrics
.dp_rank
,
metrics
.active_decode_blocks
);
self
.tx
.send
(
metrics
)
.map_err
(|
_
|
anyhow
::
anyhow!
(
"metrics channel closed"
))
}
pub
async
fn
create_endpoint
(
&
self
,
component
:
Component
)
->
Result
<
()
>
{
let
worker_id
=
component
.drt
()
.connection_id
();
self
.start_nats_metrics_publishing
(
component
.namespace
()
.clone
(),
worker_id
);
Ok
(())
}
pub
(
super
)
fn
start_nats_metrics_publishing
(
&
self
,
namespace
:
Namespace
,
worker_id
:
u64
)
{
let
nats_rx
=
self
.rx
.clone
();
tokio
::
spawn
(
async
move
{
let
event_publisher
=
match
EventPublisher
::
for_namespace
(
&
namespace
,
KV_METRICS_SUBJECT
)
.await
{
Ok
(
publisher
)
=>
publisher
,
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to create metrics publisher: {}"
,
e
);
return
;
}
};
let
mut
rx
=
nats_rx
;
let
mut
last_metrics
:
Option
<
WorkerMetrics
>
=
None
;
let
mut
pending_publish
:
Option
<
WorkerMetrics
>
=
None
;
let
mut
publish_timer
=
Box
::
pin
(
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_secs
(
0
)));
publish_timer
.as_mut
()
.reset
(
tokio
::
time
::
Instant
::
now
());
loop
{
tokio
::
select!
{
result
=
rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
debug!
(
"Metrics publisher sender dropped, stopping NATS background task"
);
break
;
}
let
metrics
=
rx
.borrow_and_update
()
.clone
();
let
has_changed
=
last_metrics
.as_ref
()
!=
Some
(
&
metrics
);
if
has_changed
{
pending_publish
=
Some
(
metrics
.clone
());
last_metrics
=
Some
(
metrics
);
publish_timer
.as_mut
()
.reset
(
tokio
::
time
::
Instant
::
now
()
+
tokio
::
time
::
Duration
::
from_millis
(
1
)
);
}
}
_
=
&
mut
publish_timer
=>
{
if
let
Some
(
metrics
)
=
pending_publish
.take
()
{
let
active_load
=
ActiveLoad
{
worker_id
,
dp_rank
:
metrics
.dp_rank
,
active_decode_blocks
:
Some
(
metrics
.active_decode_blocks
),
active_prefill_tokens
:
None
,
};
if
let
Err
(
e
)
=
event_publisher
.publish
(
&
active_load
)
.await
{
tracing
::
warn!
(
"Failed to publish metrics: {}"
,
e
);
}
}
publish_timer
.as_mut
()
.reset
(
tokio
::
time
::
Instant
::
now
()
+
tokio
::
time
::
Duration
::
from_secs
(
3600
)
);
}
}
}
});
}
}
lib/llm/src/kv_router/publisher/zmq_listener.rs
0 → 100644
View file @
02b1c58a
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::
Arc
;
use
std
::
sync
::
atomic
::{
AtomicU32
,
AtomicU64
,
Ordering
};
use
std
::
time
::
Duration
;
use
rmp_serde
as
rmps
;
use
tokio
::
sync
::
mpsc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
zeromq
::{
Socket
,
SocketRecv
,
SubSocket
};
use
dynamo_kv_router
::
protocols
::
*
;
use
dynamo_kv_router
::
zmq_wire
::
*
;
pub
(
super
)
const
INITIAL_BACKOFF_MS
:
u64
=
10
;
pub
(
super
)
const
MAX_BACKOFF_MS
:
u64
=
5000
;
pub
(
super
)
const
MAX_CONSECUTIVE_ERRORS
:
u32
=
10
;
pub
(
super
)
const
MAX_BACKOFF_EXPONENT
:
u32
=
8
;
pub
(
super
)
fn
calculate_backoff_ms
(
consecutive_errors
:
u32
)
->
u64
{
std
::
cmp
::
min
(
INITIAL_BACKOFF_MS
*
2_u64
.pow
(
consecutive_errors
.min
(
MAX_BACKOFF_EXPONENT
)),
MAX_BACKOFF_MS
,
)
}
pub
(
super
)
async
fn
start_zmq_listener
(
zmq_endpoint
:
String
,
zmq_topic
:
String
,
worker_id
:
WorkerId
,
tx
:
mpsc
::
UnboundedSender
<
PlacementEvent
>
,
cancellation_token
:
CancellationToken
,
kv_block_size
:
u32
,
next_event_id
:
Arc
<
AtomicU64
>
,
)
{
tracing
::
debug!
(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')"
,
zmq_endpoint
,
zmq_topic
);
let
warning_count
=
Arc
::
new
(
AtomicU32
::
new
(
0
));
let
mut
socket
=
SubSocket
::
new
();
if
let
Err
(
e
)
=
socket
.subscribe
(
&
zmq_topic
)
.await
{
tracing
::
error!
(
"Failed to subscribe on ZMQ socket: {}"
,
e
);
return
;
}
if
let
Err
(
e
)
=
socket
.connect
(
&
zmq_endpoint
)
.await
{
tracing
::
error!
(
"Failed to connect ZMQ SUB socket to {zmq_endpoint}: {e}"
);
return
;
}
let
mut
consecutive_errors
=
0u32
;
#[expect(unused_assignments)]
let
mut
exit_reason
=
"unknown"
;
let
mut
messages_processed
=
0u64
;
'main
:
loop
{
tokio
::
select!
{
biased
;
_
=
cancellation_token
.cancelled
()
=>
{
tracing
::
debug!
(
"ZMQ listener received cancellation signal"
);
exit_reason
=
"cancellation token cancelled"
;
break
'main
;
}
msg_result
=
socket
.recv
()
=>
{
let
Ok
(
msg
)
=
msg_result
else
{
let
e
=
msg_result
.unwrap_err
();
consecutive_errors
+=
1
;
if
consecutive_errors
>=
MAX_CONSECUTIVE_ERRORS
{
tracing
::
error!
(
error
=%
e
,
consecutive_errors
=%
consecutive_errors
,
"Too many consecutive ZMQ errors, terminating listener"
);
exit_reason
=
"too many consecutive errors"
;
break
'main
;
}
let
backoff_ms
=
calculate_backoff_ms
(
consecutive_errors
);
tracing
::
warn!
(
error
=%
e
,
consecutive_errors
=%
consecutive_errors
,
backoff_ms
=%
backoff_ms
,
"Error reading from ZMQ socket, applying exponential backoff"
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
backoff_ms
))
.await
;
continue
;
};
consecutive_errors
=
0
;
let
mut
frames
:
Vec
<
Vec
<
u8
>>
=
msg
.into_vec
()
.into_iter
()
.map
(|
frame
|
frame
.to_vec
())
.collect
();
if
frames
.len
()
!=
3
{
tracing
::
warn!
(
"Received unexpected ZMQ frame count: expected 3, actual {}"
,
frames
.len
()
);
continue
;
}
let
payload
=
frames
.pop
()
.unwrap
();
let
seq_bytes
=
frames
.pop
()
.unwrap
();
if
seq_bytes
.len
()
!=
8
{
tracing
::
warn!
(
"Invalid sequence number byte length: expected 8, actual {}"
,
seq_bytes
.len
()
);
continue
;
}
let
engine_seq
=
u64
::
from_be_bytes
(
seq_bytes
.try_into
()
.unwrap
());
let
batch_result
=
rmps
::
from_slice
::
<
KvEventBatch
>
(
&
payload
);
let
Ok
(
batch
)
=
batch_result
else
{
let
e
=
batch_result
.unwrap_err
();
tracing
::
warn!
(
"Failed to decode KVEventBatch msgpack: {e}"
);
continue
;
};
tracing
::
trace!
(
"ZMQ listener on {} received batch with {} events (engine_seq={}, dp_rank={})"
,
zmq_endpoint
,
batch
.events
.len
(),
engine_seq
,
batch
.data_parallel_rank
.unwrap_or
(
0
)
);
let
dp_rank
=
batch
.data_parallel_rank
.unwrap_or
(
0
)
.cast_unsigned
();
for
raw_event
in
batch
.events
{
let
event_id
=
next_event_id
.fetch_add
(
1
,
Ordering
::
SeqCst
);
let
worker
=
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
);
let
event
=
convert_event
(
raw_event
,
event_id
,
kv_block_size
,
worker
,
&
warning_count
);
if
tx
.send
(
event
)
.is_err
()
{
tracing
::
warn!
(
"Failed to send message to channel - receiver dropped"
);
exit_reason
=
"channel receiver dropped"
;
break
'main
;
}
messages_processed
+=
1
;
}
}
}
}
tracing
::
debug!
(
"ZMQ listener exiting, reason: {}, messages processed: {}"
,
exit_reason
,
messages_processed
);
}
lib/llm/src/kv_router/remote_indexer.rs
deleted
100644 → 0
View file @
4b8826b3
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
anyhow
::
Result
;
use
futures
::
StreamExt
;
use
dynamo_runtime
::{
component
::
Component
,
pipeline
::{
ManyOut
,
RouterMode
,
SingleIn
,
network
::
egress
::
push_router
::
PushRouter
},
};
use
dynamo_kv_router
::{
indexer
::{
IndexerQueryRequest
,
IndexerQueryResponse
,
KV_INDEXER_QUERY_ENDPOINT
},
protocols
::{
LocalBlockHash
,
OverlapScores
},
};
/// A remote indexer that queries a standalone KV indexer via the request plane.
///
/// Used by the frontend when `remote_indexer_component` is configured. Instead of
/// maintaining a local radix tree, this forwards `find_matches` queries to the
/// standalone indexer service over the Dynamo request plane.
pub
struct
RemoteIndexer
{
router
:
PushRouter
<
IndexerQueryRequest
,
IndexerQueryResponse
>
,
model_name
:
String
,
namespace
:
String
,
}
impl
RemoteIndexer
{
pub
async
fn
new
(
component
:
&
Component
,
indexer_component_name
:
&
str
,
model_name
:
String
,
)
->
Result
<
Self
>
{
let
namespace
=
component
.namespace
()
.name
();
let
indexer_ns
=
component
.namespace
();
let
indexer_component
=
indexer_ns
.component
(
indexer_component_name
)
?
;
let
endpoint
=
indexer_component
.endpoint
(
KV_INDEXER_QUERY_ENDPOINT
);
let
client
=
endpoint
.client
()
.await
?
;
let
router
=
PushRouter
::
from_client_no_fault_detection
(
client
,
RouterMode
::
RoundRobin
)
.await
?
;
Ok
(
Self
{
router
,
model_name
,
namespace
,
})
}
pub
async
fn
find_matches
(
&
self
,
block_hashes
:
Vec
<
LocalBlockHash
>
)
->
Result
<
OverlapScores
>
{
let
request
=
IndexerQueryRequest
{
model_name
:
self
.model_name
.clone
(),
namespace
:
self
.namespace
.clone
(),
block_hashes
,
};
let
mut
stream
:
ManyOut
<
IndexerQueryResponse
>
=
self
.router
.round_robin
(
SingleIn
::
new
(
request
))
.await
?
;
match
stream
.next
()
.await
{
Some
(
IndexerQueryResponse
::
Scores
(
scores
))
=>
Ok
(
scores
.into
()),
Some
(
IndexerQueryResponse
::
Error
(
msg
))
=>
{
Err
(
anyhow
::
anyhow!
(
"Remote indexer error: {}"
,
msg
))
}
None
=>
Err
(
anyhow
::
anyhow!
(
"Remote indexer returned empty response"
)),
}
}
}
lib/llm/src/kv_router/scheduler.rs
View file @
02b1c58a
...
...
@@ -81,6 +81,7 @@ where
block_size
,
selector
,
policy
,
kv_router_config
.router_track_prefill_tokens
,
component
.drt
()
.child_token
(),
worker_type
,
watch_worker_configs
,
...
...
@@ -180,9 +181,10 @@ where
token_seq
:
Option
<
Vec
<
SequenceHash
>>
,
isl_tokens
:
usize
,
overlaps
:
OverlapScores
,
track_prefill_tokens
:
bool
,
)
->
Vec
<
PotentialLoad
>
{
self
.inner
.get_potential_loads
(
token_seq
,
isl_tokens
,
overlaps
)
.get_potential_loads
(
token_seq
,
isl_tokens
,
overlaps
,
track_prefill_tokens
)
}
pub
fn
get_active_lora_counts
(
&
self
)
->
HashMap
<
String
,
usize
>
{
...
...
lib/llm/src/kv_router/sequence.rs
View file @
02b1c58a
...
...
@@ -223,6 +223,7 @@ mod tests {
token_sequence
:
Some
(
vec!
[
0
,
1
,
2
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
worker
:
WorkerWithDpRank
::
new
(
0
,
0
),
lora_name
:
None
,
...
...
@@ -235,6 +236,7 @@ mod tests {
token_sequence
:
Some
(
vec!
[
3
,
4
]),
isl
:
8
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
worker
:
WorkerWithDpRank
::
new
(
0
,
1
),
lora_name
:
None
,
...
...
@@ -247,6 +249,7 @@ mod tests {
token_sequence
:
Some
(
vec!
[
0
,
1
,
2
,
3
]),
isl
:
16
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
worker
:
WorkerWithDpRank
::
new
(
1
,
0
),
lora_name
:
None
,
...
...
@@ -373,6 +376,7 @@ mod tests {
token_sequence
:
None
,
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
worker
:
WorkerWithDpRank
::
from_worker_id
(
0
),
lora_name
:
None
,
...
...
@@ -385,6 +389,7 @@ mod tests {
token_sequence
:
None
,
isl
:
8
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
worker
:
WorkerWithDpRank
::
from_worker_id
(
1
),
lora_name
:
None
,
...
...
@@ -397,6 +402,7 @@ mod tests {
token_sequence
:
None
,
isl
:
16
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
worker
:
WorkerWithDpRank
::
from_worker_id
(
2
),
lora_name
:
None
,
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment