Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
134d484d
Unverified
Commit
134d484d
authored
Apr 15, 2026
by
Yan Ru Pei
Committed by
GitHub
Apr 15, 2026
Browse files
feat(kv-router): add prompt membership index for scheduler reads (#8175)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
d0d9c030
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
3478 additions
and
912 deletions
+3478
-912
lib/kv-router/src/sequences/block_tracker.rs
lib/kv-router/src/sequences/block_tracker.rs
+87
-7
lib/kv-router/src/sequences/mod.rs
lib/kv-router/src/sequences/mod.rs
+4
-0
lib/kv-router/src/sequences/multi_worker.rs
lib/kv-router/src/sequences/multi_worker.rs
+955
-426
lib/kv-router/src/sequences/prefill_tracker.rs
lib/kv-router/src/sequences/prefill_tracker.rs
+260
-4
lib/kv-router/src/sequences/prompt_membership_trie.rs
lib/kv-router/src/sequences/prompt_membership_trie.rs
+796
-0
lib/kv-router/src/sequences/prompt_registry.rs
lib/kv-router/src/sequences/prompt_registry.rs
+534
-0
lib/kv-router/src/sequences/request_maps.rs
lib/kv-router/src/sequences/request_maps.rs
+177
-0
lib/kv-router/src/sequences/single.rs
lib/kv-router/src/sequences/single.rs
+385
-412
lib/kv-router/src/sequences/topology.rs
lib/kv-router/src/sequences/topology.rs
+257
-0
lib/llm/src/kv_router/sequence.rs
lib/llm/src/kv_router/sequence.rs
+0
-52
lib/mocker/src/replay/offline/components/router.rs
lib/mocker/src/replay/offline/components/router.rs
+4
-3
lib/mocker/src/scheduler/sglang/core.rs
lib/mocker/src/scheduler/sglang/core.rs
+6
-3
lib/mocker/src/scheduler/sglang/request.rs
lib/mocker/src/scheduler/sglang/request.rs
+2
-1
lib/mocker/src/scheduler/vllm/core.rs
lib/mocker/src/scheduler/vllm/core.rs
+11
-4
No files found.
lib/kv-router/src/sequences/block_tracker.rs
View file @
134d484d
...
@@ -2,35 +2,50 @@
...
@@ -2,35 +2,50 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: Apache-2.0
use
dynamo_tokens
::
SequenceHash
;
use
dynamo_tokens
::
SequenceHash
;
use
std
::
collections
::
HashMap
;
use
rustc_hash
::
Fx
HashMap
;
use
std
::
sync
::{
Arc
,
Weak
};
use
std
::
sync
::{
Arc
,
Weak
};
#[derive(Debug)]
pub
(
super
)
struct
BlockAcquire
{
pub
(
super
)
rc
:
Arc
<
()
>
,
pub
(
super
)
became_present_on_worker
:
bool
,
}
#[derive(Debug,
Default)]
#[derive(Debug,
Default)]
pub
(
super
)
struct
BlockTracker
{
pub
(
super
)
struct
BlockTracker
{
pub
(
super
)
unique_blocks
:
HashMap
<
SequenceHash
,
Weak
<
()
>>
,
pub
(
super
)
unique_blocks
:
Fx
HashMap
<
SequenceHash
,
Weak
<
()
>>
,
pub
(
super
)
fractional_blocks
:
HashMap
<
SequenceHash
,
f64
>
,
pub
(
super
)
fractional_blocks
:
Fx
HashMap
<
SequenceHash
,
f64
>
,
}
}
impl
BlockTracker
{
impl
BlockTracker
{
pub
(
super
)
fn
touch_block
(
&
mut
self
,
block
:
&
SequenceHash
)
->
Arc
<
()
>
{
pub
(
super
)
fn
touch_block
(
&
mut
self
,
block
:
&
SequenceHash
)
->
BlockAcquire
{
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
&&
let
Some
(
rc
)
=
weak
.upgrade
()
&&
let
Some
(
rc
)
=
weak
.upgrade
()
{
{
return
rc
;
return
BlockAcquire
{
rc
,
became_present_on_worker
:
false
,
};
}
}
let
rc
=
Arc
::
new
(());
let
rc
=
Arc
::
new
(());
self
.unique_blocks
.insert
(
*
block
,
Arc
::
downgrade
(
&
rc
));
self
.unique_blocks
.insert
(
*
block
,
Arc
::
downgrade
(
&
rc
));
rc
BlockAcquire
{
rc
,
became_present_on_worker
:
true
,
}
}
}
pub
(
super
)
fn
try_remove_block
(
&
mut
self
,
block
:
&
SequenceHash
)
{
pub
(
super
)
fn
try_remove_block
(
&
mut
self
,
block
:
&
SequenceHash
)
->
bool
{
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
&&
weak
.strong_count
()
==
0
&&
weak
.strong_count
()
==
0
{
{
self
.unique_blocks
.remove
(
block
);
self
.unique_blocks
.remove
(
block
);
self
.fractional_blocks
.remove
(
block
);
self
.fractional_blocks
.remove
(
block
);
return
true
;
}
}
false
}
}
pub
(
super
)
fn
active_blocks
(
&
self
)
->
usize
{
pub
(
super
)
fn
active_blocks
(
&
self
)
->
usize
{
...
@@ -43,3 +58,68 @@ impl BlockTracker {
...
@@ -43,3 +58,68 @@ impl BlockTracker {
count
.round
()
as
usize
count
.round
()
as
usize
}
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
first_touch_and_last_remove_report_presence_transitions
()
{
let
mut
tracker
=
BlockTracker
::
default
();
let
first
=
tracker
.touch_block
(
&
1
);
let
second
=
tracker
.touch_block
(
&
1
);
assert
!
(
first
.became_present_on_worker
);
assert
!
(
!
second
.became_present_on_worker
);
assert_eq!
(
tracker
.active_blocks
(),
1
);
drop
(
first
.rc
);
assert
!
(
!
tracker
.try_remove_block
(
&
1
));
assert_eq!
(
tracker
.active_blocks
(),
1
);
drop
(
second
.rc
);
assert
!
(
tracker
.try_remove_block
(
&
1
));
assert_eq!
(
tracker
.active_blocks
(),
0
);
}
#[test]
fn
fractional_blocks_adjust_active_block_count
()
{
let
mut
tracker
=
BlockTracker
::
default
();
let
first
=
tracker
.touch_block
(
&
1
);
let
second
=
tracker
.touch_block
(
&
2
);
tracker
.fractional_blocks
.insert
(
1
,
0.5
);
tracker
.fractional_blocks
.insert
(
2
,
0.5
);
assert_eq!
(
tracker
.active_blocks
(),
1
);
drop
(
first
.rc
);
assert
!
(
tracker
.try_remove_block
(
&
1
));
assert
!
(
!
tracker
.fractional_blocks
.contains_key
(
&
1
));
assert_eq!
(
tracker
.active_blocks
(),
1
);
drop
(
second
.rc
);
assert
!
(
tracker
.try_remove_block
(
&
2
));
assert
!
(
tracker
.fractional_blocks
.is_empty
());
assert_eq!
(
tracker
.active_blocks
(),
0
);
}
#[test]
fn
shared_block_counts_once_until_last_reference_drops
()
{
let
mut
tracker
=
BlockTracker
::
default
();
let
first
=
tracker
.touch_block
(
&
7
);
let
second
=
tracker
.touch_block
(
&
7
);
let
third
=
tracker
.touch_block
(
&
7
);
assert_eq!
(
tracker
.active_blocks
(),
1
);
drop
(
first
.rc
);
drop
(
second
.rc
);
assert
!
(
!
tracker
.try_remove_block
(
&
7
));
assert_eq!
(
tracker
.active_blocks
(),
1
);
drop
(
third
.rc
);
assert
!
(
tracker
.try_remove_block
(
&
7
));
assert_eq!
(
tracker
.active_blocks
(),
0
);
}
}
lib/kv-router/src/sequences/mod.rs
View file @
134d484d
...
@@ -4,7 +4,11 @@
...
@@ -4,7 +4,11 @@
mod
block_tracker
;
mod
block_tracker
;
pub
mod
multi_worker
;
pub
mod
multi_worker
;
mod
prefill_tracker
;
mod
prefill_tracker
;
mod
prompt_membership_trie
;
mod
prompt_registry
;
mod
request_maps
;
pub
mod
single
;
pub
mod
single
;
mod
topology
;
pub
use
multi_worker
::
*
;
pub
use
multi_worker
::
*
;
pub
use
single
::
*
;
pub
use
single
::
*
;
lib/kv-router/src/sequences/multi_worker.rs
View file @
134d484d
...
@@ -9,17 +9,20 @@
...
@@ -9,17 +9,20 @@
//! transport (e.g., NATS EventPublisher, Prometheus gauges) so that all business logic lives in
//! transport (e.g., NATS EventPublisher, Prometheus gauges) so that all business logic lives in
//! this crate while the runtime glue stays in `lib/llm`.
//! this crate while the runtime glue stays in `lib/llm`.
use
dashmap
::
DashMap
;
use
dynamo_tokens
::
SequenceHash
;
use
dynamo_tokens
::
SequenceHash
;
use
parking_lot
::
RwLock
;
use
parking_lot
::
RwLock
;
use
std
::
collections
::{
HashMap
,
HashSet
};
use
rustc_hash
::
FxHashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
future
::
Future
;
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
tokio
::
sync
::
watch
;
use
tokio
::
sync
::
watch
;
use
tokio
::
time
::{
Duration
,
Instant
};
use
tokio
::
time
::{
Duration
,
Instant
};
use
tokio_util
::
sync
::
CancellationToken
;
use
tokio_util
::
sync
::
CancellationToken
;
use
super
::
single
::{
ActiveSequences
,
RequestId
};
use
super
::
prompt_registry
::{
PromptRegistry
,
WorkerLoadSnapshot
};
use
super
::
request_maps
::
RequestIndex
;
use
super
::
single
::{
ActiveSequences
,
PromptMembershipDelta
,
RequestId
};
use
super
::
topology
::
WorkerTable
;
use
crate
::
protocols
::{
use
crate
::
protocols
::{
ActiveLoad
,
ActiveSequenceEvent
,
ActiveSequenceEventData
,
OverlapScores
,
PrefillLoadHint
,
ActiveLoad
,
ActiveSequenceEvent
,
ActiveSequenceEventData
,
OverlapScores
,
PrefillLoadHint
,
WorkerWithDpRank
,
WorkerWithDpRank
,
...
@@ -100,35 +103,6 @@ pub struct SequenceRequest {
...
@@ -100,35 +103,6 @@ pub struct SequenceRequest {
pub
lora_name
:
Option
<
String
>
,
pub
lora_name
:
Option
<
String
>
,
}
}
// ---------------------------------------------------------------------------
// WorkerTable
// ---------------------------------------------------------------------------
struct
WorkerTable
{
slots
:
Vec
<
(
WorkerWithDpRank
,
RwLock
<
ActiveSequences
>
)
>
,
index
:
HashMap
<
WorkerWithDpRank
,
usize
>
,
}
impl
WorkerTable
{
fn
new
(
block_size
:
usize
,
dp_range
:
&
HashMap
<
u64
,
(
u32
,
u32
)
>
)
->
Self
{
let
mut
slots
=
Vec
::
new
();
let
mut
index
=
HashMap
::
new
();
for
(
&
worker_id
,
&
(
dp_start
,
dp_size
))
in
dp_range
{
for
dp_rank
in
dp_start
..
dp_start
+
dp_size
{
let
worker
=
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
);
let
idx
=
slots
.len
();
slots
.push
((
worker
,
RwLock
::
new
(
ActiveSequences
::
new
(
block_size
))));
index
.insert
(
worker
,
idx
);
}
}
Self
{
slots
,
index
}
}
}
// ---------------------------------------------------------------------------
// ActiveSequencesMultiWorker
// ---------------------------------------------------------------------------
/// Multi-worker extension of [`ActiveSequences`] with per-worker `parking_lot::RwLock` for
/// Multi-worker extension of [`ActiveSequences`] with per-worker `parking_lot::RwLock` for
/// fine-grained concurrent access.
/// fine-grained concurrent access.
///
///
...
@@ -140,8 +114,8 @@ impl WorkerTable {
...
@@ -140,8 +114,8 @@ impl WorkerTable {
/// and metrics infrastructure.
/// and metrics infrastructure.
pub
struct
ActiveSequencesMultiWorker
<
P
:
SequencePublisher
>
{
pub
struct
ActiveSequencesMultiWorker
<
P
:
SequencePublisher
>
{
workers
:
RwLock
<
WorkerTable
>
,
workers
:
RwLock
<
WorkerTable
>
,
request_
to_worker
:
DashMap
<
RequestId
,
WorkerWithDpRank
>
,
request_
index
:
RequestIndex
,
request_to_lora
:
DashMap
<
RequestId
,
String
>
,
prompt_registry
:
PromptRegistry
,
block_size
:
usize
,
block_size
:
usize
,
router_id
:
u64
,
router_id
:
u64
,
publisher
:
Arc
<
P
>
,
publisher
:
Arc
<
P
>
,
...
@@ -164,11 +138,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -164,11 +138,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
)
->
Self
{
)
->
Self
{
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
let
(
remote_state_updates
,
_
)
=
watch
::
channel
(());
let
(
remote_state_updates
,
_
)
=
watch
::
channel
(());
let
workers
=
WorkerTable
::
new
(
block_size
,
&
dp_range
);
let
prompt_registry
=
PromptRegistry
::
new
(
workers
.workers
());
Self
{
Self
{
workers
:
RwLock
::
new
(
W
orker
Table
::
new
(
block_size
,
&
dp_range
)
),
workers
:
RwLock
::
new
(
w
orker
s
),
request_
to_worker
:
DashMap
::
new
(),
request_
index
:
RequestIndex
::
default
(),
request_to_lora
:
DashMap
::
new
()
,
prompt_registry
,
block_size
,
block_size
,
router_id
,
router_id
,
publisher
:
Arc
::
new
(
publisher
),
publisher
:
Arc
::
new
(
publisher
),
...
@@ -178,6 +154,59 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -178,6 +154,59 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
}
}
}
#[cfg(any(test,
feature
=
"bench"
))]
pub
fn
assert_completely_drained
(
&
self
,
decay_now
:
Instant
)
{
let
active_blocks
=
self
.active_blocks
();
assert
!
(
active_blocks
.values
()
.all
(|
&
count
|
count
==
0
),
"expected all workers to have zero active blocks, got {active_blocks:?}"
,
);
let
active_tokens
=
self
.active_tokens
(
decay_now
);
assert
!
(
active_tokens
.values
()
.all
(|
&
count
|
count
==
0
),
"expected all workers to have zero active tokens, got {active_tokens:?}"
,
);
assert
!
(
self
.request_index
.is_empty
(),
"expected no active request-to-worker mappings, found {}"
,
self
.request_index
.worker_len
(),
);
assert
!
(
self
.get_active_lora_counts
()
.is_empty
(),
"expected no active LoRA counts, found {:?}"
,
self
.get_active_lora_counts
(),
);
assert
!
(
self
.prompt_registry
.is_block_index_empty
(),
"expected reverse block index to be empty after drain"
,
);
}
fn
publish_worker_load_snapshot
(
&
self
,
worker
:
WorkerWithDpRank
,
load
:
WorkerLoadSnapshot
,
decay_now
:
Instant
,
)
{
let
active_blocks
=
load
.active_blocks
;
let
active_tokens
=
load
.active_tokens
(
decay_now
);
self
.publisher
.observe_load
(
&
worker
,
self
.worker_type
,
active_blocks
,
active_tokens
);
let
active_load
=
ActiveLoad
{
worker_id
:
worker
.worker_id
,
dp_rank
:
worker
.dp_rank
,
active_decode_blocks
:
Some
(
active_blocks
as
u64
),
active_prefill_tokens
:
Some
(
active_tokens
as
u64
),
kv_used_blocks
:
None
,
};
self
.publisher
.publish_load
(
active_load
);
}
fn
spawn_publish_event
(
&
self
,
event
:
ActiveSequenceEvent
)
{
fn
spawn_publish_event
(
&
self
,
event
:
ActiveSequenceEvent
)
{
if
!
self
.replica_sync
{
if
!
self
.replica_sync
{
return
;
return
;
...
@@ -257,26 +286,40 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -257,26 +286,40 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
expected_output_tokens
,
expected_output_tokens
,
prefill_load_hint
,
prefill_load_hint
,
}
=>
{
}
=>
{
self
.request_to_worker
self
.ensure_worker_registered
(
event
.worker
);
.insert
(
event
.request_id
.clone
(),
event
.worker
);
if
let
Some
(
ref
lora_name
)
=
event
.lora_name
{
self
.request_to_lora
.insert
(
event
.request_id
.clone
(),
lora_name
.clone
());
}
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
event
.worker
)
{
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
event
.worker
)
{
table
.slots
[
idx
]
.1
.write
()
.add_request_with_prefill_tracking
(
self
.request_index
.set_request
(
event
.request_id
.clone
(),
event
.request_id
.clone
(),
token_sequence
.clone
(),
event
.worker
,
*
isl
,
event
.lora_name
.clone
(),
*
overlap
,
*
expected_output_tokens
,
*
track_prefill_tokens
,
*
prefill_load_hint
,
decay_now
,
);
);
let
(
expired_request_ids
,
load
)
=
{
let
slot
=
&
table
.slots
[
idx
];
let
mut
seq
=
slot
.sequences
.write
();
let
outcome
=
seq
.add_request_with_prefill_tracking
(
event
.request_id
.clone
(),
token_sequence
.clone
(),
*
isl
,
*
overlap
,
*
expected_output_tokens
,
*
track_prefill_tokens
,
*
prefill_load_hint
,
decay_now
,
);
let
load
=
seq
.worker_load_snapshot
();
self
.prompt_registry
.apply_membership_delta_and_load
(
event
.worker
,
&
slot
.trie_lookup
,
outcome
.membership_delta
,
load
,
);
(
outcome
.expired_request_ids
,
load
)
};
drop
(
table
);
self
.request_index
.remove_requests
(
expired_request_ids
.iter
());
self
.publish_worker_load_snapshot
(
event
.worker
,
load
,
decay_now
);
continue
;
}
else
{
}
else
{
tracing
::
warn!
(
tracing
::
warn!
(
"Worker {:?} not found, cannot process AddRequest"
,
"Worker {:?} not found, cannot process AddRequest"
,
...
@@ -285,27 +328,40 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -285,27 +328,40 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
}
}
}
ActiveSequenceEventData
::
Free
=>
{
ActiveSequenceEventData
::
Free
=>
{
if
let
Some
((
_
,
worker
))
=
if
let
Some
(
worker
)
=
self
.request_index
.remove_request
(
&
event
.request_id
)
{
self
.request_to_worker
.remove
(
&
event
.request_id
)
{
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
{
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
{
table
.slots
[
idx
]
.1
.write
()
.free
(
&
event
.request_id
,
decay_now
);
let
load
=
{
let
slot
=
&
table
.slots
[
idx
];
let
mut
seq
=
slot
.sequences
.write
();
let
delta
=
seq
.free
(
&
event
.request_id
,
decay_now
);
let
load
=
seq
.worker_load_snapshot
();
self
.prompt_registry
.apply_membership_delta_and_load
(
worker
,
&
slot
.trie_lookup
,
delta
,
load
,
);
load
};
drop
(
table
);
self
.publish_worker_load_snapshot
(
worker
,
load
,
decay_now
);
remote_capacity_changed
=
true
;
remote_capacity_changed
=
true
;
}
}
}
}
self
.request_to_lora
.remove
(
&
event
.request_id
);
}
}
ActiveSequenceEventData
::
MarkPrefillCompleted
=>
{
ActiveSequenceEventData
::
MarkPrefillCompleted
=>
{
let
worker
=
let
worker
=
self
.request_index
.worker_for
(
&
event
.request_id
);
self
.request_to_worker
.get
(
&
event
.request_id
)
.map
(|
r
|
*
r
);
if
let
Some
(
worker
)
=
worker
{
if
let
Some
(
worker
)
=
worker
{
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
{
if
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
{
table
.slots
[
idx
]
{
.1
let
mut
seq
=
table
.slots
[
idx
]
.sequences
.write
();
.write
()
seq
.mark_prefill_completed
(
&
event
.request_id
,
decay_now
);
.mark_prefill_completed
(
&
event
.request_id
,
decay_now
);
let
load
=
seq
.worker_load_snapshot
();
self
.prompt_registry
.replace_worker_load_state
(
worker
,
load
);
}
drop
(
table
);
remote_capacity_changed
=
true
;
remote_capacity_changed
=
true
;
}
}
}
}
...
@@ -337,151 +393,35 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -337,151 +393,35 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// Worker removal in External mode will be handled separately via GAIE
/// Worker removal in External mode will be handled separately via GAIE
/// lifecycle events (not yet implemented). TODO (atchernych) once we upgrade to GAIE latest.
/// lifecycle events (not yet implemented). TODO (atchernych) once we upgrade to GAIE latest.
pub
fn
register_external_workers
(
&
self
,
dp_range
:
&
HashMap
<
u64
,
(
u32
,
u32
)
>
)
{
pub
fn
register_external_workers
(
&
self
,
dp_range
:
&
HashMap
<
u64
,
(
u32
,
u32
)
>
)
{
let
mut
table
=
self
.workers
.write
();
let
change
=
{
for
(
&
worker_id
,
&
(
dp_start
,
dp_size
))
in
dp_range
{
let
mut
table
=
self
.workers
.write
();
for
dp_rank
in
dp_start
..
(
dp_start
+
dp_size
)
{
table
.register_external
(
self
.block_size
,
dp_range
)
let
worker
=
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
);
};
if
!
table
.index
.contains_key
(
&
worker
)
{
tracing
::
debug!
(
"Lazily registering external worker {:?}"
,
worker
);
for
worker
in
&
change
.added
{
let
idx
=
table
.slots
.len
();
tracing
::
debug!
(
"Lazily registering external worker {:?}"
,
worker
);
table
.slots
.push
((
worker
,
RwLock
::
new
(
ActiveSequences
::
new
(
self
.block_size
))));
table
.index
.insert
(
worker
,
idx
);
}
}
}
}
self
.prompt_registry
.apply_topology_change
(
change
);
}
}
/// Update the set of workers, adding and removing as needed.
/// Update the set of workers, adding and removing as needed.
///
///
/// `new_dp_range` maps worker IDs to their data-parallel range (start, size).
/// `new_dp_range` maps worker IDs to their data-parallel range (start, size).
pub
fn
update_workers
(
&
self
,
new_dp_range
:
&
HashMap
<
u64
,
(
u32
,
u32
)
>
)
{
pub
fn
update_workers
(
&
self
,
new_dp_range
:
&
HashMap
<
u64
,
(
u32
,
u32
)
>
)
{
let
mut
table
=
self
.workers
.write
();
let
change
=
{
let
mut
target_workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
for
(
&
worker_id
,
&
(
dp_start
,
dp_size
))
in
new_dp_range
{
for
dp_rank
in
dp_start
..
(
dp_start
+
dp_size
)
{
target_workers
.insert
(
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
));
}
}
// Clean up request mappings for workers being removed.
for
(
worker
,
_
)
in
&
table
.slots
{
if
target_workers
.contains
(
worker
)
{
continue
;
}
tracing
::
warn!
(
"Removing worker {:?}"
,
worker
);
let
requests_to_remove
:
Vec
<
RequestId
>
=
self
.request_to_worker
.iter
()
.filter
(|
entry
|
entry
.value
()
==
worker
)
.map
(|
entry
|
entry
.key
()
.clone
())
.collect
();
self
.request_to_worker
.retain
(|
_
request_id
,
mapped_worker
|
mapped_worker
!=
worker
);
for
request_id
in
requests_to_remove
{
self
.request_to_lora
.remove
(
&
request_id
);
}
}
// Drain old slots, preserving ActiveSequences for retained workers.
let
mut
old
:
HashMap
<
WorkerWithDpRank
,
ActiveSequences
>
=
table
.slots
.drain
(
..
)
.map
(|(
w
,
lock
)|
(
w
,
lock
.into_inner
()))
.collect
();
table
.index
.clear
();
// Rebuild with target workers, reusing state where possible.
for
worker
in
target_workers
{
if
!
old
.contains_key
(
&
worker
)
{
tracing
::
warn!
(
"Adding worker {:?}"
,
worker
);
}
let
idx
=
table
.slots
.len
();
let
seq
=
old
.remove
(
&
worker
)
.unwrap_or_else
(||
ActiveSequences
::
new
(
self
.block_size
));
table
.slots
.push
((
worker
,
RwLock
::
new
(
seq
)));
table
.index
.insert
(
worker
,
idx
);
}
}
fn
add_request_local
(
&
self
,
req
:
SequenceRequest
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
let
SequenceRequest
{
request_id
,
token_sequence
,
isl
,
overlap
,
track_prefill_tokens
,
expected_output_tokens
,
prefill_load_hint
,
worker
,
lora_name
,
}
=
req
;
if
!
self
.workers
.read
()
.index
.contains_key
(
&
worker
)
{
// The selector already picked this worker from the discovery watch,
// but the slot tracker hasn't been updated yet. Lazily register it
// so we don't drop tracking for this request.
let
mut
table
=
self
.workers
.write
();
let
mut
table
=
self
.workers
.write
();
if
!
table
.index
.contains_key
(
&
worker
)
{
table
.reconcile
(
self
.block_size
,
new_dp_range
)
tracing
::
debug!
(
?
worker
,
"Lazily registering worker in slot tracker"
);
let
idx
=
table
.slots
.len
();
table
.slots
.push
((
worker
,
RwLock
::
new
(
ActiveSequences
::
new
(
self
.block_size
))));
table
.index
.insert
(
worker
,
idx
);
}
}
if
let
Some
(
existing_worker
)
=
self
.request_to_worker
.get
(
&
request_id
)
{
return
Err
(
SequenceError
::
DuplicateRequest
{
request_id
,
worker
:
*
existing_worker
,
});
}
self
.request_to_worker
.insert
(
request_id
.clone
(),
worker
);
if
let
Some
(
lora
)
=
lora_name
{
self
.request_to_lora
.insert
(
request_id
.clone
(),
lora
);
}
let
removed_requests
=
{
let
table
=
self
.workers
.read
();
let
&
idx
=
table
.index
.get
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
let
mut
seq
=
table
.slots
[
idx
]
.1
.write
();
seq
.add_request_with_prefill_tracking
(
request_id
,
token_sequence
,
isl
,
overlap
,
expected_output_tokens
,
track_prefill_tokens
,
prefill_load_hint
,
decay_now
,
)
};
};
for
expired_id
in
&
removed_requests
{
for
removed
in
&
change
.removed
{
self
.request_to_worker
.remove
(
expired_id
);
tracing
::
warn!
(
"Removing worker {:?}"
,
removed
.worker
);
self
.request_to_lora
.remove
(
expired_id
);
self
.request_index
.remove_worker_requests
(
removed
.worker
);
}
for
worker
in
&
change
.added
{
tracing
::
warn!
(
"Adding worker {:?}"
,
worker
);
}
}
self
.publish_active_load_for_worker
(
worker
,
decay_now
);
self
.prompt_registry
.apply_topology_change
(
change
);
Ok
(())
}
}
pub
fn
add_request
(
pub
fn
add_request
(
...
@@ -489,7 +429,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -489,7 +429,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
req
:
SequenceRequest
,
req
:
SequenceRequest
,
decay_now
:
Instant
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
)
->
Result
<
(),
SequenceError
>
{
self
.spawn_publish_
event
(
ActiveSequenceEvent
{
let
event
=
ActiveSequenceEvent
{
request_id
:
req
.request_id
.clone
(),
request_id
:
req
.request_id
.clone
(),
worker
:
req
.worker
,
worker
:
req
.worker
,
data
:
ActiveSequenceEventData
::
AddRequest
{
data
:
ActiveSequenceEventData
::
AddRequest
{
...
@@ -502,78 +442,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -502,78 +442,12 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
},
},
router_id
:
self
.router_id
,
router_id
:
self
.router_id
,
lora_name
:
req
.lora_name
.clone
(),
lora_name
:
req
.lora_name
.clone
(),
});
};
self
.add_request_local
(
req
,
decay_now
)
self
.add_request_local
(
req
,
decay_now
)
?
;
}
self
.spawn_publish_event
(
event
);
/// Send a mutation to the worker assigned to a request, optionally publishing
/// a replica-sync event and cleaning up request mappings afterward.
fn
mutate_request_worker_local
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
),
remove_mapping
:
bool
,
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
.request_to_worker
.get
(
request_id
)
.map
(|
entry
|
*
entry
)
.ok_or_else
(||
SequenceError
::
RequestNotFound
{
request_id
:
request_id
.clone
(),
})
?
;
{
let
table
=
self
.workers
.read
();
let
&
idx
=
table
.index
.get
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
let
mut
seq
=
table
.slots
[
idx
]
.1
.write
();
mutate_fn
(
&
mut
seq
,
request_id
,
decay_now
);
}
if
remove_mapping
{
self
.request_to_worker
.remove
(
request_id
);
self
.request_to_lora
.remove
(
request_id
);
}
self
.publish_active_load_for_worker
(
worker
,
decay_now
);
Ok
(())
Ok
(())
}
}
fn
mutate_request_worker
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
event_data
:
ActiveSequenceEventData
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
),
remove_mapping
:
bool
,
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
.request_to_worker
.get
(
request_id
)
.map
(|
entry
|
*
entry
)
.ok_or_else
(||
SequenceError
::
RequestNotFound
{
request_id
:
request_id
.clone
(),
})
?
;
let
lora_name
=
self
.request_to_lora
.get
(
request_id
)
.map
(|
entry
|
entry
.value
()
.clone
());
self
.spawn_publish_event
(
ActiveSequenceEvent
{
request_id
:
request_id
.clone
(),
worker
,
data
:
event_data
,
router_id
:
self
.router_id
,
lora_name
,
});
self
.mutate_request_worker_local
(
request_id
,
decay_now
,
mutate_fn
,
remove_mapping
)
}
/// Free all blocks associated with a request.
/// Free all blocks associated with a request.
///
///
/// Note: This operation is idempotent. Calling it multiple times for the same request
/// Note: This operation is idempotent. Calling it multiple times for the same request
...
@@ -583,20 +457,20 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -583,20 +457,20 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// [`ActiveSequences::free`], so callers do not need to call
/// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request.
/// [`Self::mark_prefill_completed`] before freeing a completed request.
pub
fn
free
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
)
->
Result
<
(),
SequenceError
>
{
pub
fn
free
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
)
->
Result
<
(),
SequenceError
>
{
if
!
self
.request_to_worker
.contains_key
(
request_id
)
{
match
self
.mutate_request_worker_prompt_state
(
tracing
::
debug!
(
"Request {request_id} not found, already freed (idempotent)"
);
return
Ok
(());
}
self
.mutate_request_worker
(
request_id
,
request_id
,
decay_now
,
decay_now
,
ActiveSequenceEventData
::
Free
,
ActiveSequenceEventData
::
Free
,
|
seqs
,
rid
,
decay_now
|
{
|
seqs
,
rid
,
decay_now
|
seqs
.free
(
rid
,
decay_now
),
seqs
.free
(
rid
,
decay_now
);
},
true
,
true
,
)
)
{
Ok
(())
=>
Ok
(()),
Err
(
SequenceError
::
RequestNotFound
{
..
})
=>
{
tracing
::
debug!
(
"Request {request_id} not found, already freed (idempotent)"
);
Ok
(())
}
Err
(
err
)
=>
Err
(
err
),
}
}
}
/// Mark prefill as completed for a request.
/// Mark prefill as completed for a request.
...
@@ -608,14 +482,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -608,14 +482,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
request_id
:
&
RequestId
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
)
->
Result
<
(),
SequenceError
>
{
self
.mutate_request_worker
(
self
.mutate_request_worker
_load_state
(
request_id
,
request_id
,
decay_now
,
decay_now
,
ActiveSequenceEventData
::
MarkPrefillCompleted
,
ActiveSequenceEventData
::
MarkPrefillCompleted
,
|
seqs
,
rid
,
decay_now
|
{
|
seqs
,
rid
,
decay_now
|
{
seqs
.mark_prefill_completed
(
rid
,
decay_now
);
seqs
.mark_prefill_completed
(
rid
,
decay_now
);
},
},
false
,
)
)
}
}
...
@@ -630,63 +503,37 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -630,63 +503,37 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
request_id
:
&
RequestId
,
request_id
:
&
RequestId
,
decay_fraction
:
Option
<
f64
>
,
decay_fraction
:
Option
<
f64
>
,
)
->
Result
<
(),
SequenceError
>
{
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
let
worker
=
self
.request_index
.worker_for
(
request_id
)
.ok_or_else
(||
{
.request_to_worker
SequenceError
::
RequestNotFound
{
.get
(
request_id
)
.map
(|
entry
|
*
entry
)
.ok_or_else
(||
SequenceError
::
RequestNotFound
{
request_id
:
request_id
.clone
(),
})
?
;
let
success
=
{
let
table
=
self
.workers
.read
();
let
&
idx
=
table
.index
.get
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
let
mut
seq
=
table
.slots
[
idx
]
.1
.write
();
seq
.add_output_block
(
request_id
,
decay_fraction
)
};
if
!
success
{
return
Err
(
SequenceError
::
RequestNotFound
{
request_id
:
request_id
.clone
(),
request_id
:
request_id
.clone
(),
});
}
}
})
?
;
self
.publish_active_load_for_worker
(
worker
,
Instant
::
now
());
Ok
(())
}
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
let
load
=
{
fn
publish_active_load_for_worker
(
&
self
,
worker
:
WorkerWithDpRank
,
decay_now
:
Instant
)
{
let
(
active_blocks
,
active_tokens
)
=
{
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
else
{
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
else
{
tracing
::
warn!
(
"Worker {worker:?} not found when publishing ActiveLoad"
);
drop
(
table
);
return
;
return
Err
(
self
.stale_request_not_found
(
request_id
,
worker
,
"add_output_block"
))
;
};
};
let
seq
=
table
.slots
[
idx
]
.1
.read
();
let
mut
seq
=
table
.slots
[
idx
]
.sequences
.write
();
(
seq
.active_blocks
(),
seq
.active_tokens
(
decay_now
))
let
Some
(
_
new_block_hash
)
=
seq
.add_output_block
(
request_id
,
decay_fraction
)
else
{
return
Err
(
SequenceError
::
RequestNotFound
{
request_id
:
request_id
.clone
(),
});
};
let
load
=
seq
.worker_load_snapshot
();
self
.prompt_registry
.replace_worker_load_state
(
worker
,
load
);
load
};
};
self
.publisher
self
.publish_worker_load_snapshot
(
worker
,
load
,
Instant
::
now
());
.observe_load
(
&
worker
,
self
.worker_type
,
active_blocks
,
active_tokens
);
let
active_load
=
ActiveLoad
{
worker_id
:
worker
.worker_id
,
dp_rank
:
worker
.dp_rank
,
active_decode_blocks
:
Some
(
active_blocks
as
u64
),
active_prefill_tokens
:
Some
(
active_tokens
as
u64
),
kv_used_blocks
:
None
,
};
self
.publisher
.publish_load
(
active_load
);
Ok
(())
}
}
/// Get the number of workers.
/// Get the number of workers.
pub
fn
num_workers
(
&
self
)
->
usize
{
#[cfg(test)]
pub
(
crate
)
fn
num_workers
(
&
self
)
->
usize
{
self
.workers
.read
()
.slots
.len
()
self
.workers
.read
()
.slots
.len
()
}
}
...
@@ -699,8 +546,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -699,8 +546,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
pub
fn
new_blocks
(
&
self
,
token_sequence
:
&
[
SequenceHash
])
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
pub
fn
new_blocks
(
&
self
,
token_sequence
:
&
[
SequenceHash
])
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
let
mut
results
=
HashMap
::
with_capacity
(
table
.slots
.len
());
let
mut
results
=
HashMap
::
with_capacity
(
table
.slots
.len
());
for
(
worker
,
lock
)
in
&
table
.slots
{
for
slot
in
&
table
.slots
{
results
.insert
(
*
worker
,
lock
.read
()
.new_blocks
(
token_sequence
));
results
.insert
(
slot
.worker
,
slot
.sequences
.read
()
.new_blocks
(
token_sequence
),
);
}
}
results
results
}
}
...
@@ -712,8 +562,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -712,8 +562,11 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
let
mut
results
=
HashMap
::
with_capacity
(
table
.slots
.len
());
let
mut
results
=
HashMap
::
with_capacity
(
table
.slots
.len
());
for
(
worker
,
lock
)
in
&
table
.slots
{
for
slot
in
&
table
.slots
{
results
.insert
(
*
worker
,
lock
.read
()
.potential_blocks
(
token_sequence
));
results
.insert
(
slot
.worker
,
slot
.sequences
.read
()
.potential_blocks
(
token_sequence
),
);
}
}
results
results
}
}
...
@@ -726,8 +579,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -726,8 +579,8 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
overlaps
:
OverlapScores
,
overlaps
:
OverlapScores
,
decay_now
:
Instant
,
decay_now
:
Instant
,
)
->
(
)
->
(
HashMap
<
WorkerWithDpRank
,
usize
>
,
Fx
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
Fx
HashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
)
{
self
.potential_blocks_and_tokens_with_prefill_tracking
(
self
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
token_sequence
,
...
@@ -746,91 +599,42 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -746,91 +599,42 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
track_prefill_tokens
:
bool
,
track_prefill_tokens
:
bool
,
decay_now
:
Instant
,
decay_now
:
Instant
,
)
->
(
)
->
(
HashMap
<
WorkerWithDpRank
,
usize
>
,
Fx
HashMap
<
WorkerWithDpRank
,
usize
>
,
HashMap
<
WorkerWithDpRank
,
usize
>
,
Fx
HashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
)
{
#[cfg(feature
=
"bench"
)]
self
.prompt_registry
let
start
=
tokio
::
time
::
Instant
::
now
();
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
let
table
=
self
.workers
.read
();
isl
,
&
overlaps
,
#[cfg(feature
=
"bench"
)]
track_prefill_tokens
,
let
num_workers
=
table
.slots
.len
();
self
.block_size
,
decay_now
,
let
mut
potential_blocks
=
HashMap
::
with_capacity
(
table
.slots
.len
());
)
let
mut
potential_tokens
=
HashMap
::
with_capacity
(
table
.slots
.len
());
for
(
worker
,
lock
)
in
&
table
.slots
{
let
overlap
=
*
overlaps
.scores
.get
(
worker
)
.unwrap_or
(
&
0
);
let
(
blocks
,
tokens
)
=
lock
.read
()
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
isl
,
overlap
,
track_prefill_tokens
,
decay_now
,
);
potential_blocks
.insert
(
*
worker
,
blocks
);
potential_tokens
.insert
(
*
worker
,
tokens
);
}
#[cfg(feature
=
"bench"
)]
{
let
total_elapsed
=
start
.elapsed
();
tracing
::
info!
(
num_workers
,
total_us
=
total_elapsed
.as_micros
()
as
u64
,
"potential_blocks_and_tokens completed"
);
}
(
potential_blocks
,
potential_tokens
)
}
}
/// Query all workers for their current number of active blocks.
/// Query all workers for their current number of active blocks.
pub
fn
active_blocks
(
&
self
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
pub
fn
active_blocks
(
&
self
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
let
table
=
self
.workers
.read
();
self
.prompt_registry
.active_blocks
()
let
mut
results
=
HashMap
::
with_capacity
(
table
.slots
.len
());
for
(
worker
,
lock
)
in
&
table
.slots
{
results
.insert
(
*
worker
,
lock
.read
()
.active_blocks
());
}
results
}
}
/// Query all workers for their current number of active tokens.
/// Query all workers for their current number of active tokens.
pub
fn
active_tokens
(
&
self
,
decay_now
:
Instant
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
pub
fn
active_tokens
(
&
self
,
decay_now
:
Instant
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
let
table
=
self
.workers
.read
();
self
.prompt_registry
.active_tokens
(
decay_now
)
let
mut
results
=
HashMap
::
with_capacity
(
table
.slots
.len
());
for
(
worker
,
lock
)
in
&
table
.slots
{
results
.insert
(
*
worker
,
lock
.read
()
.active_tokens
(
decay_now
));
}
results
}
}
/// Return true if any worker satisfies the provided predicate on active token count.
/// Return true if any worker satisfies the provided predicate on active token count.
pub
fn
any_worker_matches_active_tokens
(
pub
fn
any_worker_matches_active_tokens
(
&
self
,
&
self
,
decay_now
:
Instant
,
decay_now
:
Instant
,
mut
predicate
:
impl
FnMut
(
WorkerWithDpRank
,
usize
)
->
bool
,
predicate
:
impl
FnMut
(
WorkerWithDpRank
,
usize
)
->
bool
,
)
->
bool
{
)
->
bool
{
let
table
=
self
.workers
.read
();
self
.prompt_registry
for
(
worker
,
lock
)
in
&
table
.slots
{
.any_worker_matches_active_tokens
(
decay_now
,
predicate
)
if
predicate
(
*
worker
,
lock
.read
()
.active_tokens
(
decay_now
))
{
return
true
;
}
}
false
}
}
pub
fn
get_active_lora_counts
(
&
self
)
->
HashMap
<
String
,
usize
>
{
pub
fn
get_active_lora_counts
(
&
self
)
->
HashMap
<
String
,
usize
>
{
let
mut
counts
:
HashMap
<
String
,
usize
>
=
HashMap
::
new
();
self
.request_index
.active_lora_counts
()
for
entry
in
self
.request_to_lora
.iter
()
{
let
lora_name
=
entry
.value
()
.clone
();
*
counts
.entry
(
lora_name
)
.or_insert
(
0
)
+=
1
;
}
counts
}
}
/// Force expire stale requests across all workers (one-shot).
/// Force expire stale requests across all workers (one-shot).
...
@@ -844,17 +648,24 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -844,17 +648,24 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
let
now
=
Instant
::
now
();
let
now
=
Instant
::
now
();
let
table
=
self
.workers
.read
();
let
table
=
self
.workers
.read
();
let
mut
removed_request_count
=
0
;
let
mut
removed_request_count
=
0
;
for
(
worker
,
lock
)
in
&
table
.slots
{
for
slot
in
&
table
.slots
{
let
removed_requests
=
lock
.write
()
.force_expiry
();
let
mut
seq
=
slot
.sequences
.write
();
if
!
removed_requests
.is_empty
()
{
let
outcome
=
seq
.force_expiry
();
for
expired_id
in
&
removed_requests
{
if
!
outcome
.expired_request_ids
.is_empty
()
{
self
.request_to_worker
.remove
(
expired_id
);
let
load
=
seq
.worker_load_snapshot
();
self
.request_to_lora
.remove
(
expired_id
);
self
.prompt_registry
.apply_membership_delta_and_load
(
removed_request_count
+=
1
;
slot
.worker
,
}
&
slot
.trie_lookup
,
self
.publish_active_load_for_worker
(
*
worker
,
now
);
outcome
.membership_delta
,
load
,
);
removed_request_count
+=
outcome
.expired_request_ids
.len
();
self
.request_index
.remove_requests
(
outcome
.expired_request_ids
.iter
());
self
.publish_worker_load_snapshot
(
slot
.worker
,
load
,
now
);
}
}
}
}
drop
(
table
);
let
duration
=
now
.elapsed
();
let
duration
=
now
.elapsed
();
tracing
::
debug!
(
tracing
::
debug!
(
duration
=
duration
.as_secs_f64
(),
duration
=
duration
.as_secs_f64
(),
...
@@ -890,31 +701,336 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
...
@@ -890,31 +701,336 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}
}
});
});
}
}
}
#[cfg(test)]
fn
ensure_worker_registered
(
&
self
,
worker
:
WorkerWithDpRank
)
{
mod
tests
{
if
self
.workers
.read
()
.index
.contains_key
(
&
worker
)
{
use
std
::
collections
::
HashMap
;
return
;
use
std
::
time
::
Duration
;
}
use
super
::
*
;
let
mut
table
=
self
.workers
.write
();
use
crate
::
protocols
::{
OverlapScores
,
PrefillLoadHint
};
if
table
.index
.contains_key
(
&
worker
)
{
use
crate
::
test_utils
::
NoopSequencePublisher
;
return
;
}
fn
make_sequences
()
->
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>
{
tracing
::
debug!
(
?
worker
,
"Lazily registering worker in slot tracker"
);
ActiveSequencesMultiWorker
::
new
(
let
change
=
table
.ensure_worker
(
self
.block_size
,
worker
);
NoopSequencePublisher
,
drop
(
table
);
4
,
HashMap
::
from
([(
1_u64
,
(
0_u32
,
1_u32
))]),
self
.prompt_registry
.apply_topology_change
(
change
);
false
,
0
,
"test"
,
)
}
}
#[tokio::test]
fn
add_request_local
(
async
fn
add_request_can_skip_prefill_token_tracking
()
{
&
self
,
let
sequences
=
make_sequences
();
req
:
SequenceRequest
,
decay_now
:
Instant
,
)
->
Result
<
(),
SequenceError
>
{
let
SequenceRequest
{
request_id
,
token_sequence
,
isl
,
overlap
,
track_prefill_tokens
,
expected_output_tokens
,
prefill_load_hint
,
worker
,
lora_name
,
}
=
req
;
self
.ensure_worker_registered
(
worker
);
let
(
expired_request_ids
,
load
)
=
{
let
table
=
self
.workers
.read
();
let
&
idx
=
table
.index
.get
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
if
let
Err
(
existing_worker
)
=
self
.request_index
.try_insert_request
(
request_id
.clone
(),
worker
,
lora_name
)
{
return
Err
(
SequenceError
::
DuplicateRequest
{
request_id
,
worker
:
existing_worker
,
});
}
let
slot
=
&
table
.slots
[
idx
];
let
mut
seq
=
slot
.sequences
.write
();
let
outcome
=
seq
.add_request_with_prefill_tracking
(
request_id
,
token_sequence
,
isl
,
overlap
,
expected_output_tokens
,
track_prefill_tokens
,
prefill_load_hint
,
decay_now
,
);
let
load
=
seq
.worker_load_snapshot
();
self
.prompt_registry
.apply_membership_delta_and_load
(
worker
,
&
slot
.trie_lookup
,
outcome
.membership_delta
,
load
,
);
(
outcome
.expired_request_ids
,
load
)
};
self
.request_index
.remove_requests
(
expired_request_ids
.iter
());
self
.publish_worker_load_snapshot
(
worker
,
load
,
decay_now
);
Ok
(())
}
fn
stale_request_not_found
(
&
self
,
request_id
:
&
RequestId
,
worker
:
WorkerWithDpRank
,
operation
:
&
'static
str
,
)
->
SequenceError
{
if
self
.request_index
.worker_for
(
request_id
)
==
Some
(
worker
)
{
self
.request_index
.remove_request
(
request_id
);
tracing
::
warn!
(
%
request_id
,
?
worker
,
operation
,
"request index referenced a missing worker slot; removed stale mapping"
);
}
else
{
tracing
::
warn!
(
%
request_id
,
?
worker
,
operation
,
"request worker slot disappeared before the mutation ran"
);
}
SequenceError
::
RequestNotFound
{
request_id
:
request_id
.clone
(),
}
}
fn
mutate_request_worker_prompt_state_local
(
&
self
,
worker
:
WorkerWithDpRank
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
)
->
PromptMembershipDelta
,
remove_mapping
:
bool
,
)
->
Result
<
(),
SequenceError
>
{
let
load
=
{
let
table
=
self
.workers
.read
();
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
else
{
drop
(
table
);
return
Err
(
self
.stale_request_not_found
(
request_id
,
worker
,
"free_or_mutate"
));
};
let
slot
=
&
table
.slots
[
idx
];
let
mut
seq
=
slot
.sequences
.write
();
let
delta
=
mutate_fn
(
&
mut
seq
,
request_id
,
decay_now
);
let
load
=
seq
.worker_load_snapshot
();
self
.prompt_registry
.apply_membership_delta_and_load
(
worker
,
&
slot
.trie_lookup
,
delta
,
load
,
);
load
};
if
remove_mapping
{
self
.request_index
.remove_request
(
request_id
);
}
self
.publish_worker_load_snapshot
(
worker
,
load
,
decay_now
);
Ok
(())
}
fn
mutate_request_worker_load_state_local
(
&
self
,
worker
:
WorkerWithDpRank
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
),
)
->
Result
<
(),
SequenceError
>
{
let
load
=
{
let
table
=
self
.workers
.read
();
let
Some
(
&
idx
)
=
table
.index
.get
(
&
worker
)
else
{
drop
(
table
);
return
Err
(
self
.stale_request_not_found
(
request_id
,
worker
,
"load_only_mutate"
));
};
let
mut
seq
=
table
.slots
[
idx
]
.sequences
.write
();
mutate_fn
(
&
mut
seq
,
request_id
,
decay_now
);
let
load
=
seq
.worker_load_snapshot
();
self
.prompt_registry
.replace_worker_load_state
(
worker
,
load
);
load
};
self
.publish_worker_load_snapshot
(
worker
,
load
,
decay_now
);
Ok
(())
}
fn
mutate_request_worker_prompt_state
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
event_data
:
ActiveSequenceEventData
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
)
->
PromptMembershipDelta
,
remove_mapping
:
bool
,
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
.request_index
.worker_for
(
request_id
)
.ok_or_else
(||
{
SequenceError
::
RequestNotFound
{
request_id
:
request_id
.clone
(),
}
})
?
;
let
lora_name
=
self
.request_index
.lora_for
(
request_id
);
self
.mutate_request_worker_prompt_state_local
(
worker
,
request_id
,
decay_now
,
mutate_fn
,
remove_mapping
,
)
?
;
self
.spawn_publish_event
(
ActiveSequenceEvent
{
request_id
:
request_id
.clone
(),
worker
,
data
:
event_data
,
router_id
:
self
.router_id
,
lora_name
,
});
Ok
(())
}
fn
mutate_request_worker_load_state
(
&
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
event_data
:
ActiveSequenceEventData
,
mutate_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
,
Instant
),
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
.request_index
.worker_for
(
request_id
)
.ok_or_else
(||
{
SequenceError
::
RequestNotFound
{
request_id
:
request_id
.clone
(),
}
})
?
;
let
lora_name
=
self
.request_index
.lora_for
(
request_id
);
self
.mutate_request_worker_load_state_local
(
worker
,
request_id
,
decay_now
,
mutate_fn
)
?
;
self
.spawn_publish_event
(
ActiveSequenceEvent
{
request_id
:
request_id
.clone
(),
worker
,
data
:
event_data
,
router_id
:
self
.router_id
,
lora_name
,
});
Ok
(())
}
}
#[cfg(test)]
mod
tests
{
use
std
::
collections
::{
HashMap
,
VecDeque
};
use
std
::
future
::{
self
,
Future
};
use
std
::
time
::
Duration
;
use
rustc_hash
::
FxHashMap
;
use
super
::
*
;
use
crate
::
protocols
::{
ActiveSequenceEvent
,
ActiveSequenceEventData
,
BlockHashOptions
,
OverlapScores
,
PrefillLoadHint
,
compute_block_hash_for_seq
,
compute_seq_hash_for_block
,
};
use
crate
::
test_utils
::
NoopSequencePublisher
;
fn
make_sequences
()
->
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>
{
ActiveSequencesMultiWorker
::
new
(
NoopSequencePublisher
,
4
,
HashMap
::
from
([(
1_u64
,
(
0_u32
,
1_u32
))]),
false
,
0
,
"test"
,
)
}
fn
make_multi_sequences
()
->
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>
{
ActiveSequencesMultiWorker
::
new
(
NoopSequencePublisher
,
4
,
HashMap
::
from
([(
1_u64
,
(
0_u32
,
1_u32
)),
(
2_u64
,
(
0_u32
,
1_u32
))]),
false
,
0
,
"test"
,
)
}
fn
naive_potential_loads
(
sequences
:
&
ActiveSequencesMultiWorker
<
NoopSequencePublisher
>
,
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlaps
:
&
OverlapScores
,
track_prefill_tokens
:
bool
,
decay_now
:
Instant
,
)
->
(
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
let
table
=
sequences
.workers
.read
();
let
mut
potential_blocks
=
FxHashMap
::
default
();
let
mut
potential_tokens
=
FxHashMap
::
default
();
for
slot
in
&
table
.slots
{
let
seq
=
slot
.sequences
.read
();
let
overlap_depth
=
token_sequence
.map_or
(
0
,
|
query
|
{
let
active_hashes
=
seq
.active_prompt_hashes
();
query
.iter
()
.position
(|
hash
|
!
active_hashes
.contains
(
hash
))
.unwrap_or
(
query
.len
())
});
let
new_blocks
=
token_sequence
.map_or
(
0
,
|
query
|
query
.len
()
.saturating_sub
(
overlap_depth
));
let
overlap
=
*
overlaps
.scores
.get
(
&
slot
.worker
)
.unwrap_or
(
&
0
);
let
added_tokens
=
if
track_prefill_tokens
{
seq
.new_tokens
(
isl
,
overlap
)
}
else
{
0
};
potential_blocks
.insert
(
slot
.worker
,
seq
.active_blocks
()
+
new_blocks
);
potential_tokens
.insert
(
slot
.worker
,
seq
.active_tokens
(
decay_now
)
+
added_tokens
);
}
(
potential_blocks
,
potential_tokens
)
}
fn
seq_hashes_for_tokens
(
tokens
:
&
[
u32
],
lora_name
:
Option
<&
str
>
)
->
Vec
<
SequenceHash
>
{
let
block_hashes
=
compute_block_hash_for_seq
(
tokens
,
4
,
BlockHashOptions
{
lora_name
,
..
Default
::
default
()
},
);
compute_seq_hash_for_block
(
&
block_hashes
)
}
struct
VecSubscriber
{
events
:
VecDeque
<
anyhow
::
Result
<
ActiveSequenceEvent
>>
,
}
impl
SequenceSubscriber
for
VecSubscriber
{
fn
next_event
(
&
mut
self
,
)
->
impl
Future
<
Output
=
Option
<
anyhow
::
Result
<
ActiveSequenceEvent
>>>
+
Send
{
future
::
ready
(
self
.events
.pop_front
())
}
}
#[tokio::test]
async
fn
add_request_can_skip_prefill_token_tracking
()
{
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
decay_now
=
Instant
::
now
();
let
decay_now
=
Instant
::
now
();
...
@@ -941,6 +1057,419 @@ mod tests {
...
@@ -941,6 +1057,419 @@ mod tests {
);
);
}
}
#[test]
fn
block_membership_index_matches_naive_loads_with_output_blocks_and_prefill_updates
()
{
let
sequences
=
make_multi_sequences
();
let
worker_a
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
worker_b
=
WorkerWithDpRank
::
new
(
2
,
0
);
let
decay_now
=
Instant
::
now
();
sequences
.add_request
(
SequenceRequest
{
request_id
:
"req-a"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
:
worker_a
,
lora_name
:
None
,
},
decay_now
,
)
.unwrap
();
sequences
.add_output_block
(
&
"req-a"
.to_string
(),
Some
(
0.5
))
.unwrap
();
sequences
.mark_prefill_completed
(
&
"req-a"
.to_string
(),
decay_now
)
.unwrap
();
sequences
.add_request
(
SequenceRequest
{
request_id
:
"req-b"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
4
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
:
worker_b
,
lora_name
:
None
,
},
decay_now
,
)
.unwrap
();
let
prompt
=
vec!
[
1
,
2
,
3
,
5
];
let
mut
expected_overlaps
=
OverlapScores
::
new
();
expected_overlaps
.scores
.insert
(
worker_a
,
2
);
expected_overlaps
.scores
.insert
(
worker_b
,
1
);
let
expected
=
naive_potential_loads
(
&
sequences
,
Some
(
&
prompt
),
16
,
&
expected_overlaps
,
true
,
decay_now
,
);
let
mut
actual_overlaps
=
OverlapScores
::
new
();
actual_overlaps
.scores
.insert
(
worker_a
,
2
);
actual_overlaps
.scores
.insert
(
worker_b
,
1
);
let
actual
=
sequences
.potential_blocks_and_tokens_with_prefill_tracking
(
Some
(
&
prompt
),
16
,
actual_overlaps
,
true
,
decay_now
,
);
assert_eq!
(
actual
.0
,
expected
.0
);
assert_eq!
(
actual
.1
,
expected
.1
);
}
#[test]
fn
lora_specific_sequence_hashes_do_not_cross_match
()
{
let
sequences
=
make_multi_sequences
();
let
worker_a
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
worker_b
=
WorkerWithDpRank
::
new
(
2
,
0
);
let
decay_now
=
Instant
::
now
();
let
tokens
=
[
1_u32
,
2
,
3
,
4
,
5
,
6
,
7
,
8
];
let
base_prompt
=
seq_hashes_for_tokens
(
&
tokens
,
None
);
let
lora_prompt
=
seq_hashes_for_tokens
(
&
tokens
,
Some
(
"adapter-a"
));
assert_ne!
(
base_prompt
,
lora_prompt
);
sequences
.add_request
(
SequenceRequest
{
request_id
:
"base"
.to_string
(),
token_sequence
:
Some
(
base_prompt
.clone
()),
isl
:
8
,
overlap
:
0
,
track_prefill_tokens
:
false
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
:
worker_a
,
lora_name
:
None
,
},
decay_now
,
)
.unwrap
();
sequences
.add_request
(
SequenceRequest
{
request_id
:
"lora"
.to_string
(),
token_sequence
:
Some
(
lora_prompt
),
isl
:
8
,
overlap
:
0
,
track_prefill_tokens
:
false
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
:
worker_b
,
lora_name
:
Some
(
"adapter-a"
.to_string
()),
},
decay_now
,
)
.unwrap
();
let
expected
=
naive_potential_loads
(
&
sequences
,
Some
(
&
base_prompt
),
8
,
&
OverlapScores
::
default
(),
false
,
decay_now
,
);
let
actual
=
sequences
.potential_blocks_and_tokens_with_prefill_tracking
(
Some
(
&
base_prompt
),
8
,
OverlapScores
::
default
(),
false
,
decay_now
,
);
assert_eq!
(
actual
.0
,
expected
.0
);
assert_eq!
(
actual
.1
,
expected
.1
);
let
active_blocks
=
sequences
.active_blocks
();
assert_eq!
(
actual
.0
.get
(
&
worker_b
)
.copied
(),
Some
(
active_blocks
[
&
worker_b
]
+
base_prompt
.len
()),
);
}
#[tokio::test(start_paused
=
true
)]
async
fn
force_expiry_clears_block_membership_index
()
{
let
sequences
=
make_multi_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
sequences
.add_request
(
SequenceRequest
{
request_id
:
"req-1"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
,
lora_name
:
None
,
},
Instant
::
now
(),
)
.unwrap
();
tokio
::
time
::
advance
(
Duration
::
from_secs
(
331
))
.await
;
sequences
.force_expire_requests_across_all_workers
();
assert
!
(
sequences
.request_index
.is_empty
());
assert
!
(
sequences
.prompt_registry
.is_block_index_empty
());
assert_eq!
(
sequences
.active_blocks
()
.get
(
&
worker
)
.copied
(),
Some
(
0
));
}
#[tokio::test(start_paused
=
true
)]
async
fn
expiry_then_immediate_readd_preserves_block_membership
()
{
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
sequences
.add_request
(
SequenceRequest
{
request_id
:
"req-1"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
,
lora_name
:
None
,
},
Instant
::
now
(),
)
.unwrap
();
tokio
::
time
::
advance
(
Duration
::
from_secs
(
331
))
.await
;
sequences
.add_request
(
SequenceRequest
{
request_id
:
"req-2"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
,
lora_name
:
None
,
},
Instant
::
now
(),
)
.unwrap
();
assert
!
(
!
sequences
.prompt_registry
.is_block_index_empty
());
assert_eq!
(
sequences
.active_blocks
()
.get
(
&
worker
)
.copied
(),
Some
(
3
));
let
expected
=
naive_potential_loads
(
&
sequences
,
Some
(
&
[
1
,
2
,
3
]),
12
,
&
OverlapScores
::
default
(),
false
,
Instant
::
now
(),
);
let
actual
=
sequences
.potential_blocks_and_tokens_with_prefill_tracking
(
Some
(
&
[
1
,
2
,
3
]),
12
,
OverlapScores
::
default
(),
false
,
Instant
::
now
(),
);
assert_eq!
(
actual
,
expected
);
}
#[tokio::test]
async
fn
replica_sync_add_and_free_keep_block_membership_consistent
()
{
let
sequences
=
ActiveSequencesMultiWorker
::
new
(
NoopSequencePublisher
,
4
,
HashMap
::
from
([(
1_u64
,
(
0_u32
,
1_u32
))]),
true
,
0
,
"test"
,
);
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
subscriber
=
VecSubscriber
{
events
:
VecDeque
::
from
(
vec!
[
Ok
(
ActiveSequenceEvent
{
request_id
:
"req-1"
.to_string
(),
worker
,
data
:
ActiveSequenceEventData
::
AddRequest
{
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
},
router_id
:
99
,
lora_name
:
None
,
}),
Ok
(
ActiveSequenceEvent
{
request_id
:
"req-1"
.to_string
(),
worker
,
data
:
ActiveSequenceEventData
::
Free
,
router_id
:
99
,
lora_name
:
None
,
}),
]),
};
sequences
.run_replica_sync
(
subscriber
,
CancellationToken
::
new
())
.await
.unwrap
();
assert
!
(
sequences
.request_index
.is_empty
());
assert
!
(
sequences
.prompt_registry
.is_block_index_empty
());
assert_eq!
(
sequences
.active_blocks
()
.get
(
&
worker
)
.copied
(),
Some
(
0
));
}
#[tokio::test]
async
fn
replica_sync_add_lazily_registers_missing_worker
()
{
let
sequences
=
ActiveSequencesMultiWorker
::
new
(
NoopSequencePublisher
,
4
,
HashMap
::
new
(),
true
,
0
,
"test"
,
);
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
subscriber
=
VecSubscriber
{
events
:
VecDeque
::
from
(
vec!
[
Ok
(
ActiveSequenceEvent
{
request_id
:
"req-1"
.to_string
(),
worker
,
data
:
ActiveSequenceEventData
::
AddRequest
{
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
true
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
},
router_id
:
99
,
lora_name
:
None
,
})]),
};
sequences
.run_replica_sync
(
subscriber
,
CancellationToken
::
new
())
.await
.unwrap
();
assert_eq!
(
sequences
.num_workers
(),
1
);
assert_eq!
(
sequences
.request_index
.worker_for
(
&
"req-1"
.to_string
()),
Some
(
worker
)
);
assert
!
(
!
sequences
.prompt_registry
.is_block_index_empty
());
assert_eq!
(
sequences
.active_blocks
()
.get
(
&
worker
)
.copied
(),
Some
(
3
));
}
#[test]
fn
worker_removal_then_readd_starts_with_empty_registry_state
()
{
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
decay_now
=
Instant
::
now
();
sequences
.add_request
(
SequenceRequest
{
request_id
:
"req-1"
.to_string
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
false
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
,
lora_name
:
None
,
},
decay_now
,
)
.unwrap
();
sequences
.update_workers
(
&
HashMap
::
new
());
assert
!
(
sequences
.prompt_registry
.is_block_index_empty
());
assert
!
(
sequences
.active_blocks
()
.is_empty
());
assert
!
(
sequences
.request_index
.is_empty
());
sequences
.update_workers
(
&
HashMap
::
from
([(
1_u64
,
(
0_u32
,
1_u32
))]));
assert_eq!
(
sequences
.active_blocks
()
.get
(
&
worker
)
.copied
(),
Some
(
0
));
assert
!
(
sequences
.prompt_registry
.is_block_index_empty
());
}
#[test]
fn
free_is_idempotent_after_request_is_removed
()
{
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
request_id
=
"req-1"
.to_string
();
let
decay_now
=
Instant
::
now
();
sequences
.add_request
(
SequenceRequest
{
request_id
:
request_id
.clone
(),
token_sequence
:
Some
(
vec!
[
1
,
2
,
3
]),
isl
:
12
,
overlap
:
0
,
track_prefill_tokens
:
false
,
expected_output_tokens
:
None
,
prefill_load_hint
:
None
,
worker
,
lora_name
:
None
,
},
decay_now
,
)
.unwrap
();
sequences
.free
(
&
request_id
,
decay_now
)
.unwrap
();
sequences
.free
(
&
request_id
,
decay_now
)
.unwrap
();
assert
!
(
sequences
.request_index
.is_empty
());
assert
!
(
sequences
.prompt_registry
.is_block_index_empty
());
assert_eq!
(
sequences
.active_blocks
()
.get
(
&
worker
)
.copied
(),
Some
(
0
));
}
#[test]
fn
free_cleans_stale_request_mapping_when_worker_slot_is_missing
()
{
let
sequences
=
make_sequences
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
request_id
=
"stale-request"
.to_string
();
sequences
.request_index
.set_request
(
request_id
.clone
(),
worker
,
Some
(
"adapter"
.to_string
()),
);
{
let
mut
table
=
sequences
.workers
.write
();
*
table
=
WorkerTable
::
new
(
sequences
.block_size
,
&
HashMap
::
new
());
}
sequences
.free
(
&
request_id
,
Instant
::
now
())
.unwrap
();
assert
!
(
sequences
.request_index
.is_empty
());
}
#[test]
#[test]
fn
explicit_decay_time_drives_multi_worker_load_queries_consistently
()
{
fn
explicit_decay_time_drives_multi_worker_load_queries_consistently
()
{
let
sequences
=
make_sequences
();
let
sequences
=
make_sequences
();
...
...
lib/kv-router/src/sequences/prefill_tracker.rs
View file @
134d484d
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: Apache-2.0
use
std
::
collections
::
VecDeque
;
use
std
::
collections
::
{
HashMap
,
VecDeque
}
;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
use
super
::
single
::
RequestId
;
use
super
::
single
::
RequestId
;
#[derive(Debug,
Clone,
Copy)]
#[derive(Debug,
Clone,
Copy
,
PartialEq,
Eq
)]
pub
(
super
)
struct
PrefillLoadState
{
pub
(
super
)
struct
PrefillLoadState
{
pub
(
super
)
initial_effective_prefill_tokens
:
usize
,
pub
(
super
)
initial_effective_prefill_tokens
:
usize
,
pub
(
super
)
expected_prefill_duration
:
Option
<
Duration
>
,
pub
(
super
)
expected_prefill_duration
:
Option
<
Duration
>
,
}
}
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
(
super
)
struct
AnchoredPrefillSnapshot
{
pub
(
super
)
initial_effective_prefill_tokens
:
usize
,
pub
(
super
)
expected_prefill_duration
:
Option
<
Duration
>
,
pub
(
super
)
anchored_since
:
Instant
,
}
#[derive(Debug,
Default,
Clone,
Copy,
PartialEq,
Eq)]
pub
(
super
)
struct
PrefillLoadSnapshot
{
pub
(
super
)
prefill_full_tokens_sum
:
usize
,
pub
(
super
)
anchored_prefill
:
Option
<
AnchoredPrefillSnapshot
>
,
}
impl
PrefillLoadSnapshot
{
pub
(
super
)
fn
active_tokens_at
(
&
self
,
now
:
Instant
)
->
usize
{
let
Some
(
anchored_prefill
)
=
self
.anchored_prefill
else
{
return
0
;
};
let
anchored_full
=
anchored_prefill
.initial_effective_prefill_tokens
;
let
anchored_remaining
=
match
anchored_prefill
.expected_prefill_duration
{
None
=>
anchored_full
,
Some
(
expected_prefill_duration
)
if
expected_prefill_duration
.is_zero
()
=>
0
,
Some
(
expected_prefill_duration
)
=>
{
let
elapsed
=
now
.saturating_duration_since
(
anchored_prefill
.anchored_since
);
let
remaining_fraction
=
(
1.0
-
(
elapsed
.as_secs_f64
()
/
expected_prefill_duration
.as_secs_f64
()))
.clamp
(
0.0
,
1.0
);
((
anchored_full
as
f64
)
*
remaining_fraction
)
.ceil
()
as
usize
}
};
self
.prefill_full_tokens_sum
.checked_sub
(
anchored_full
)
.expect
(
"prefill_full_tokens_sum smaller than anchored load"
)
+
anchored_remaining
}
}
pub
(
super
)
fn
added_prefill_tokens
(
block_size
:
usize
,
isl
:
usize
,
overlap
:
u32
)
->
usize
{
let
cached_tokens
=
(
overlap
as
usize
)
*
block_size
;
isl
.checked_sub
(
cached_tokens
)
.unwrap_or_else
(||
{
tracing
::
error!
(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {block_size}), returning 0"
,
);
0
})
}
#[derive(Debug,
Default)]
#[derive(Debug,
Default)]
pub
(
super
)
struct
PrefillLoadTracker
{
pub
(
super
)
struct
PrefillLoadTracker
{
pub
(
super
)
prefills
:
HashMap
<
RequestId
,
PrefillLoadState
>
,
pub
(
super
)
prefill_order
:
VecDeque
<
RequestId
>
,
pub
(
super
)
prefill_order
:
VecDeque
<
RequestId
>
,
pub
(
super
)
prefill_full_tokens_sum
:
usize
,
pub
(
super
)
prefill_full_tokens_sum
:
usize
,
pub
(
super
)
anchored_prefill
:
Option
<
(
RequestId
,
Instant
)
>
,
pub
(
super
)
anchored_prefill
:
Option
<
(
RequestId
,
Instant
)
>
,
...
@@ -27,6 +76,7 @@ impl PrefillLoadTracker {
...
@@ -27,6 +76,7 @@ impl PrefillLoadTracker {
prefill
:
PrefillLoadState
,
prefill
:
PrefillLoadState
,
decay_now
:
Instant
,
decay_now
:
Instant
,
)
{
)
{
self
.prefills
.insert
(
request_id
.clone
(),
prefill
);
self
.prefill_full_tokens_sum
+=
prefill
.initial_effective_prefill_tokens
;
self
.prefill_full_tokens_sum
+=
prefill
.initial_effective_prefill_tokens
;
let
should_anchor
=
self
.anchored_prefill
.is_none
();
let
should_anchor
=
self
.anchored_prefill
.is_none
();
self
.prefill_order
.push_back
(
request_id
.clone
());
self
.prefill_order
.push_back
(
request_id
.clone
());
...
@@ -38,9 +88,9 @@ impl PrefillLoadTracker {
...
@@ -38,9 +88,9 @@ impl PrefillLoadTracker {
pub
(
super
)
fn
remove
(
pub
(
super
)
fn
remove
(
&
mut
self
,
&
mut
self
,
request_id
:
&
RequestId
,
request_id
:
&
RequestId
,
prefill
:
PrefillLoadState
,
decay_now
:
Instant
,
decay_now
:
Instant
,
)
{
)
->
Option
<
PrefillLoadState
>
{
let
prefill
=
self
.prefills
.remove
(
request_id
)
?
;
self
.prefill_full_tokens_sum
=
self
self
.prefill_full_tokens_sum
=
self
.prefill_full_tokens_sum
.prefill_full_tokens_sum
.checked_sub
(
prefill
.initial_effective_prefill_tokens
)
.checked_sub
(
prefill
.initial_effective_prefill_tokens
)
...
@@ -60,6 +110,7 @@ impl PrefillLoadTracker {
...
@@ -60,6 +110,7 @@ impl PrefillLoadTracker {
{
{
self
.set_anchor_to_front
(
decay_now
);
self
.set_anchor_to_front
(
decay_now
);
}
}
Some
(
prefill
)
}
}
pub
(
super
)
fn
set_anchor_to_front
(
&
mut
self
,
now
:
Instant
)
{
pub
(
super
)
fn
set_anchor_to_front
(
&
mut
self
,
now
:
Instant
)
{
...
@@ -69,4 +120,209 @@ impl PrefillLoadTracker {
...
@@ -69,4 +120,209 @@ impl PrefillLoadTracker {
.cloned
()
.cloned
()
.map
(|
request_id
|
(
request_id
,
now
));
.map
(|
request_id
|
(
request_id
,
now
));
}
}
pub
(
super
)
fn
snapshot
(
&
self
)
->
PrefillLoadSnapshot
{
PrefillLoadSnapshot
{
prefill_full_tokens_sum
:
self
.prefill_full_tokens_sum
,
anchored_prefill
:
self
.anchored_prefill
.as_ref
()
.map
(|(
request_id
,
anchored_since
)|
{
let
prefill
=
self
.prefills
.get
(
request_id
)
.copied
()
.expect
(
"anchored prefill missing request state"
);
AnchoredPrefillSnapshot
{
initial_effective_prefill_tokens
:
prefill
.initial_effective_prefill_tokens
,
expected_prefill_duration
:
prefill
.expected_prefill_duration
,
anchored_since
:
*
anchored_since
,
}
}),
}
}
#[cfg(any(test,
debug_assertions))]
pub
(
super
)
fn
assert_consistent
(
&
self
)
{
let
active_prefills
:
std
::
collections
::
HashSet
<
RequestId
>
=
self
.prefills
.keys
()
.cloned
()
.collect
();
let
ordered_prefills
:
std
::
collections
::
HashSet
<
RequestId
>
=
self
.prefill_order
.iter
()
.cloned
()
.collect
();
let
recomputed_prefill_sum
:
usize
=
self
.prefills
.values
()
.map
(|
prefill
|
prefill
.initial_effective_prefill_tokens
)
.sum
();
assert_eq!
(
ordered_prefills
.len
(),
self
.prefill_order
.len
(),
"prefill_order contains duplicate request ids"
,
);
assert_eq!
(
ordered_prefills
,
active_prefills
,
"prefill_order must match active prefill requests"
,
);
assert_eq!
(
self
.prefill_full_tokens_sum
,
recomputed_prefill_sum
,
"prefill_full_tokens_sum drifted from tracker state"
,
);
if
let
Some
(
oldest_request_id
)
=
self
.prefill_order
.front
()
{
let
Some
((
anchored_request_id
,
_
))
=
self
.anchored_prefill
.as_ref
()
else
{
panic!
(
"anchored_prefill must exist when prefill_order is non-empty"
);
};
assert
!
(
self
.prefills
.contains_key
(
oldest_request_id
),
"prefill_order front must point to an active prefill request"
,
);
assert_eq!
(
anchored_request_id
,
oldest_request_id
,
"anchored_prefill must match prefill_order.front()"
,
);
}
else
{
assert
!
(
self
.anchored_prefill
.is_none
(),
"anchored_prefill must be absent when no active prefills remain"
,
);
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
fn
prefill_state
(
tokens
:
usize
,
duration_secs
:
u64
)
->
PrefillLoadState
{
PrefillLoadState
{
initial_effective_prefill_tokens
:
tokens
,
expected_prefill_duration
:
Some
(
Duration
::
from_secs
(
duration_secs
)),
}
}
#[test]
fn
snapshot_without_anchor_reports_zero_active_tokens
()
{
let
tracker
=
PrefillLoadTracker
::
default
();
let
snapshot
=
tracker
.snapshot
();
assert_eq!
(
snapshot
.active_tokens_at
(
Instant
::
now
()),
0
);
}
#[test]
fn
snapshot_only_decays_oldest_prefill
()
{
let
epoch
=
Instant
::
now
();
let
mut
tracker
=
PrefillLoadTracker
::
default
();
let
r1
=
"r1"
.to_string
();
let
r2
=
"r2"
.to_string
();
let
p1
=
prefill_state
(
100
,
10
);
let
p2
=
prefill_state
(
60
,
6
);
tracker
.insert
(
&
r1
,
p1
,
epoch
);
tracker
.insert
(
&
r2
,
p2
,
epoch
+
Duration
::
from_secs
(
2
));
let
snapshot
=
tracker
.snapshot
();
assert_eq!
(
snapshot
.active_tokens_at
(
epoch
+
Duration
::
from_secs
(
2
)),
140
);
assert_eq!
(
snapshot
.active_tokens_at
(
epoch
+
Duration
::
from_secs
(
5
)),
110
);
}
#[test]
fn
removing_anchored_prefill_reanchors_front_and_resets_decay
()
{
let
epoch
=
Instant
::
now
();
let
mut
tracker
=
PrefillLoadTracker
::
default
();
let
r1
=
"r1"
.to_string
();
let
r2
=
"r2"
.to_string
();
let
p1
=
prefill_state
(
100
,
10
);
let
p2
=
prefill_state
(
40
,
8
);
tracker
.insert
(
&
r1
,
p1
,
epoch
);
tracker
.insert
(
&
r2
,
p2
,
epoch
);
assert_eq!
(
tracker
.remove
(
&
r1
,
epoch
+
Duration
::
from_secs
(
3
)),
Some
(
p1
)
);
assert_eq!
(
tracker
.prefill_order
,
VecDeque
::
from
([
r2
.clone
()]));
assert
!
(
tracker
.anchored_prefill
.as_ref
()
.is_some_and
(|(
request_id
,
_
)|
request_id
==
&
r2
)
);
let
snapshot
=
tracker
.snapshot
();
assert_eq!
(
snapshot
.active_tokens_at
(
epoch
+
Duration
::
from_secs
(
3
)),
40
);
assert_eq!
(
snapshot
.active_tokens_at
(
epoch
+
Duration
::
from_secs
(
5
)),
30
);
}
#[test]
fn
removing_nonfront_prefill_preserves_existing_anchor
()
{
let
epoch
=
Instant
::
now
();
let
mut
tracker
=
PrefillLoadTracker
::
default
();
let
r1
=
"r1"
.to_string
();
let
r2
=
"r2"
.to_string
();
let
p1
=
prefill_state
(
30
,
6
);
let
p2
=
prefill_state
(
20
,
4
);
tracker
.insert
(
&
r1
,
p1
,
epoch
);
tracker
.insert
(
&
r2
,
p2
,
epoch
);
assert_eq!
(
tracker
.remove
(
&
r2
,
epoch
+
Duration
::
from_secs
(
2
)),
Some
(
p2
)
);
assert_eq!
(
tracker
.prefill_order
,
VecDeque
::
from
([
r1
.clone
()]));
assert
!
(
tracker
.anchored_prefill
.as_ref
()
.is_some_and
(|(
request_id
,
anchored_since
)|
{
request_id
==
&
r1
&&
*
anchored_since
==
epoch
})
);
let
snapshot
=
tracker
.snapshot
();
assert_eq!
(
snapshot
.active_tokens_at
(
epoch
+
Duration
::
from_secs
(
2
)),
21
);
}
#[test]
fn
duplicate_cleanup_is_idempotent
()
{
let
epoch
=
Instant
::
now
();
let
mut
tracker
=
PrefillLoadTracker
::
default
();
let
r1
=
"r1"
.to_string
();
let
r2
=
"r2"
.to_string
();
let
p1
=
prefill_state
(
50
,
10
);
let
p2
=
prefill_state
(
30
,
10
);
tracker
.insert
(
&
r1
,
p1
,
epoch
);
tracker
.insert
(
&
r2
,
p2
,
epoch
);
tracker
.assert_consistent
();
assert_eq!
(
tracker
.remove
(
&
r1
,
epoch
),
Some
(
p1
));
assert_eq!
(
tracker
.remove
(
&
r1
,
epoch
),
None
);
assert_eq!
(
tracker
.prefill_full_tokens_sum
,
30
);
assert_eq!
(
tracker
.prefill_order
,
VecDeque
::
from
([
r2
.clone
()]));
assert_eq!
(
tracker
.remove
(
&
r2
,
epoch
),
Some
(
p2
));
assert_eq!
(
tracker
.remove
(
&
r2
,
epoch
),
None
);
tracker
.assert_consistent
();
assert_eq!
(
tracker
.prefill_full_tokens_sum
,
0
);
assert
!
(
tracker
.prefill_order
.is_empty
());
assert
!
(
tracker
.prefills
.is_empty
());
}
}
}
lib/kv-router/src/sequences/prompt_membership_trie.rs
0 → 100644
View file @
134d484d
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::
Arc
;
use
dynamo_tokens
::
SequenceHash
;
use
parking_lot
::
RwLock
;
use
rustc_hash
::{
FxHashMap
,
FxHashSet
};
use
crate
::
protocols
::
WorkerWithDpRank
;
type
SharedNode
=
Arc
<
RwLock
<
PromptTrieNode
>>
;
pub
(
super
)
type
WorkerLookup
=
FxHashMap
<
SequenceHash
,
SharedNode
>
;
#[derive(Debug)]
pub
(
super
)
struct
PromptTrieNode
{
edge
:
Vec
<
SequenceHash
>
,
edge_index
:
FxHashMap
<
SequenceHash
,
usize
>
,
worker_cutoffs
:
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
full_edge_workers
:
FxHashSet
<
WorkerWithDpRank
>
,
children
:
FxHashMap
<
SequenceHash
,
SharedNode
>
,
}
impl
PromptTrieNode
{
fn
new
()
->
Self
{
Self
{
edge
:
Vec
::
new
(),
edge_index
:
FxHashMap
::
default
(),
worker_cutoffs
:
FxHashMap
::
default
(),
full_edge_workers
:
FxHashSet
::
default
(),
children
:
FxHashMap
::
default
(),
}
}
#[cfg(any(test,
feature
=
"bench"
))]
fn
has_any_workers
(
&
self
)
->
bool
{
!
self
.full_edge_workers
.is_empty
()
||
!
self
.worker_cutoffs
.is_empty
()
}
fn
current_cutoff
(
&
self
,
worker
:
WorkerWithDpRank
)
->
usize
{
if
self
.full_edge_workers
.contains
(
&
worker
)
{
self
.edge
.len
()
}
else
{
self
.worker_cutoffs
.get
(
&
worker
)
.copied
()
.unwrap_or
(
0
)
}
}
fn
covers_pos
(
&
self
,
worker
:
WorkerWithDpRank
,
pos
:
usize
)
->
bool
{
self
.full_edge_workers
.contains
(
&
worker
)
||
matches!
(
self
.worker_cutoffs
.get
(
&
worker
),
Some
(
&
cutoff
)
if
pos
<
cutoff
)
}
fn
clear_children_if_unreachable
(
&
mut
self
)
{
if
self
.full_edge_workers
.is_empty
()
{
self
.children
.clear
();
}
}
fn
uncovered_suffix_hashes
(
&
self
,
cutoff
:
usize
)
->
Vec
<
SequenceHash
>
{
debug_assert!
(
cutoff
<=
self
.edge
.len
());
self
.edge
[
cutoff
..
]
.to_vec
()
}
fn
drop_worker
(
&
mut
self
,
worker
:
WorkerWithDpRank
)
{
self
.full_edge_workers
.remove
(
&
worker
);
self
.worker_cutoffs
.remove
(
&
worker
);
self
.clear_children_if_unreachable
();
}
fn
promote_to_full
(
&
mut
self
,
worker
:
WorkerWithDpRank
)
{
if
!
self
.full_edge_workers
.contains
(
&
worker
)
{
self
.worker_cutoffs
.remove
(
&
worker
);
self
.full_edge_workers
.insert
(
worker
);
}
}
fn
remove_worker_at_pos
(
&
mut
self
,
worker
:
WorkerWithDpRank
,
pos
:
usize
,
removed_hash
:
SequenceHash
,
)
->
RemoveOutcome
{
let
current_cutoff
=
self
.current_cutoff
(
worker
);
if
pos
>=
current_cutoff
{
return
RemoveOutcome
{
stale_hashes
:
vec!
[
removed_hash
],
};
}
let
new_cutoff
=
pos
;
let
stale_hashes
=
self
.uncovered_suffix_hashes
(
new_cutoff
);
if
new_cutoff
==
0
{
self
.drop_worker
(
worker
);
}
else
{
self
.full_edge_workers
.remove
(
&
worker
);
self
.worker_cutoffs
.insert
(
worker
,
new_cutoff
);
self
.clear_children_if_unreachable
();
}
RemoveOutcome
{
stale_hashes
}
}
#[cfg(any(test,
feature
=
"bench"
))]
fn
live_children
(
&
self
)
->
Vec
<
SharedNode
>
{
self
.children
.values
()
.filter
(|
child
|
{
let
guard
=
child
.read
();
guard
.has_any_workers
()
||
!
guard
.children
.is_empty
()
})
.cloned
()
.collect
()
}
}
struct
RemoveOutcome
{
stale_hashes
:
Vec
<
SequenceHash
>
,
}
pub
(
super
)
struct
PromptMembershipTrie
{
root
:
SharedNode
,
}
impl
Default
for
PromptMembershipTrie
{
fn
default
()
->
Self
{
Self
::
new
()
}
}
impl
Drop
for
PromptMembershipTrie
{
fn
drop
(
&
mut
self
)
{
let
mut
stack
:
Vec
<
SharedNode
>
=
Vec
::
new
();
{
let
mut
root
=
self
.root
.write
();
stack
.extend
(
root
.children
.drain
()
.map
(|(
_
,
child
)|
child
));
}
while
let
Some
(
node
)
=
stack
.pop
()
{
if
let
Ok
(
rwlock
)
=
Arc
::
try_unwrap
(
node
)
{
let
mut
inner
=
rwlock
.into_inner
();
stack
.extend
(
inner
.children
.drain
()
.map
(|(
_
,
child
)|
child
));
}
}
}
}
impl
PromptMembershipTrie
{
pub
(
super
)
fn
new
()
->
Self
{
Self
{
root
:
Arc
::
new
(
RwLock
::
new
(
PromptTrieNode
::
new
())),
}
}
fn
find_in_subtree
(
start
:
&
SharedNode
,
hash
:
SequenceHash
)
->
Option
<
SharedNode
>
{
let
mut
stack
=
Vec
::
new
();
{
let
guard
=
start
.read
();
stack
.extend
(
guard
.children
.values
()
.cloned
());
}
while
let
Some
(
node
)
=
stack
.pop
()
{
let
guard
=
node
.read
();
if
guard
.edge_index
.contains_key
(
&
hash
)
{
drop
(
guard
);
return
Some
(
node
);
}
stack
.extend
(
guard
.children
.values
()
.cloned
());
}
None
}
fn
resolve_lookup
(
worker_lookup
:
&
mut
WorkerLookup
,
hash
:
SequenceHash
)
->
Option
<
SharedNode
>
{
let
node
=
worker_lookup
.get
(
&
hash
)
?
.clone
();
let
found
=
{
let
guard
=
node
.read
();
guard
.edge_index
.contains_key
(
&
hash
)
};
if
found
{
return
Some
(
node
);
}
let
resolved
=
Self
::
find_in_subtree
(
&
node
,
hash
)
?
;
worker_lookup
.insert
(
hash
,
resolved
.clone
());
Some
(
resolved
)
}
fn
split_node
(
node
:
&
mut
PromptTrieNode
,
pos
:
usize
)
->
SharedNode
{
debug_assert!
(
pos
>
0
&&
pos
<
node
.edge
.len
());
let
suffix_edge
=
node
.edge
.split_off
(
pos
);
let
suffix_first_hash
=
suffix_edge
[
0
];
let
mut
suffix_edge_index
=
FxHashMap
::
default
();
for
(
i
,
&
hash
)
in
suffix_edge
.iter
()
.enumerate
()
{
suffix_edge_index
.insert
(
hash
,
i
);
}
for
&
hash
in
&
suffix_edge
{
node
.edge_index
.remove
(
&
hash
);
}
let
mut
suffix_full
=
FxHashSet
::
default
();
let
mut
suffix_cutoffs
=
FxHashMap
::
default
();
let
mut
to_promote
=
Vec
::
new
();
for
&
worker
in
&
node
.full_edge_workers
{
suffix_full
.insert
(
worker
);
}
for
(
&
worker
,
&
cutoff
)
in
&
node
.worker_cutoffs
{
if
cutoff
>=
pos
{
to_promote
.push
(
worker
);
let
suffix_cutoff
=
cutoff
-
pos
;
if
suffix_cutoff
>
0
{
suffix_cutoffs
.insert
(
worker
,
suffix_cutoff
);
}
}
}
for
worker
in
to_promote
{
node
.worker_cutoffs
.remove
(
&
worker
);
node
.full_edge_workers
.insert
(
worker
);
}
let
suffix_children
=
std
::
mem
::
take
(
&
mut
node
.children
);
let
suffix
=
Arc
::
new
(
RwLock
::
new
(
PromptTrieNode
{
edge
:
suffix_edge
,
edge_index
:
suffix_edge_index
,
worker_cutoffs
:
suffix_cutoffs
,
full_edge_workers
:
suffix_full
,
children
:
suffix_children
,
}));
node
.children
.insert
(
suffix_first_hash
,
suffix
.clone
());
suffix
}
pub
(
super
)
fn
store_chain
(
&
self
,
worker
:
WorkerWithDpRank
,
lookup
:
&
Arc
<
RwLock
<
WorkerLookup
>>
,
parent
:
Option
<
SequenceHash
>
,
hashes
:
&
[
SequenceHash
],
)
{
if
hashes
.is_empty
()
{
return
;
}
let
mut
worker_lookup
=
lookup
.write
();
let
parent
=
match
parent
{
Some
(
parent_hash
)
=>
loop
{
let
Some
(
node
)
=
Self
::
resolve_lookup
(
&
mut
worker_lookup
,
parent_hash
)
else
{
tracing
::
warn!
(
?
worker
,
?
parent_hash
,
"prompt parent hash not found"
);
return
;
};
{
let
guard
=
node
.read
();
let
Some
(
&
pos
)
=
guard
.edge_index
.get
(
&
parent_hash
)
else
{
continue
;
};
if
!
guard
.covers_pos
(
worker
,
pos
)
{
worker_lookup
.remove
(
&
parent_hash
);
tracing
::
warn!
(
?
worker
,
?
parent_hash
,
pos
,
"worker no longer covers prompt parent"
);
return
;
}
}
let
split_suffix
=
{
let
mut
guard
=
node
.write
();
if
!
guard
.edge_index
.contains_key
(
&
parent_hash
)
{
continue
;
}
if
!
guard
.edge
.is_empty
()
&&
*
guard
.edge
.last
()
.unwrap
()
!=
parent_hash
{
let
split_pos
=
guard
.edge
.iter
()
.position
(|
hash
|
*
hash
==
parent_hash
)
.expect
(
"parent hash presence was checked above"
);
Some
(
Self
::
split_node
(
&
mut
guard
,
split_pos
+
1
))
}
else
{
None
}
};
if
split_suffix
.is_some
()
{
continue
;
}
break
node
;
},
None
=>
self
.root
.clone
(),
};
self
.insert_hashes_from
(
worker
,
&
mut
worker_lookup
,
&
parent
,
hashes
);
}
fn
insert_hashes_from
(
&
self
,
worker
:
WorkerWithDpRank
,
worker_lookup
:
&
mut
WorkerLookup
,
parent
:
&
SharedNode
,
hashes
:
&
[
SequenceHash
],
)
{
let
mut
current_parent
=
parent
.clone
();
let
mut
remaining
=
hashes
;
let
mut
last_hash
=
None
;
while
!
remaining
.is_empty
()
{
let
first_hash
=
remaining
[
0
];
let
child
=
{
let
mut
parent_guard
=
current_parent
.write
();
if
let
Some
(
last_hash
)
=
last_hash
&&
!
parent_guard
.edge_index
.contains_key
(
&
last_hash
)
{
drop
(
parent_guard
);
if
let
Some
(
resolved
)
=
Self
::
resolve_lookup
(
worker_lookup
,
last_hash
)
{
current_parent
=
resolved
;
}
continue
;
}
match
parent_guard
.children
.get
(
&
first_hash
)
.cloned
()
{
Some
(
existing
)
=>
existing
,
None
=>
{
let
edge
=
remaining
.to_vec
();
let
mut
edge_index
=
FxHashMap
::
default
();
for
(
i
,
&
hash
)
in
edge
.iter
()
.enumerate
()
{
edge_index
.insert
(
hash
,
i
);
}
let
mut
full_edge_workers
=
FxHashSet
::
default
();
full_edge_workers
.insert
(
worker
);
let
new_node
=
Arc
::
new
(
RwLock
::
new
(
PromptTrieNode
{
edge
,
edge_index
,
worker_cutoffs
:
FxHashMap
::
default
(),
full_edge_workers
,
children
:
FxHashMap
::
default
(),
}));
parent_guard
.children
.insert
(
first_hash
,
new_node
.clone
());
drop
(
parent_guard
);
for
&
hash
in
remaining
{
worker_lookup
.insert
(
hash
,
new_node
.clone
());
}
return
;
}
}
};
{
let
mut
child_guard
=
child
.write
();
let
edge_len
=
child_guard
.edge
.len
();
let
mut
match_len
=
0
;
for
(
&
edge_hash
,
&
query_hash
)
in
child_guard
.edge
.iter
()
.zip
(
remaining
.iter
())
{
if
edge_hash
!=
query_hash
{
break
;
}
match_len
+=
1
;
}
debug_assert!
(
match_len
>=
1
);
if
match_len
<
edge_len
{
let
_
suffix
=
Self
::
split_node
(
&
mut
child_guard
,
match_len
);
child_guard
.promote_to_full
(
worker
);
let
tail
=
&
remaining
[
match_len
..
];
if
!
tail
.is_empty
()
{
let
edge
=
tail
.to_vec
();
let
mut
edge_index
=
FxHashMap
::
default
();
for
(
i
,
&
hash
)
in
edge
.iter
()
.enumerate
()
{
edge_index
.insert
(
hash
,
i
);
}
let
mut
full_edge_workers
=
FxHashSet
::
default
();
full_edge_workers
.insert
(
worker
);
let
tail_first_hash
=
tail
[
0
];
let
new_node
=
Arc
::
new
(
RwLock
::
new
(
PromptTrieNode
{
edge
,
edge_index
,
worker_cutoffs
:
FxHashMap
::
default
(),
full_edge_workers
,
children
:
FxHashMap
::
default
(),
}));
child_guard
.children
.insert
(
tail_first_hash
,
new_node
.clone
());
drop
(
child_guard
);
for
&
hash
in
&
remaining
[
..
match_len
]
{
worker_lookup
.insert
(
hash
,
child
.clone
());
}
for
&
hash
in
tail
{
worker_lookup
.insert
(
hash
,
new_node
.clone
());
}
}
else
{
drop
(
child_guard
);
for
&
hash
in
&
remaining
[
..
match_len
]
{
worker_lookup
.insert
(
hash
,
child
.clone
());
}
}
return
;
}
child_guard
.promote_to_full
(
worker
);
drop
(
child_guard
);
for
&
hash
in
&
remaining
[
..
edge_len
]
{
worker_lookup
.insert
(
hash
,
child
.clone
());
}
last_hash
=
Some
(
remaining
[
edge_len
-
1
]);
remaining
=
&
remaining
[
edge_len
..
];
current_parent
=
child
;
}
}
}
pub
(
super
)
fn
remove_chain
(
&
self
,
worker
:
WorkerWithDpRank
,
lookup
:
&
Arc
<
RwLock
<
WorkerLookup
>>
,
hashes
:
&
[
SequenceHash
],
)
{
let
mut
worker_lookup
=
lookup
.write
();
if
worker_lookup
.is_empty
()
{
return
;
}
'outer
:
for
&
hash
in
hashes
{
let
mut
current_node
=
match
Self
::
resolve_lookup
(
&
mut
worker_lookup
,
hash
)
{
Some
(
node
)
=>
node
,
None
=>
continue
,
};
loop
{
let
update
=
{
let
mut
guard
=
current_node
.write
();
guard
.edge_index
.get
(
&
hash
)
.copied
()
.map
(|
pos
|
guard
.remove_worker_at_pos
(
worker
,
pos
,
hash
))
};
match
update
{
Some
(
outcome
)
=>
{
for
stale_hash
in
outcome
.stale_hashes
{
worker_lookup
.remove
(
&
stale_hash
);
}
continue
'outer
;
}
None
=>
match
Self
::
find_in_subtree
(
&
current_node
,
hash
)
{
Some
(
resolved
)
=>
{
worker_lookup
.insert
(
hash
,
resolved
.clone
());
current_node
=
resolved
;
}
None
=>
{
worker_lookup
.remove
(
&
hash
);
continue
'outer
;
}
},
}
}
}
}
pub
(
super
)
fn
remove_worker
(
&
self
,
worker
:
WorkerWithDpRank
,
lookup
:
&
Arc
<
RwLock
<
WorkerLookup
>>
,
)
{
let
mut
worker_lookup
=
lookup
.write
();
if
worker_lookup
.is_empty
()
{
return
;
}
let
hashes
:
Vec
<
_
>
=
worker_lookup
.keys
()
.copied
()
.collect
();
let
mut
nodes
=
Vec
::
new
();
let
mut
seen
=
FxHashSet
::
<
usize
>
::
default
();
for
hash
in
hashes
{
let
Some
(
node
)
=
Self
::
resolve_lookup
(
&
mut
worker_lookup
,
hash
)
else
{
worker_lookup
.remove
(
&
hash
);
continue
;
};
let
ptr
=
Arc
::
as_ptr
(
&
node
)
as
usize
;
if
seen
.insert
(
ptr
)
{
nodes
.push
(
node
);
}
}
worker_lookup
.clear
();
drop
(
worker_lookup
);
for
node
in
nodes
{
let
mut
guard
=
node
.write
();
guard
.drop_worker
(
worker
);
}
}
pub
(
super
)
fn
compute_overlap_depths
(
&
self
,
query
:
Option
<&
[
SequenceHash
]
>
,
)
->
FxHashMap
<
WorkerWithDpRank
,
usize
>
{
let
Some
(
query
)
=
query
else
{
return
FxHashMap
::
default
();
};
if
query
.is_empty
()
{
return
FxHashMap
::
default
();
}
let
mut
matched_depth
=
FxHashMap
::
default
();
let
mut
active
=
FxHashSet
::
default
();
let
mut
active_count
=
0u
size
;
let
mut
query_pos
=
0u
size
;
let
mut
depth
=
0u
size
;
let
mut
first_node
=
true
;
let
mut
next_child
=
{
let
root
=
self
.root
.read
();
root
.children
.get
(
&
query
[
0
])
.cloned
()
};
loop
{
if
query_pos
>=
query
.len
()
{
break
;
}
let
Some
(
child
)
=
next_child
.take
()
else
{
break
;
};
let
edge_len
;
let
edge_match_len
;
{
let
guard
=
child
.read
();
edge_len
=
guard
.edge
.len
();
let
walk_len
=
edge_len
.min
(
query
.len
()
-
query_pos
);
let
mut
match_len
=
1u
size
;
for
i
in
1
..
walk_len
{
if
guard
.edge
[
i
]
!=
query
[
query_pos
+
i
]
{
break
;
}
match_len
+=
1
;
}
edge_match_len
=
match_len
;
let
prev_depth
=
depth
;
if
first_node
{
active
=
guard
.full_edge_workers
.clone
();
active_count
=
active
.len
();
for
(
&
worker
,
&
cutoff
)
in
&
guard
.worker_cutoffs
{
let
contribution
=
cutoff
.min
(
edge_match_len
);
if
contribution
>
0
{
matched_depth
.insert
(
worker
,
contribution
);
}
}
first_node
=
false
;
}
else
if
!
guard
.worker_cutoffs
.is_empty
()
{
active
.retain
(|
worker
|
{
if
guard
.full_edge_workers
.contains
(
worker
)
{
true
}
else
if
let
Some
(
&
cutoff
)
=
guard
.worker_cutoffs
.get
(
worker
)
{
matched_depth
.insert
(
*
worker
,
prev_depth
+
cutoff
.min
(
edge_match_len
));
false
}
else
{
matched_depth
.insert
(
*
worker
,
prev_depth
);
false
}
});
active_count
=
active
.len
();
}
else
{
let
full_count
=
guard
.full_edge_workers
.len
();
if
full_count
!=
active_count
{
active
.retain
(|
worker
|
{
if
guard
.full_edge_workers
.contains
(
worker
)
{
true
}
else
{
matched_depth
.insert
(
*
worker
,
prev_depth
);
false
}
});
active_count
=
active
.len
();
}
}
next_child
=
if
edge_match_len
==
edge_len
&&
active_count
>
0
&&
query_pos
+
edge_match_len
<
query
.len
()
{
guard
.children
.get
(
&
query
[
query_pos
+
edge_match_len
])
.cloned
()
}
else
{
None
};
}
if
active_count
==
0
{
break
;
}
depth
+=
edge_match_len
;
if
edge_match_len
<
edge_len
{
break
;
}
query_pos
+=
edge_match_len
;
}
for
worker
in
active
{
matched_depth
.insert
(
worker
,
depth
);
}
matched_depth
}
#[cfg(test)]
pub
(
super
)
fn
worker_hashes
(
&
self
)
->
FxHashMap
<
WorkerWithDpRank
,
FxHashSet
<
SequenceHash
>>
{
let
mut
worker_hashes
=
FxHashMap
::
default
();
let
mut
stack
=
vec!
[
self
.root
.clone
()];
while
let
Some
(
node
)
=
stack
.pop
()
{
let
guard
=
node
.read
();
for
&
worker
in
&
guard
.full_edge_workers
{
worker_hashes
.entry
(
worker
)
.or_insert_with
(
FxHashSet
::
default
)
.extend
(
guard
.edge
.iter
()
.copied
());
}
for
(
&
worker
,
&
cutoff
)
in
&
guard
.worker_cutoffs
{
worker_hashes
.entry
(
worker
)
.or_insert_with
(
FxHashSet
::
default
)
.extend
(
guard
.edge
[
..
cutoff
]
.iter
()
.copied
());
}
stack
.extend
(
guard
.children
.values
()
.cloned
());
}
worker_hashes
}
#[cfg(any(test,
feature
=
"bench"
))]
pub
(
super
)
fn
is_empty
(
&
self
)
->
bool
{
let
root
=
self
.root
.read
();
root
.live_children
()
.is_empty
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
fn
worker
(
worker_id
:
u64
,
dp_rank
:
u32
)
->
WorkerWithDpRank
{
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
)
}
fn
lookup
()
->
Arc
<
RwLock
<
WorkerLookup
>>
{
Arc
::
new
(
RwLock
::
new
(
WorkerLookup
::
default
()))
}
#[test]
fn
parent_continuation_chains_extend_and_trim
()
{
let
trie
=
PromptMembershipTrie
::
new
();
let
worker
=
worker
(
1
,
0
);
let
lookup
=
lookup
();
trie
.store_chain
(
worker
,
&
lookup
,
None
,
&
[
1
,
2
,
3
]);
trie
.store_chain
(
worker
,
&
lookup
,
Some
(
3
),
&
[
4
,
5
]);
assert_eq!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
3
,
4
,
5
])),
FxHashMap
::
from_iter
([(
worker
,
5
)]),
);
trie
.remove_chain
(
worker
,
&
lookup
,
&
[
4
,
5
]);
assert_eq!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
3
,
4
,
5
])),
FxHashMap
::
from_iter
([(
worker
,
3
)]),
);
}
#[test]
fn
branching_continuations_across_workers_match_expected_depths
()
{
let
trie
=
PromptMembershipTrie
::
new
();
let
worker_a
=
worker
(
1
,
0
);
let
worker_b
=
worker
(
2
,
0
);
let
lookup_a
=
lookup
();
let
lookup_b
=
lookup
();
trie
.store_chain
(
worker_a
,
&
lookup_a
,
None
,
&
[
1
,
2
,
3
,
4
]);
trie
.store_chain
(
worker_b
,
&
lookup_b
,
None
,
&
[
1
,
2
,
5
]);
assert_eq!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
3
,
4
])),
FxHashMap
::
from_iter
([(
worker_a
,
4
),
(
worker_b
,
2
)]),
);
assert_eq!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
5
])),
FxHashMap
::
from_iter
([(
worker_a
,
2
),
(
worker_b
,
3
)]),
);
}
#[test]
fn
partial_suffix_removal_keeps_prefix
()
{
let
trie
=
PromptMembershipTrie
::
new
();
let
worker
=
worker
(
1
,
0
);
let
lookup
=
lookup
();
trie
.store_chain
(
worker
,
&
lookup
,
None
,
&
[
1
,
2
,
3
,
4
,
5
]);
trie
.remove_chain
(
worker
,
&
lookup
,
&
[
3
,
4
,
5
]);
assert_eq!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
3
,
4
,
5
])),
FxHashMap
::
from_iter
([(
worker
,
2
)]),
);
}
#[test]
fn
remove_worker_preserves_other_workers
()
{
let
trie
=
PromptMembershipTrie
::
new
();
let
worker_a
=
worker
(
1
,
0
);
let
worker_b
=
worker
(
2
,
0
);
let
lookup_a
=
lookup
();
let
lookup_b
=
lookup
();
trie
.store_chain
(
worker_a
,
&
lookup_a
,
None
,
&
[
1
,
2
,
3
]);
trie
.store_chain
(
worker_b
,
&
lookup_b
,
None
,
&
[
1
,
2
,
3
]);
trie
.remove_worker
(
worker_a
,
&
lookup_a
);
assert_eq!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
3
])),
FxHashMap
::
from_iter
([(
worker_b
,
3
)]),
);
}
#[test]
fn
multiple_dp_ranks_with_same_worker_id_remain_isolated
()
{
let
trie
=
PromptMembershipTrie
::
new
();
let
worker_a
=
worker
(
1
,
0
);
let
worker_b
=
worker
(
1
,
1
);
let
lookup_a
=
lookup
();
let
lookup_b
=
lookup
();
trie
.store_chain
(
worker_a
,
&
lookup_a
,
None
,
&
[
1
,
2
,
3
]);
trie
.store_chain
(
worker_b
,
&
lookup_b
,
None
,
&
[
1
,
2
]);
assert_eq!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
3
])),
FxHashMap
::
from_iter
([(
worker_a
,
3
),
(
worker_b
,
2
)]),
);
}
#[test]
fn
clear_worker_state_then_reuse_starts_empty
()
{
let
trie
=
PromptMembershipTrie
::
new
();
let
worker
=
worker
(
1
,
0
);
let
lookup
=
lookup
();
trie
.store_chain
(
worker
,
&
lookup
,
None
,
&
[
1
,
2
,
3
]);
trie
.remove_worker
(
worker
,
&
lookup
);
assert
!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
3
]))
.is_empty
());
trie
.store_chain
(
worker
,
&
lookup
,
None
,
&
[
1
,
2
]);
assert_eq!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
3
])),
FxHashMap
::
from_iter
([(
worker
,
2
)]),
);
}
#[test]
fn
redundant_batched_remove_is_idempotent
()
{
let
trie
=
PromptMembershipTrie
::
new
();
let
worker
=
worker
(
1
,
0
);
let
lookup
=
lookup
();
trie
.store_chain
(
worker
,
&
lookup
,
None
,
&
[
1
,
2
,
3
,
4
]);
trie
.remove_chain
(
worker
,
&
lookup
,
&
[
2
,
3
,
4
]);
trie
.remove_chain
(
worker
,
&
lookup
,
&
[
2
,
3
,
4
]);
assert_eq!
(
trie
.compute_overlap_depths
(
Some
(
&
[
1
,
2
,
3
,
4
])),
FxHashMap
::
from_iter
([(
worker
,
1
)]),
);
}
}
lib/kv-router/src/sequences/prompt_registry.rs
0 → 100644
View file @
134d484d
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
dashmap
::
DashMap
;
use
dynamo_tokens
::
SequenceHash
;
#[cfg(test)]
use
rustc_hash
::
FxHashSet
;
use
rustc_hash
::{
FxBuildHasher
,
FxHashMap
};
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
tokio
::
time
::
Instant
;
use
super
::
prefill_tracker
::{
PrefillLoadSnapshot
,
added_prefill_tokens
};
use
super
::
prompt_membership_trie
::{
PromptMembershipTrie
,
WorkerLookup
};
use
super
::
single
::
PromptMembershipDelta
;
use
super
::
topology
::
WorkerTopologyChange
;
use
crate
::
protocols
::{
OverlapScores
,
WorkerWithDpRank
};
#[derive(Debug,
Default,
Clone,
Copy,
PartialEq,
Eq)]
pub
(
super
)
struct
WorkerLoadSnapshot
{
pub
(
super
)
active_blocks
:
usize
,
pub
(
super
)
prefill
:
PrefillLoadSnapshot
,
}
impl
WorkerLoadSnapshot
{
pub
(
super
)
fn
active_tokens
(
&
self
,
decay_now
:
Instant
)
->
usize
{
self
.prefill
.active_tokens_at
(
decay_now
)
}
}
pub
(
super
)
struct
PromptRegistry
{
// WARNING: prompt membership and worker load are only eventually consistent.
// Each mutation still starts from one worker-local source of truth: we mutate the chosen
// `ActiveSequences`, derive an exact `PromptMembershipDelta` plus `WorkerLoadSnapshot`, then
// publish them separately here. The trie and load map converge to the correct final state
// after the write finishes, but reads can still observe a mixed membership/load state that
// never existed atomically and make a suboptimal routing choice.
membership
:
PromptMembershipTrie
,
loads
:
DashMap
<
WorkerWithDpRank
,
WorkerLoadSnapshot
,
FxBuildHasher
>
,
}
impl
Default
for
PromptRegistry
{
fn
default
()
->
Self
{
Self
{
membership
:
PromptMembershipTrie
::
new
(),
loads
:
DashMap
::
with_hasher
(
FxBuildHasher
),
}
}
}
impl
PromptRegistry
{
pub
(
super
)
fn
new
(
workers
:
impl
IntoIterator
<
Item
=
WorkerWithDpRank
>
)
->
Self
{
let
registry
=
Self
::
default
();
for
worker
in
workers
{
registry
.loads
.entry
(
worker
)
.or_default
();
}
registry
}
pub
(
super
)
fn
replace_worker_load_state
(
&
self
,
worker
:
WorkerWithDpRank
,
load
:
WorkerLoadSnapshot
,
)
{
self
.loads
.insert
(
worker
,
load
);
}
pub
(
super
)
fn
apply_membership_delta_and_load
(
&
self
,
worker
:
WorkerWithDpRank
,
lookup
:
&
Arc
<
parking_lot
::
RwLock
<
WorkerLookup
>>
,
delta
:
PromptMembershipDelta
,
load
:
WorkerLoadSnapshot
,
)
{
for
remove
in
delta
.removes
{
self
.membership
.remove_chain
(
worker
,
lookup
,
&
remove
.hashes
);
}
for
store
in
delta
.stores
{
self
.membership
.store_chain
(
worker
,
lookup
,
store
.parent
,
&
store
.hashes
);
}
self
.loads
.insert
(
worker
,
load
);
}
pub
(
super
)
fn
apply_topology_change
(
&
self
,
change
:
WorkerTopologyChange
)
{
for
removed
in
change
.removed
{
self
.membership
.remove_worker
(
removed
.worker
,
&
removed
.trie_lookup
);
self
.loads
.remove
(
&
removed
.worker
);
}
for
worker
in
change
.added
{
self
.loads
.entry
(
worker
)
.or_default
();
}
}
#[expect(clippy::too_many_arguments)]
fn
project_loads_from_overlap
(
&
self
,
query_len
:
usize
,
matched_depth
:
&
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
isl
:
usize
,
overlaps
:
&
OverlapScores
,
track_prefill_tokens
:
bool
,
block_size
:
usize
,
decay_now
:
Instant
,
)
->
(
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
let
mut
potential_blocks
=
FxHashMap
::
with_capacity_and_hasher
(
self
.loads
.len
(),
FxBuildHasher
);
let
mut
potential_tokens
=
FxHashMap
::
with_capacity_and_hasher
(
self
.loads
.len
(),
FxBuildHasher
);
for
entry
in
&
self
.loads
{
let
worker
=
*
entry
.key
();
let
load
=
*
entry
.value
();
let
overlap_depth
=
matched_depth
.get
(
&
worker
)
.copied
()
.unwrap_or
(
0
);
let
new_blocks
=
query_len
.saturating_sub
(
overlap_depth
);
let
active_tokens
=
load
.active_tokens
(
decay_now
);
let
overlap
=
*
overlaps
.scores
.get
(
&
worker
)
.unwrap_or
(
&
0
);
let
added_tokens
=
if
track_prefill_tokens
{
added_prefill_tokens
(
block_size
,
isl
,
overlap
)
}
else
{
0
};
potential_blocks
.insert
(
worker
,
load
.active_blocks
+
new_blocks
);
potential_tokens
.insert
(
worker
,
active_tokens
+
added_tokens
);
}
(
potential_blocks
,
potential_tokens
)
}
pub
(
super
)
fn
potential_blocks_and_tokens_with_prefill_tracking
(
&
self
,
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlaps
:
&
OverlapScores
,
track_prefill_tokens
:
bool
,
block_size
:
usize
,
decay_now
:
Instant
,
)
->
(
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
let
query_len
=
token_sequence
.map_or
(
0
,
|
query
|
query
.len
());
let
matched_depth
=
self
.membership
.compute_overlap_depths
(
token_sequence
);
self
.project_loads_from_overlap
(
query_len
,
&
matched_depth
,
isl
,
overlaps
,
track_prefill_tokens
,
block_size
,
decay_now
,
)
}
pub
(
super
)
fn
active_blocks
(
&
self
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
self
.loads
.iter
()
.map
(|
entry
|
(
*
entry
.key
(),
entry
.value
()
.active_blocks
))
.collect
()
}
pub
(
super
)
fn
active_tokens
(
&
self
,
decay_now
:
Instant
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
self
.loads
.iter
()
.map
(|
entry
|
(
*
entry
.key
(),
entry
.value
()
.active_tokens
(
decay_now
)))
.collect
()
}
pub
(
super
)
fn
any_worker_matches_active_tokens
(
&
self
,
decay_now
:
Instant
,
mut
predicate
:
impl
FnMut
(
WorkerWithDpRank
,
usize
)
->
bool
,
)
->
bool
{
self
.loads
.iter
()
.any
(|
entry
|
predicate
(
*
entry
.key
(),
entry
.value
()
.active_tokens
(
decay_now
)))
}
#[cfg(test)]
pub
(
super
)
fn
assert_consistent_with_workers
(
&
self
,
expected_loads
:
&
FxHashMap
<
WorkerWithDpRank
,
WorkerLoadSnapshot
>
,
expected_blocks
:
&
FxHashMap
<
WorkerWithDpRank
,
FxHashSet
<
SequenceHash
>>
,
)
{
let
actual_loads
:
FxHashMap
<
_
,
_
>
=
self
.loads
.iter
()
.map
(|
entry
|
(
*
entry
.key
(),
*
entry
.value
()))
.collect
();
let
actual_blocks
=
self
.membership
.worker_hashes
();
assert_eq!
(
actual_loads
,
*
expected_loads
,
"prompt registry worker loads drifted from per-worker state"
,
);
assert_eq!
(
actual_blocks
,
*
expected_blocks
,
"prompt registry prompt membership drifted from per-worker state"
,
);
}
#[cfg(any(test,
feature
=
"bench"
))]
pub
(
super
)
fn
is_block_index_empty
(
&
self
)
->
bool
{
self
.membership
.is_empty
()
}
}
#[cfg(test)]
mod
tests
{
use
std
::
time
::
Duration
;
use
rustc_hash
::{
FxHashMap
,
FxHashSet
};
use
super
::
*
;
use
crate
::
protocols
::
WorkerWithDpRank
;
use
crate
::
sequences
::
prefill_tracker
::
AnchoredPrefillSnapshot
;
use
crate
::
sequences
::
single
::{
PromptMembershipRemove
,
PromptMembershipStore
};
use
crate
::
sequences
::
topology
::
RemovedWorkerState
;
fn
worker
(
worker_id
:
u64
,
dp_rank
:
u32
)
->
WorkerWithDpRank
{
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
)
}
fn
lookup
()
->
Arc
<
parking_lot
::
RwLock
<
WorkerLookup
>>
{
Arc
::
new
(
parking_lot
::
RwLock
::
new
(
WorkerLookup
::
default
()))
}
fn
store
(
parent
:
Option
<
SequenceHash
>
,
hashes
:
&
[
SequenceHash
])
->
PromptMembershipDelta
{
PromptMembershipDelta
{
stores
:
vec!
[
PromptMembershipStore
{
parent
,
hashes
:
hashes
.to_vec
(),
}],
removes
:
Vec
::
new
(),
}
}
fn
worker_load_snapshot
(
active_blocks
:
usize
)
->
WorkerLoadSnapshot
{
WorkerLoadSnapshot
{
active_blocks
,
prefill
:
PrefillLoadSnapshot
::
default
(),
}
}
fn
anchored_load_snapshot
(
active_blocks
:
usize
,
prefill_full_tokens_sum
:
usize
,
anchored_tokens
:
usize
,
expected_prefill_duration
:
Option
<
Duration
>
,
anchored_since
:
Instant
,
)
->
WorkerLoadSnapshot
{
WorkerLoadSnapshot
{
active_blocks
,
prefill
:
PrefillLoadSnapshot
{
prefill_full_tokens_sum
,
anchored_prefill
:
Some
(
AnchoredPrefillSnapshot
{
initial_effective_prefill_tokens
:
anchored_tokens
,
expected_prefill_duration
,
anchored_since
,
}),
},
}
}
fn
hash_set
(
hashes
:
&
[
SequenceHash
])
->
FxHashSet
<
SequenceHash
>
{
hashes
.iter
()
.copied
()
.collect
()
}
#[expect(clippy::too_many_arguments)]
fn
naive_potential_loads
(
expected_loads
:
&
FxHashMap
<
WorkerWithDpRank
,
WorkerLoadSnapshot
>
,
expected_blocks
:
&
FxHashMap
<
WorkerWithDpRank
,
FxHashSet
<
SequenceHash
>>
,
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlaps
:
&
OverlapScores
,
track_prefill_tokens
:
bool
,
block_size
:
usize
,
decay_now
:
Instant
,
)
->
(
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
)
{
let
mut
potential_blocks
=
FxHashMap
::
with_capacity_and_hasher
(
expected_loads
.len
(),
FxBuildHasher
);
let
mut
potential_tokens
=
FxHashMap
::
with_capacity_and_hasher
(
expected_loads
.len
(),
FxBuildHasher
);
for
(
&
worker
,
load
)
in
expected_loads
{
let
overlap_depth
=
token_sequence
.map_or
(
0
,
|
query
|
{
let
worker_blocks
=
expected_blocks
.get
(
&
worker
)
.expect
(
"worker must have a prompt membership entry"
);
query
.iter
()
.position
(|
hash
|
!
worker_blocks
.contains
(
hash
))
.unwrap_or
(
query
.len
())
});
let
new_blocks
=
token_sequence
.map_or
(
0
,
|
query
|
query
.len
()
.saturating_sub
(
overlap_depth
));
let
overlap
=
*
overlaps
.scores
.get
(
&
worker
)
.unwrap_or
(
&
0
);
let
added_tokens
=
if
track_prefill_tokens
{
added_prefill_tokens
(
block_size
,
isl
,
overlap
)
}
else
{
0
};
potential_blocks
.insert
(
worker
,
load
.active_blocks
+
new_blocks
);
potential_tokens
.insert
(
worker
,
load
.active_tokens
(
decay_now
)
+
added_tokens
);
}
(
potential_blocks
,
potential_tokens
)
}
#[test]
fn
removed_hash_can_be_restored_by_later_store
()
{
let
worker
=
worker
(
1
,
0
);
let
registry
=
PromptRegistry
::
new
([
worker
]);
let
lookup
=
lookup
();
let
mut
expected_loads
=
FxHashMap
::
default
();
let
mut
expected_blocks
=
FxHashMap
::
default
();
registry
.apply_membership_delta_and_load
(
worker
,
&
lookup
,
store
(
None
,
&
[
42
]),
worker_load_snapshot
(
1
),
);
let
load
=
worker_load_snapshot
(
1
);
registry
.apply_membership_delta_and_load
(
worker
,
&
lookup
,
PromptMembershipDelta
{
removes
:
vec!
[
PromptMembershipRemove
{
hashes
:
vec!
[
42
]
}],
..
Default
::
default
()
},
load
,
);
registry
.apply_membership_delta_and_load
(
worker
,
&
lookup
,
store
(
None
,
&
[
42
]),
load
);
expected_loads
.insert
(
worker
,
load
);
expected_blocks
.insert
(
worker
,
hash_set
(
&
[
42
]));
registry
.assert_consistent_with_workers
(
&
expected_loads
,
&
expected_blocks
);
}
#[test]
fn
staggered_prefix_overlap_matches_naive_projection
()
{
let
worker_a
=
worker
(
1
,
0
);
let
worker_b
=
worker
(
2
,
0
);
let
worker_c
=
worker
(
3
,
0
);
let
registry
=
PromptRegistry
::
new
([
worker_a
,
worker_b
,
worker_c
]);
let
lookup_a
=
lookup
();
let
lookup_b
=
lookup
();
let
lookup_c
=
lookup
();
let
decay_now
=
Instant
::
now
();
let
full_prompt
:
Vec
<
SequenceHash
>
=
(
1_u64
..=
96
)
.collect
();
let
mut
expected_loads
=
FxHashMap
::
default
();
let
mut
expected_blocks
=
FxHashMap
::
default
();
for
(
worker
,
lookup
,
prompt_len
)
in
[
(
worker_a
,
&
lookup_a
,
96u
size
),
(
worker_b
,
&
lookup_b
,
64
),
(
worker_c
,
&
lookup_c
,
33
),
]
{
let
blocks
=
full_prompt
[
..
prompt_len
]
.to_vec
();
let
load
=
worker_load_snapshot
(
prompt_len
);
registry
.apply_membership_delta_and_load
(
worker
,
lookup
,
store
(
None
,
&
blocks
),
load
);
expected_loads
.insert
(
worker
,
load
);
expected_blocks
.insert
(
worker
,
blocks
.into_iter
()
.collect
());
}
registry
.assert_consistent_with_workers
(
&
expected_loads
,
&
expected_blocks
);
let
expected
=
naive_potential_loads
(
&
expected_loads
,
&
expected_blocks
,
Some
(
&
full_prompt
),
384
,
&
OverlapScores
::
default
(),
false
,
4
,
decay_now
,
);
let
actual
=
registry
.potential_blocks_and_tokens_with_prefill_tracking
(
Some
(
&
full_prompt
),
384
,
&
OverlapScores
::
default
(),
false
,
4
,
decay_now
,
);
assert_eq!
(
actual
,
expected
);
}
#[test]
fn
load_only_update_preserves_prompt_membership_and_active_token_projection
()
{
let
worker
=
worker
(
1
,
0
);
let
registry
=
PromptRegistry
::
new
([
worker
]);
let
lookup
=
lookup
();
let
now
=
Instant
::
now
();
let
anchored_since
=
now
.checked_sub
(
Duration
::
from_secs
(
3
))
.unwrap_or
(
now
);
let
mut
expected_loads
=
FxHashMap
::
default
();
let
mut
expected_blocks
=
FxHashMap
::
default
();
registry
.apply_membership_delta_and_load
(
worker
,
&
lookup
,
store
(
None
,
&
[
1
,
2
,
3
]),
worker_load_snapshot
(
3
),
);
expected_blocks
.insert
(
worker
,
hash_set
(
&
[
1
,
2
,
3
]));
let
updated_load
=
anchored_load_snapshot
(
5
,
12
,
10
,
Some
(
Duration
::
from_secs
(
10
)),
anchored_since
);
registry
.replace_worker_load_state
(
worker
,
updated_load
);
expected_loads
.insert
(
worker
,
updated_load
);
registry
.assert_consistent_with_workers
(
&
expected_loads
,
&
expected_blocks
);
assert_eq!
(
registry
.active_tokens
(
now
)
.get
(
&
worker
)
.copied
(),
Some
(
9
));
let
actual
=
registry
.potential_blocks_and_tokens_with_prefill_tracking
(
Some
(
&
[
1
,
2
,
3
]),
12
,
&
OverlapScores
::
default
(),
false
,
4
,
now
,
);
assert_eq!
(
actual
.0
.get
(
&
worker
)
.copied
(),
Some
(
5
));
assert_eq!
(
actual
.1
.get
(
&
worker
)
.copied
(),
Some
(
9
));
}
#[test]
fn
removing_worker_clears_prompt_membership_and_load_state
()
{
let
worker_a
=
worker
(
1
,
0
);
let
worker_b
=
worker
(
2
,
0
);
let
registry
=
PromptRegistry
::
new
([
worker_a
,
worker_b
]);
let
lookup_a
=
lookup
();
let
lookup_b
=
lookup
();
let
mut
expected_loads
=
FxHashMap
::
default
();
let
mut
expected_blocks
=
FxHashMap
::
default
();
let
load_a
=
worker_load_snapshot
(
3
);
let
load_b
=
worker_load_snapshot
(
2
);
registry
.apply_membership_delta_and_load
(
worker_a
,
&
lookup_a
,
store
(
None
,
&
[
1
,
2
,
3
]),
load_a
,
);
registry
.apply_membership_delta_and_load
(
worker_b
,
&
lookup_b
,
store
(
None
,
&
[
1
,
2
]),
load_b
);
expected_loads
.insert
(
worker_a
,
load_a
);
expected_loads
.insert
(
worker_b
,
load_b
);
expected_blocks
.insert
(
worker_a
,
hash_set
(
&
[
1
,
2
,
3
]));
expected_blocks
.insert
(
worker_b
,
hash_set
(
&
[
1
,
2
]));
registry
.apply_topology_change
(
WorkerTopologyChange
{
added
:
Vec
::
new
(),
removed
:
vec!
[
RemovedWorkerState
{
worker
:
worker_a
,
trie_lookup
:
Arc
::
clone
(
&
lookup_a
),
}],
});
expected_loads
.remove
(
&
worker_a
);
expected_blocks
.remove
(
&
worker_a
);
registry
.assert_consistent_with_workers
(
&
expected_loads
,
&
expected_blocks
);
assert
!
(
!
registry
.active_blocks
()
.contains_key
(
&
worker_a
));
let
actual
=
registry
.potential_blocks_and_tokens_with_prefill_tracking
(
Some
(
&
[
1
,
2
,
3
]),
12
,
&
OverlapScores
::
default
(),
false
,
4
,
Instant
::
now
(),
);
assert_eq!
(
actual
.0
.get
(
&
worker_b
)
.copied
(),
Some
(
3
));
}
#[test]
fn
dp_ranks_with_same_worker_id_remain_isolated
()
{
let
worker_a
=
worker
(
1
,
0
);
let
worker_b
=
worker
(
1
,
1
);
let
registry
=
PromptRegistry
::
new
([
worker_a
,
worker_b
]);
let
lookup_a
=
lookup
();
let
lookup_b
=
lookup
();
let
decay_now
=
Instant
::
now
();
let
mut
expected_loads
=
FxHashMap
::
default
();
let
mut
expected_blocks
=
FxHashMap
::
default
();
let
load_a
=
worker_load_snapshot
(
3
);
let
load_b
=
worker_load_snapshot
(
1
);
registry
.apply_membership_delta_and_load
(
worker_a
,
&
lookup_a
,
store
(
None
,
&
[
1
,
2
,
3
]),
load_a
,
);
registry
.apply_membership_delta_and_load
(
worker_b
,
&
lookup_b
,
store
(
None
,
&
[
1
]),
load_b
);
expected_loads
.insert
(
worker_a
,
load_a
);
expected_loads
.insert
(
worker_b
,
load_b
);
expected_blocks
.insert
(
worker_a
,
hash_set
(
&
[
1
,
2
,
3
]));
expected_blocks
.insert
(
worker_b
,
hash_set
(
&
[
1
]));
registry
.assert_consistent_with_workers
(
&
expected_loads
,
&
expected_blocks
);
let
expected
=
naive_potential_loads
(
&
expected_loads
,
&
expected_blocks
,
Some
(
&
[
1
,
2
,
3
]),
12
,
&
OverlapScores
::
default
(),
false
,
4
,
decay_now
,
);
let
actual
=
registry
.potential_blocks_and_tokens_with_prefill_tracking
(
Some
(
&
[
1
,
2
,
3
]),
12
,
&
OverlapScores
::
default
(),
false
,
4
,
decay_now
,
);
assert_eq!
(
actual
,
expected
);
}
}
lib/kv-router/src/sequences/request_maps.rs
0 → 100644
View file @
134d484d
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
dashmap
::{
DashMap
,
mapref
::
entry
::
Entry
};
use
std
::
collections
::
HashMap
;
use
super
::
single
::
RequestId
;
use
crate
::
protocols
::
WorkerWithDpRank
;
#[derive(Debug,
Default)]
pub
(
super
)
struct
RequestIndex
{
request_to_worker
:
DashMap
<
RequestId
,
WorkerWithDpRank
>
,
request_to_lora
:
DashMap
<
RequestId
,
String
>
,
}
impl
RequestIndex
{
pub
(
super
)
fn
try_insert_request
(
&
self
,
request_id
:
RequestId
,
worker
:
WorkerWithDpRank
,
lora_name
:
Option
<
String
>
,
)
->
Result
<
(),
WorkerWithDpRank
>
{
match
self
.request_to_worker
.entry
(
request_id
.clone
())
{
Entry
::
Occupied
(
entry
)
=>
Err
(
*
entry
.get
()),
Entry
::
Vacant
(
entry
)
=>
{
entry
.insert
(
worker
);
if
let
Some
(
lora_name
)
=
lora_name
{
self
.request_to_lora
.insert
(
request_id
,
lora_name
);
}
Ok
(())
}
}
}
pub
(
super
)
fn
set_request
(
&
self
,
request_id
:
RequestId
,
worker
:
WorkerWithDpRank
,
lora_name
:
Option
<
String
>
,
)
{
self
.request_to_worker
.insert
(
request_id
.clone
(),
worker
);
if
let
Some
(
lora_name
)
=
lora_name
{
self
.request_to_lora
.insert
(
request_id
,
lora_name
);
}
else
{
self
.request_to_lora
.remove
(
&
request_id
);
}
}
pub
(
super
)
fn
worker_for
(
&
self
,
request_id
:
&
RequestId
)
->
Option
<
WorkerWithDpRank
>
{
self
.request_to_worker
.get
(
request_id
)
.map
(|
entry
|
*
entry
)
}
pub
(
super
)
fn
lora_for
(
&
self
,
request_id
:
&
RequestId
)
->
Option
<
String
>
{
self
.request_to_lora
.get
(
request_id
)
.map
(|
entry
|
entry
.value
()
.clone
())
}
pub
(
super
)
fn
remove_request
(
&
self
,
request_id
:
&
RequestId
)
->
Option
<
WorkerWithDpRank
>
{
let
worker
=
self
.request_to_worker
.remove
(
request_id
)
.map
(|(
_
request_id
,
worker
)|
worker
);
self
.request_to_lora
.remove
(
request_id
);
worker
}
pub
(
super
)
fn
remove_requests
<
'a
>
(
&
self
,
request_ids
:
impl
IntoIterator
<
Item
=
&
'a
RequestId
>
)
{
for
request_id
in
request_ids
{
self
.remove_request
(
request_id
);
}
}
pub
(
super
)
fn
remove_worker_requests
(
&
self
,
worker
:
WorkerWithDpRank
)
->
Vec
<
RequestId
>
{
let
request_ids
:
Vec
<
_
>
=
self
.request_to_worker
.iter
()
.filter
(|
entry
|
*
entry
.value
()
==
worker
)
.map
(|
entry
|
entry
.key
()
.clone
())
.collect
();
self
.remove_requests
(
request_ids
.iter
());
request_ids
}
pub
(
super
)
fn
active_lora_counts
(
&
self
)
->
HashMap
<
String
,
usize
>
{
let
mut
counts
=
HashMap
::
new
();
for
entry
in
self
.request_to_lora
.iter
()
{
let
lora_name
=
entry
.value
()
.clone
();
*
counts
.entry
(
lora_name
)
.or_insert
(
0
)
+=
1
;
}
counts
}
#[cfg(any(test,
feature
=
"bench"
))]
pub
(
super
)
fn
is_empty
(
&
self
)
->
bool
{
self
.request_to_worker
.is_empty
()
&&
self
.request_to_lora
.is_empty
()
}
#[cfg(any(test,
feature
=
"bench"
))]
pub
(
super
)
fn
worker_len
(
&
self
)
->
usize
{
self
.request_to_worker
.len
()
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
duplicate_insert_returns_existing_worker
()
{
let
index
=
RequestIndex
::
default
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
index
.try_insert_request
(
"req-1"
.to_string
(),
worker
,
Some
(
"adapter"
.to_string
()))
.unwrap
();
assert_eq!
(
index
.try_insert_request
(
"req-1"
.to_string
(),
WorkerWithDpRank
::
new
(
2
,
0
),
None
),
Err
(
worker
)
);
assert_eq!
(
index
.worker_for
(
&
"req-1"
.to_string
()),
Some
(
worker
));
assert_eq!
(
index
.lora_for
(
&
"req-1"
.to_string
()),
Some
(
"adapter"
.to_string
())
);
}
#[test]
fn
remove_request_is_idempotent
()
{
let
index
=
RequestIndex
::
default
();
let
worker
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
request_id
=
"req-1"
.to_string
();
index
.set_request
(
request_id
.clone
(),
worker
,
Some
(
"adapter"
.to_string
()));
assert_eq!
(
index
.remove_request
(
&
request_id
),
Some
(
worker
));
assert_eq!
(
index
.remove_request
(
&
request_id
),
None
);
assert
!
(
index
.is_empty
());
}
#[test]
fn
set_request_without_lora_clears_stale_lora_mapping
()
{
let
index
=
RequestIndex
::
default
();
let
request_id
=
"req-1"
.to_string
();
index
.set_request
(
request_id
.clone
(),
WorkerWithDpRank
::
new
(
1
,
0
),
Some
(
"adapter"
.to_string
()),
);
index
.set_request
(
request_id
.clone
(),
WorkerWithDpRank
::
new
(
2
,
0
),
None
);
assert_eq!
(
index
.worker_for
(
&
request_id
),
Some
(
WorkerWithDpRank
::
new
(
2
,
0
))
);
assert_eq!
(
index
.lora_for
(
&
request_id
),
None
);
}
#[test]
fn
remove_worker_requests_clears_both_maps
()
{
let
index
=
RequestIndex
::
default
();
let
worker_a
=
WorkerWithDpRank
::
new
(
1
,
0
);
let
worker_b
=
WorkerWithDpRank
::
new
(
2
,
0
);
index
.set_request
(
"req-a"
.to_string
(),
worker_a
,
Some
(
"adapter-a"
.to_string
()));
index
.set_request
(
"req-b"
.to_string
(),
worker_b
,
Some
(
"adapter-b"
.to_string
()));
index
.set_request
(
"req-c"
.to_string
(),
worker_a
,
None
);
let
mut
removed
=
index
.remove_worker_requests
(
worker_a
);
removed
.sort
();
assert_eq!
(
removed
,
vec!
[
"req-a"
.to_string
(),
"req-c"
.to_string
()]);
assert_eq!
(
index
.worker_for
(
&
"req-b"
.to_string
()),
Some
(
worker_b
));
assert_eq!
(
index
.active_lora_counts
(),
HashMap
::
from
([(
"adapter-b"
.to_string
(),
1
)])
);
}
}
lib/kv-router/src/sequences/single.rs
View file @
134d484d
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).
//! requests share common prefixes (e.g., system prompts, few-shot examples).
use
derive_getters
::
Getters
;
use
dynamo_tokens
::
SequenceHash
;
use
dynamo_tokens
::
SequenceHash
;
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
...
@@ -26,8 +25,12 @@ use std::time::Duration;
...
@@ -26,8 +25,12 @@ use std::time::Duration;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
#[cfg(test)]
use
rustc_hash
::
FxHashSet
;
use
super
::
block_tracker
::
BlockTracker
;
use
super
::
block_tracker
::
BlockTracker
;
use
super
::
prefill_tracker
::{
PrefillLoadState
,
PrefillLoadTracker
};
use
super
::
prefill_tracker
::{
PrefillLoadState
,
PrefillLoadTracker
,
added_prefill_tokens
};
use
super
::
prompt_registry
::
WorkerLoadSnapshot
;
use
crate
::
protocols
::
PrefillLoadHint
;
use
crate
::
protocols
::
PrefillLoadHint
;
/// Duration after which stale requests may be expired (5 minutes).
/// Duration after which stale requests may be expired (5 minutes).
...
@@ -42,28 +45,75 @@ pub type RequestId = String;
...
@@ -42,28 +45,75 @@ pub type RequestId = String;
#[derive(Debug)]
#[derive(Debug)]
pub
(
super
)
struct
RequestState
{
pub
(
super
)
struct
RequestState
{
blocks
:
Vec
<
(
SequenceHash
,
Arc
<
()
>
)
>
,
prompt_blocks
:
Vec
<
(
SequenceHash
,
Arc
<
()
>
)
>
,
output_blocks
:
Vec
<
(
SequenceHash
,
Arc
<
()
>
)
>
,
started_at
:
Instant
,
started_at
:
Instant
,
prefill
:
Option
<
PrefillLoadState
>
,
expected_output_tokens
:
Option
<
u32
>
,
expected_output_tokens
:
Option
<
u32
>
,
}
}
impl
RequestState
{
fn
all_blocks
(
&
self
)
->
impl
Iterator
<
Item
=
&
(
SequenceHash
,
Arc
<
()
>
)
>
{
self
.prompt_blocks
.iter
()
.chain
(
self
.output_blocks
.iter
())
}
}
#[derive(Debug,
Default,
Clone,
PartialEq,
Eq)]
pub
(
super
)
struct
PromptMembershipStore
{
pub
parent
:
Option
<
SequenceHash
>
,
pub
hashes
:
Vec
<
SequenceHash
>
,
}
#[derive(Debug,
Default,
Clone,
PartialEq,
Eq)]
pub
(
super
)
struct
PromptMembershipRemove
{
pub
hashes
:
Vec
<
SequenceHash
>
,
}
#[derive(Debug,
Default,
Clone,
PartialEq,
Eq)]
pub
(
super
)
struct
PromptMembershipDelta
{
pub
stores
:
Vec
<
PromptMembershipStore
>
,
pub
removes
:
Vec
<
PromptMembershipRemove
>
,
}
impl
PromptMembershipDelta
{
fn
extend
(
&
mut
self
,
other
:
Self
)
{
self
.stores
.extend
(
other
.stores
);
self
.removes
.extend
(
other
.removes
);
}
fn
push_store
(
&
mut
self
,
parent
:
Option
<
SequenceHash
>
,
hashes
:
Vec
<
SequenceHash
>
)
{
if
hashes
.is_empty
()
{
return
;
}
self
.stores
.push
(
PromptMembershipStore
{
parent
,
hashes
});
}
fn
push_remove
(
&
mut
self
,
hashes
:
Vec
<
SequenceHash
>
)
{
if
hashes
.is_empty
()
{
return
;
}
self
.removes
.push
(
PromptMembershipRemove
{
hashes
});
}
}
#[derive(Debug,
Default,
Clone,
PartialEq,
Eq)]
pub
(
super
)
struct
SequenceMutationOutcome
{
pub
membership_delta
:
PromptMembershipDelta
,
pub
expired_request_ids
:
HashSet
<
RequestId
>
,
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug
,
Getters
)]
#[derive(Debug)]
pub
struct
ActiveSequences
{
pub
struct
ActiveSequences
{
requests
:
HashMap
<
RequestId
,
RequestState
>
,
requests
:
HashMap
<
RequestId
,
RequestState
>
,
prefill
:
PrefillLoadTracker
,
prefill
:
PrefillLoadTracker
,
blocks
:
BlockTracker
,
blocks
:
BlockTracker
,
#[getter(copy)]
block_size
:
usize
,
block_size
:
usize
,
last_expiry_check_time
:
Instant
,
last_expiry_check_time
:
Instant
,
}
}
impl
ActiveSequences
{
impl
ActiveSequences
{
/// Create a new SharedSequenceManager instance
/// Create a new SharedSequenceManager instance
pub
fn
new
(
block_size
:
usize
)
->
Self
{
pub
(
super
)
fn
new
(
block_size
:
usize
)
->
Self
{
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
Self
{
Self
{
...
@@ -77,53 +127,13 @@ impl ActiveSequences {
...
@@ -77,53 +127,13 @@ impl ActiveSequences {
#[cfg(any(test,
debug_assertions))]
#[cfg(any(test,
debug_assertions))]
fn
assert_consistent
(
&
self
)
{
fn
assert_consistent
(
&
self
)
{
let
active_prefills
:
HashSet
<
RequestId
>
=
self
self
.prefill
.assert_consistent
();
.requests
let
active_prefills
:
HashSet
<
RequestId
>
=
self
.prefill.prefills
.keys
()
.cloned
()
.collect
();
.iter
()
let
active_requests
:
HashSet
<
RequestId
>
=
self
.requests
.keys
()
.cloned
()
.collect
();
.filter
(|(
_
,
state
)|
state
.prefill
.is_some
())
assert
!
(
.map
(|(
request_id
,
_
)|
request_id
.clone
())
active_prefills
.is_subset
(
&
active_requests
),
.collect
();
"prefill tracker cannot reference missing request state"
,
let
ordered_prefills
:
HashSet
<
RequestId
>
=
self
.prefill.prefill_order
.iter
()
.cloned
()
.collect
();
let
recomputed_prefill_sum
:
usize
=
self
.requests
.values
()
.filter_map
(|
state
|
state
.prefill
)
.map
(|
prefill
|
prefill
.initial_effective_prefill_tokens
)
.sum
();
assert_eq!
(
ordered_prefills
.len
(),
self
.prefill.prefill_order
.len
(),
"prefill_order contains duplicate request ids"
,
);
assert_eq!
(
ordered_prefills
,
active_prefills
,
"prefill_order must match requests with active prefill load"
,
);
);
assert_eq!
(
self
.prefill.prefill_full_tokens_sum
,
recomputed_prefill_sum
,
"prefill_full_tokens_sum drifted from request state"
,
);
if
let
Some
(
oldest_request_id
)
=
self
.prefill.prefill_order
.front
()
{
let
Some
((
anchored_request_id
,
_
))
=
self
.prefill.anchored_prefill
.as_ref
()
else
{
panic!
(
"anchored_prefill must exist when prefill_order is non-empty"
);
};
assert
!
(
self
.requests
.get
(
oldest_request_id
)
.is_some_and
(|
state
|
state
.prefill
.is_some
()),
"prefill_order front must point to an active prefill request"
,
);
assert_eq!
(
anchored_request_id
,
oldest_request_id
,
"anchored_prefill must match prefill_order.front()"
,
);
}
else
{
assert
!
(
self
.prefill.anchored_prefill
.is_none
(),
"anchored_prefill must be absent when no active prefills remain"
,
);
}
assert
!
(
assert
!
(
self
.blocks
self
.blocks
.fractional_blocks
.fractional_blocks
...
@@ -139,85 +149,19 @@ impl ActiveSequences {
...
@@ -139,85 +149,19 @@ impl ActiveSequences {
self
.assert_consistent
();
self
.assert_consistent
();
}
}
pub
fn
active_blocks
(
&
self
)
->
usize
{
pub
(
super
)
fn
active_blocks
(
&
self
)
->
usize
{
self
.blocks
.active_blocks
()
self
.blocks
.active_blocks
()
}
}
fn
insert_prefill_load
(
#[cfg(test)]
&
mut
self
,
pub
(
super
)
fn
active_tokens
(
&
self
,
decay_now
:
Instant
)
->
usize
{
request_id
:
&
RequestId
,
self
.prefill
.snapshot
()
.active_tokens_at
(
decay_now
)
prefill
:
PrefillLoadState
,
decay_now
:
Instant
,
)
{
self
.prefill
.insert
(
request_id
,
prefill
,
decay_now
);
}
fn
remove_prefill_load
(
&
mut
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
)
->
Option
<
PrefillLoadState
>
{
let
prefill
=
{
let
state
=
self
.requests
.get_mut
(
request_id
)
?
;
state
.prefill
.take
()
?
};
self
.prefill
.remove
(
request_id
,
prefill
,
decay_now
);
Some
(
prefill
)
}
fn
active_prefill_tokens_at
(
&
self
,
now
:
Instant
)
->
usize
{
let
Some
((
oldest_request_id
,
oldest_since
))
=
self
.prefill.anchored_prefill
.as_ref
()
else
{
return
0
;
};
let
prefill
=
self
.requests
.get
(
oldest_request_id
)
.and_then
(|
state
|
state
.prefill
)
.expect
(
"prefill_order front missing prefill load"
);
let
oldest_full
=
prefill
.initial_effective_prefill_tokens
;
let
oldest_remaining
=
match
prefill
.expected_prefill_duration
{
None
=>
oldest_full
,
Some
(
expected_prefill_duration
)
if
expected_prefill_duration
.is_zero
()
=>
0
,
Some
(
expected_prefill_duration
)
=>
{
let
elapsed
=
now
.saturating_duration_since
(
*
oldest_since
);
let
remaining_fraction
=
(
1.0
-
(
elapsed
.as_secs_f64
()
/
expected_prefill_duration
.as_secs_f64
()))
.clamp
(
0.0
,
1.0
);
((
oldest_full
as
f64
)
*
remaining_fraction
)
.ceil
()
as
usize
}
};
self
.prefill
.prefill_full_tokens_sum
.checked_sub
(
oldest_full
)
.expect
(
"prefill_full_tokens_sum smaller than oldest load"
)
+
oldest_remaining
}
pub
fn
active_tokens
(
&
self
,
decay_now
:
Instant
)
->
usize
{
self
.active_prefill_tokens_at
(
decay_now
)
}
/// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value.
pub
fn
set_single_ref_blocks_as_fractional
(
&
mut
self
,
request_id
:
&
RequestId
,
fraction
:
f64
)
{
let
Some
(
request_state
)
=
self
.requests
.get
(
request_id
)
else
{
tracing
::
warn!
(
"Request {request_id} not found for set_single_ref_blocks_as_fractional"
);
return
;
};
for
(
hash
,
rc
)
in
&
request_state
.blocks
{
if
Arc
::
strong_count
(
rc
)
==
1
{
self
.blocks.fractional_blocks
.insert
(
*
hash
,
fraction
);
}
}
}
}
/// Add a new request with its initial tokens.
/// Add a new request with its initial tokens.
/// Returns the set of expired request IDs that were removed during cleanup.
/// Returns block membership transitions plus any expired request IDs removed during cleanup.
pub
fn
add_request
(
#[cfg(test)]
pub
(
super
)
fn
add_request
(
&
mut
self
,
&
mut
self
,
request_id
:
RequestId
,
request_id
:
RequestId
,
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
...
@@ -225,7 +169,7 @@ impl ActiveSequences {
...
@@ -225,7 +169,7 @@ impl ActiveSequences {
overlap
:
u32
,
overlap
:
u32
,
expected_output_tokens
:
Option
<
u32
>
,
expected_output_tokens
:
Option
<
u32
>
,
decay_now
:
Instant
,
decay_now
:
Instant
,
)
->
HashSet
<
RequestId
>
{
)
->
SequenceMutationOutcome
{
self
.add_request_with_prefill_tracking
(
self
.add_request_with_prefill_tracking
(
request_id
,
request_id
,
token_sequence
,
token_sequence
,
...
@@ -239,9 +183,9 @@ impl ActiveSequences {
...
@@ -239,9 +183,9 @@ impl ActiveSequences {
}
}
/// Add a new request with optional prompt-token load accounting.
/// Add a new request with optional prompt-token load accounting.
/// Returns
the set of
expired request IDs
that were
removed during cleanup.
/// Returns
block membership transitions plus any
expired request IDs removed during cleanup.
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub
fn
add_request_with_prefill_tracking
(
pub
(
super
)
fn
add_request_with_prefill_tracking
(
&
mut
self
,
&
mut
self
,
request_id
:
RequestId
,
request_id
:
RequestId
,
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
...
@@ -251,23 +195,48 @@ impl ActiveSequences {
...
@@ -251,23 +195,48 @@ impl ActiveSequences {
track_prefill_tokens
:
bool
,
track_prefill_tokens
:
bool
,
prefill_load_hint
:
Option
<
PrefillLoadHint
>
,
prefill_load_hint
:
Option
<
PrefillLoadHint
>
,
decay_now
:
Instant
,
decay_now
:
Instant
,
)
->
HashSet
<
RequestId
>
{
)
->
SequenceMutationOutcome
{
if
self
.requests
.contains_key
(
&
request_id
)
{
if
self
.requests
.contains_key
(
&
request_id
)
{
tracing
::
error!
(
"Request {request_id} is already active. Ignoring duplicate add."
);
tracing
::
error!
(
"Request {request_id} is already active. Ignoring duplicate add."
);
return
HashSet
::
new
();
return
SequenceMutationOutcome
::
default
();
}
}
let
removed_requests
=
self
.force_expiry
();
let
mut
outcome
=
self
.force_expiry
();
let
started_at
=
Instant
::
now
();
let
started_at
=
Instant
::
now
();
let
blocks
=
match
token_sequence
{
let
prompt_blocks
=
match
token_sequence
{
Some
(
sequence
)
=>
sequence
Some
(
sequence
)
=>
{
.into_iter
()
let
mut
first_new_prompt_idx
=
None
;
.map
(|
block
|
{
let
prompt_blocks
:
Vec
<
_
>
=
sequence
let
rc
=
self
.blocks
.touch_block
(
&
block
);
.into_iter
()
(
block
,
rc
)
.enumerate
()
})
.map
(|(
idx
,
block
)|
{
.collect
(),
let
acquire
=
self
.blocks
.touch_block
(
&
block
);
if
acquire
.became_present_on_worker
&&
first_new_prompt_idx
.is_none
()
{
first_new_prompt_idx
=
Some
(
idx
);
}
(
block
,
acquire
.rc
)
})
.collect
();
if
let
Some
(
first_new_prompt_idx
)
=
first_new_prompt_idx
{
debug_assert!
(
prompt_blocks
[
first_new_prompt_idx
..
]
.iter
()
.all
(|(
hash
,
_
)|
self
.blocks.unique_blocks
.contains_key
(
hash
))
);
let
parent
=
first_new_prompt_idx
.checked_sub
(
1
)
.map
(|
idx
|
prompt_blocks
[
idx
]
.0
);
let
hashes
=
prompt_blocks
[
first_new_prompt_idx
..
]
.iter
()
.map
(|(
hash
,
_
)|
*
hash
)
.collect
();
outcome
.membership_delta
.push_store
(
parent
,
hashes
);
}
prompt_blocks
}
None
=>
Vec
::
new
(),
None
=>
Vec
::
new
(),
};
};
...
@@ -289,166 +258,212 @@ impl ActiveSequences {
...
@@ -289,166 +258,212 @@ impl ActiveSequences {
self
.requests
.insert
(
self
.requests
.insert
(
request_id
.clone
(),
request_id
.clone
(),
RequestState
{
RequestState
{
blocks
,
prompt_blocks
,
output_blocks
:
Vec
::
new
(),
started_at
,
started_at
,
prefill
,
expected_output_tokens
,
expected_output_tokens
,
},
},
);
);
if
let
Some
(
prefill
)
=
prefill
{
if
let
Some
(
prefill
)
=
prefill
{
self
.
insert_prefill_load
(
&
request_id
,
prefill
,
decay_now
);
self
.
prefill
.insert
(
&
request_id
,
prefill
,
decay_now
);
}
}
self
.validate_state
();
self
.validate_state
();
removed_requests
outcome
}
}
/// Mark prefill as completed for a request, removing it from prompt-load tracking.
/// Mark prefill as completed for a request, removing it from prompt-load tracking.
pub
fn
mark_prefill_completed
(
&
mut
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
)
{
pub
(
super
)
fn
mark_prefill_completed
(
&
mut
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
)
{
let
_
=
self
.
remove_prefill_load
(
request_id
,
decay_now
);
let
_
=
self
.
prefill
.remove
(
request_id
,
decay_now
);
self
.validate_state
();
self
.validate_state
();
}
}
pub
fn
new_tokens
(
&
self
,
isl
:
usize
,
overlap
:
u32
)
->
usize
{
let
cached_tokens
=
(
overlap
as
usize
)
*
self
.block_size
;
isl
.checked_sub
(
cached_tokens
)
.unwrap_or_else
(||
{
tracing
::
error!
(
"prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0"
,
self
.block_size
);
0
})
}
pub
fn
potential_blocks_and_tokens
(
&
self
,
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlap
:
u32
,
decay_now
:
Instant
,
)
->
(
usize
,
usize
)
{
self
.potential_blocks_and_tokens_with_prefill_tracking
(
token_sequence
,
isl
,
overlap
,
true
,
decay_now
,
)
}
pub
fn
potential_blocks_and_tokens_with_prefill_tracking
(
&
self
,
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlap
:
u32
,
track_prefill_tokens
:
bool
,
decay_now
:
Instant
,
)
->
(
usize
,
usize
)
{
let
potential_blocks
=
if
let
Some
(
token_seq
)
=
token_sequence
{
self
.new_blocks
(
token_seq
)
+
self
.active_blocks
()
}
else
{
self
.active_blocks
()
};
let
active_tokens
=
self
.active_tokens
(
decay_now
);
let
potential_tokens
=
if
track_prefill_tokens
{
self
.new_tokens
(
isl
,
overlap
)
+
active_tokens
}
else
{
active_tokens
};
(
potential_blocks
,
potential_tokens
)
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub
fn
new_blocks
(
&
self
,
token_sequence
:
&
[
SequenceHash
])
->
usize
{
token_sequence
.iter
()
.filter
(|
block
|
!
self
.blocks.unique_blocks
.contains_key
(
block
))
.count
()
}
/// Return the total number of blocks that would be used if the token sequence was added.
pub
fn
potential_blocks
(
&
self
,
token_sequence
:
&
[
SequenceHash
])
->
usize
{
self
.new_blocks
(
token_sequence
)
+
self
.active_blocks
()
}
/// Free all blocks associated with a request.
/// Free all blocks associated with a request.
///
///
/// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
/// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
/// to invoke both when the request is finishing.
/// to invoke both when the request is finishing.
pub
fn
free
(
&
mut
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
)
->
usize
{
pub
(
super
)
fn
free
(
self
.mark_prefill_completed
(
request_id
,
decay_now
);
&
mut
self
,
request_id
:
&
RequestId
,
decay_now
:
Instant
,
)
->
PromptMembershipDelta
{
let
_
=
self
.prefill
.remove
(
request_id
,
decay_now
);
let
Some
(
request_state
)
=
self
.requests
.remove
(
request_id
)
else
{
let
Some
(
request_state
)
=
self
.requests
.remove
(
request_id
)
else
{
tracing
::
warn!
(
"Trying to free non-existent request {request_id}"
);
tracing
::
warn!
(
"Trying to free non-existent request {request_id}"
);
return
self
.active_blocks
();
return
PromptMembershipDelta
::
default
();
};
};
let
_
=
request_state
.expected_output_tokens
;
let
_
=
request_state
.expected_output_tokens
;
for
(
block_hash
,
rc
)
in
request_state
.blocks
{
let
mut
membership_delta
=
PromptMembershipDelta
::
default
();
let
mut
first_absent_prompt_idx
=
None
;
let
prompt_hashes
:
Vec
<
_
>
=
request_state
.prompt_blocks
.iter
()
.map
(|(
hash
,
_
)|
*
hash
)
.collect
();
for
(
idx
,
(
block_hash
,
rc
))
in
request_state
.prompt_blocks
.into_iter
()
.enumerate
()
{
drop
(
rc
);
if
self
.blocks
.try_remove_block
(
&
block_hash
)
&&
first_absent_prompt_idx
.is_none
()
{
first_absent_prompt_idx
=
Some
(
idx
);
}
}
if
let
Some
(
first_absent_prompt_idx
)
=
first_absent_prompt_idx
{
let
prompt_remove
=
prompt_hashes
[
first_absent_prompt_idx
..
]
.to_vec
();
membership_delta
.push_remove
(
prompt_remove
);
}
for
(
block_hash
,
rc
)
in
request_state
.output_blocks
{
drop
(
rc
);
drop
(
rc
);
self
.blocks
.try_remove_block
(
&
block_hash
);
self
.blocks
.try_remove_block
(
&
block_hash
);
}
}
self
.validate_state
();
self
.validate_state
();
self
.active_blocks
()
membership_delta
}
}
/// Add an output block with a random hash and optional fractional decay weight.
/// Add an output block with a random hash and optional fractional decay weight.
///
///
/// This is used during generation to track output blocks as they are created.
/// This is used during generation to track output blocks as they are created.
pub
fn
add_output_block
(
pub
(
super
)
fn
add_output_block
(
&
mut
self
,
&
mut
self
,
request_id
:
&
RequestId
,
request_id
:
&
RequestId
,
decay_fraction
:
Option
<
f64
>
,
decay_fraction
:
Option
<
f64
>
,
)
->
bool
{
)
->
Option
<
SequenceHash
>
{
if
!
self
.requests
.contains_key
(
request_id
)
{
if
!
self
.requests
.contains_key
(
request_id
)
{
tracing
::
warn!
(
"Request {request_id} not found for add_output_block"
);
tracing
::
warn!
(
"Request {request_id} not found for add_output_block"
);
return
fals
e
;
return
Non
e
;
}
}
// TODO: Output blocks still use random hashes, so indexing them mainly simplifies
// generic block bookkeeping and usually adds little real reuse signal.
let
random_hash
:
SequenceHash
=
Uuid
::
new_v4
()
.as_u64_pair
()
.0
;
let
random_hash
:
SequenceHash
=
Uuid
::
new_v4
()
.as_u64_pair
()
.0
;
let
rc
=
self
.blocks
.touch_block
(
&
random_hash
);
let
acquire
=
self
.blocks
.touch_block
(
&
random_hash
);
self
.requests
self
.requests
.get_mut
(
request_id
)
.get_mut
(
request_id
)
.expect
(
"request existence was checked above"
)
.expect
(
"request existence was checked above"
)
.blocks
.
output_
blocks
.push
((
random_hash
,
rc
));
.push
((
random_hash
,
acquire
.
rc
));
if
let
Some
(
frac
)
=
decay_fraction
{
if
let
Some
(
frac
)
=
decay_fraction
{
self
.set_single_ref_blocks_as_fractional
(
request_id
,
frac
);
self
.set_single_ref_blocks_as_fractional
(
request_id
,
frac
);
}
}
self
.validate_state
();
self
.validate_state
();
true
acquire
.became_present_on_worker
.then_some
(
random_hash
)
}
pub
(
super
)
fn
new_tokens
(
&
self
,
isl
:
usize
,
overlap
:
u32
)
->
usize
{
added_prefill_tokens
(
self
.block_size
,
isl
,
overlap
)
}
#[cfg(test)]
fn
potential_blocks_and_tokens_with_prefill_tracking
(
&
self
,
token_sequence
:
Option
<&
[
SequenceHash
]
>
,
isl
:
usize
,
overlap
:
u32
,
track_prefill_tokens
:
bool
,
decay_now
:
Instant
,
)
->
(
usize
,
usize
)
{
let
potential_blocks
=
if
let
Some
(
token_seq
)
=
token_sequence
{
self
.new_blocks
(
token_seq
)
+
self
.active_blocks
()
}
else
{
self
.active_blocks
()
};
let
active_tokens
=
self
.active_tokens
(
decay_now
);
let
potential_tokens
=
if
track_prefill_tokens
{
self
.new_tokens
(
isl
,
overlap
)
+
active_tokens
}
else
{
active_tokens
};
(
potential_blocks
,
potential_tokens
)
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub
(
super
)
fn
new_blocks
(
&
self
,
token_sequence
:
&
[
SequenceHash
])
->
usize
{
token_sequence
.iter
()
.filter
(|
block
|
!
self
.blocks.unique_blocks
.contains_key
(
block
))
.count
()
}
/// Return the total number of blocks that would be used if the token sequence was added.
pub
(
super
)
fn
potential_blocks
(
&
self
,
token_sequence
:
&
[
SequenceHash
])
->
usize
{
self
.new_blocks
(
token_sequence
)
+
self
.active_blocks
()
}
}
/// Force expiry of stale requests if the timer has elapsed.
/// Force expiry of stale requests if the timer has elapsed.
/// Returns the set of expired request IDs that were removed.
/// Returns
block membership transitions plus
the set of expired request IDs that were removed.
pub
fn
force_expiry
(
&
mut
self
)
->
HashSet
<
RequestId
>
{
pub
(
super
)
fn
force_expiry
(
&
mut
self
)
->
SequenceMutationOutcome
{
let
now
=
Instant
::
now
();
let
now
=
Instant
::
now
();
if
now
<
self
.last_expiry_check_time
+
CHECK_EXPIRY_FREQUENCY
{
if
now
<
self
.last_expiry_check_time
+
CHECK_EXPIRY_FREQUENCY
{
return
HashSet
::
new
();
return
SequenceMutationOutcome
::
default
();
}
}
self
.last_expiry_check_time
=
now
;
self
.last_expiry_check_time
=
now
;
let
expired_requests_time
=
now
-
EXPIRY_DURATION
;
let
expired_requests_time
=
now
-
EXPIRY_DURATION
;
let
expired_requests
:
HashSet
<
RequestId
>
=
self
let
expired_request
_id
s
:
HashSet
<
RequestId
>
=
self
.requests
.requests
.iter
()
.iter
()
.filter
(|(
_
,
state
)|
state
.started_at
<
expired_requests_time
)
.filter
(|(
_
,
state
)|
state
.started_at
<
expired_requests_time
)
.map
(|(
request_id
,
_
)|
request_id
.clone
())
.map
(|(
request_id
,
_
)|
request_id
.clone
())
.collect
();
.collect
();
for
request_id
in
&
expired_requests
{
let
mut
outcome
=
SequenceMutationOutcome
{
expired_request_ids
,
..
Default
::
default
()
};
for
request_id
in
&
outcome
.expired_request_ids
{
tracing
::
warn!
(
"Expiring stale request: {}"
,
request_id
);
tracing
::
warn!
(
"Expiring stale request: {}"
,
request_id
);
self
.free
(
request_id
,
now
);
outcome
.membership_delta
.extend
(
self
.free
(
request_id
,
now
)
)
;
}
}
self
.validate_state
();
self
.validate_state
();
expired_requests
outcome
}
/// Find all blocks in a request that have only a single strong reference (only used by this request)
/// and insert them into fractional_blocks with the given fraction value.
fn
set_single_ref_blocks_as_fractional
(
&
mut
self
,
request_id
:
&
RequestId
,
fraction
:
f64
)
{
let
Some
(
request_state
)
=
self
.requests
.get
(
request_id
)
else
{
tracing
::
warn!
(
"Request {request_id} not found for set_single_ref_blocks_as_fractional"
);
return
;
};
for
(
hash
,
rc
)
in
request_state
.all_blocks
()
{
if
Arc
::
strong_count
(
rc
)
==
1
{
self
.blocks.fractional_blocks
.insert
(
*
hash
,
fraction
);
}
}
}
pub
(
super
)
fn
worker_load_snapshot
(
&
self
)
->
WorkerLoadSnapshot
{
WorkerLoadSnapshot
{
active_blocks
:
self
.active_blocks
(),
prefill
:
self
.prefill
.snapshot
(),
}
}
#[cfg(test)]
pub
(
super
)
fn
active_block_hashes
(
&
self
)
->
FxHashSet
<
SequenceHash
>
{
self
.blocks.unique_blocks
.keys
()
.copied
()
.collect
()
}
#[cfg(test)]
pub
(
super
)
fn
active_prompt_hashes
(
&
self
)
->
FxHashSet
<
SequenceHash
>
{
self
.requests
.values
()
.flat_map
(|
state
|
state
.prompt_blocks
.iter
()
.map
(|(
hash
,
_
)|
*
hash
))
.collect
()
}
}
}
}
...
@@ -464,6 +479,119 @@ mod tests {
...
@@ -464,6 +479,119 @@ mod tests {
}
}
}
}
#[test]
fn
test_prompt_membership_delta_only_reports_first_add_and_last_remove
()
{
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
let
decay_now
=
Instant
::
now
();
let
first
=
seq_manager
.add_request_with_prefill_tracking
(
"r1"
.to_string
(),
Some
(
vec!
[
1
,
2
]),
8
,
0
,
None
,
true
,
None
,
decay_now
,
);
assert_eq!
(
first
.membership_delta
,
PromptMembershipDelta
{
stores
:
vec!
[
PromptMembershipStore
{
parent
:
None
,
hashes
:
vec!
[
1
,
2
],
}],
removes
:
Vec
::
new
(),
}
);
assert
!
(
first
.expired_request_ids
.is_empty
());
let
second
=
seq_manager
.add_request_with_prefill_tracking
(
"r2"
.to_string
(),
Some
(
vec!
[
1
,
2
,
3
]),
12
,
0
,
None
,
true
,
None
,
decay_now
,
);
assert_eq!
(
second
.membership_delta
,
PromptMembershipDelta
{
stores
:
vec!
[
PromptMembershipStore
{
parent
:
Some
(
2
),
hashes
:
vec!
[
3
],
}],
removes
:
Vec
::
new
(),
}
);
let
first_free
=
seq_manager
.free
(
&
"r1"
.to_string
(),
decay_now
);
assert
!
(
first_free
.removes
.is_empty
());
assert
!
(
first_free
.stores
.is_empty
());
let
second_free
=
seq_manager
.free
(
&
"r2"
.to_string
(),
decay_now
);
assert
!
(
second_free
.stores
.is_empty
());
assert_eq!
(
second_free
.removes
,
vec!
[
PromptMembershipRemove
{
hashes
:
vec!
[
1
,
2
,
3
],
}]
);
}
#[test]
fn
test_generic_block_membership_includes_output_blocks
()
{
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
let
decay_now
=
Instant
::
now
();
let
outcome
=
seq_manager
.add_request_with_prefill_tracking
(
"r1"
.to_string
(),
Some
(
vec!
[
1
,
2
,
3
]),
12
,
0
,
None
,
true
,
None
,
decay_now
,
);
assert_eq!
(
outcome
.membership_delta.stores
,
vec!
[
PromptMembershipStore
{
parent
:
None
,
hashes
:
vec!
[
1
,
2
,
3
],
}]
);
assert_eq!
(
seq_manager
.active_block_hashes
(),
[
1
,
2
,
3
]
.into_iter
()
.collect
()
);
let
output_hash
=
seq_manager
.add_output_block
(
&
"r1"
.to_string
(),
Some
(
0.5
))
.expect
(
"request exists"
);
assert_eq!
(
seq_manager
.active_block_hashes
(),
[
1
,
2
,
3
,
output_hash
]
.into_iter
()
.collect
()
);
seq_manager
.mark_prefill_completed
(
&
"r1"
.to_string
(),
decay_now
);
assert_eq!
(
seq_manager
.active_tokens
(
decay_now
),
0
);
assert_eq!
(
seq_manager
.active_block_hashes
(),
[
1
,
2
,
3
,
output_hash
]
.into_iter
()
.collect
()
);
let
free_delta
=
seq_manager
.free
(
&
"r1"
.to_string
(),
decay_now
);
assert_eq!
(
free_delta
.removes
,
vec!
[
PromptMembershipRemove
{
hashes
:
vec!
[
1
,
2
,
3
],
}]
);
}
#[test]
#[test]
fn
test_active_sequences_shared_blocks
()
{
fn
test_active_sequences_shared_blocks
()
{
let
block_size
=
4
;
let
block_size
=
4
;
...
@@ -532,13 +660,21 @@ mod tests {
...
@@ -532,13 +660,21 @@ mod tests {
);
);
assert_eq!
(
seq_manager
.active_blocks
(),
3
);
assert_eq!
(
seq_manager
.active_blocks
(),
3
);
assert
!
(
seq_manager
.add_output_block
(
&
"r1"
.to_string
(),
Some
(
0.5
)));
assert
!
(
seq_manager
.add_output_block
(
&
"r1"
.to_string
(),
Some
(
0.5
))
.is_some
()
);
assert_eq!
(
seq_manager
.active_blocks
(),
2
);
assert_eq!
(
seq_manager
.active_blocks
(),
2
);
seq_manager
.add_request
(
"r2"
.to_string
(),
Some
(
vec!
[
1
,
2
]),
8
,
0
,
None
,
decay_now
);
seq_manager
.add_request
(
"r2"
.to_string
(),
Some
(
vec!
[
1
,
2
]),
8
,
0
,
None
,
decay_now
);
assert_eq!
(
seq_manager
.active_blocks
(),
2
);
assert_eq!
(
seq_manager
.active_blocks
(),
2
);
assert
!
(
seq_manager
.add_output_block
(
&
"r1"
.to_string
(),
Some
(
0.0
)));
assert
!
(
seq_manager
.add_output_block
(
&
"r1"
.to_string
(),
Some
(
0.0
))
.is_some
()
);
assert_eq!
(
seq_manager
.active_blocks
(),
1
);
assert_eq!
(
seq_manager
.active_blocks
(),
1
);
seq_manager
.free
(
&
"r2"
.to_string
(),
decay_now
);
seq_manager
.free
(
&
"r2"
.to_string
(),
decay_now
);
...
@@ -628,181 +764,6 @@ mod tests {
...
@@ -628,181 +764,6 @@ mod tests {
assert_eq!
(
tokens
,
0
);
assert_eq!
(
tokens
,
0
);
}
}
#[test]
fn
test_prefill_decay_only_applies_to_oldest_request
()
{
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
let
epoch
=
Instant
::
now
();
seq_manager
.add_request_with_prefill_tracking
(
"r1"
.to_string
(),
Some
(
vec!
[
1
]),
100
,
0
,
None
,
true
,
Some
(
prefill_hint
(
100
,
10
)),
epoch
,
);
seq_manager
.add_request_with_prefill_tracking
(
"r2"
.to_string
(),
Some
(
vec!
[
2
]),
60
,
0
,
None
,
true
,
Some
(
prefill_hint
(
60
,
6
)),
epoch
+
Duration
::
from_secs
(
2
),
);
assert_eq!
(
seq_manager
.active_tokens
(
epoch
+
Duration
::
from_secs
(
2
)),
140
);
let
decayed
=
seq_manager
.active_tokens
(
epoch
+
Duration
::
from_secs
(
5
));
assert_eq!
(
decayed
,
110
);
assert
!
(
decayed
<=
160
);
assert
!
(
decayed
>=
60
);
}
#[test]
fn
test_prefill_decay_hands_off_to_next_oldest_request
()
{
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
let
epoch
=
Instant
::
now
();
seq_manager
.add_request_with_prefill_tracking
(
"r1"
.to_string
(),
Some
(
vec!
[
1
]),
100
,
0
,
None
,
true
,
Some
(
prefill_hint
(
100
,
10
)),
epoch
,
);
seq_manager
.add_request_with_prefill_tracking
(
"r2"
.to_string
(),
Some
(
vec!
[
2
]),
40
,
0
,
None
,
true
,
Some
(
prefill_hint
(
40
,
8
)),
epoch
,
);
assert_eq!
(
seq_manager
.active_tokens
(
epoch
+
Duration
::
from_secs
(
3
)),
110
);
seq_manager
.mark_prefill_completed
(
&
"r1"
.to_string
(),
epoch
+
Duration
::
from_secs
(
3
));
assert_eq!
(
seq_manager
.active_tokens
(
epoch
+
Duration
::
from_secs
(
3
)),
40
);
assert_eq!
(
seq_manager
.prefill.prefill_order
,
VecDeque
::
from
(
vec!
[
"r2"
.to_string
()])
);
assert
!
(
seq_manager
.prefill
.anchored_prefill
.as_ref
()
.is_some_and
(|(
request_id
,
_
)|
request_id
==
"r2"
)
);
assert_eq!
(
seq_manager
.active_tokens
(
epoch
+
Duration
::
from_secs
(
5
)),
30
);
}
#[test]
fn
test_prefill_decay_resets_when_request_becomes_oldest
()
{
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
let
epoch
=
Instant
::
now
();
seq_manager
.add_request_with_prefill_tracking
(
"r1"
.to_string
(),
Some
(
vec!
[
1
]),
100
,
0
,
None
,
true
,
Some
(
prefill_hint
(
100
,
10
)),
epoch
,
);
seq_manager
.add_request_with_prefill_tracking
(
"r2"
.to_string
(),
Some
(
vec!
[
2
]),
80
,
0
,
None
,
true
,
Some
(
prefill_hint
(
80
,
8
)),
epoch
+
Duration
::
from_secs
(
4
),
);
assert_eq!
(
seq_manager
.active_tokens
(
epoch
+
Duration
::
from_secs
(
8
)),
100
);
seq_manager
.mark_prefill_completed
(
&
"r1"
.to_string
(),
epoch
+
Duration
::
from_secs
(
8
));
assert_eq!
(
seq_manager
.active_tokens
(
epoch
+
Duration
::
from_secs
(
8
)),
80
);
assert_eq!
(
seq_manager
.active_tokens
(
epoch
+
Duration
::
from_secs
(
10
)),
60
);
}
#[test]
fn
test_prefill_front_removal_reanchors_queue_front
()
{
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
let
epoch
=
Instant
::
now
();
seq_manager
.add_request_with_prefill_tracking
(
"r1"
.to_string
(),
Some
(
vec!
[
1
]),
30
,
0
,
None
,
true
,
Some
(
prefill_hint
(
30
,
6
)),
epoch
,
);
seq_manager
.add_request_with_prefill_tracking
(
"r2"
.to_string
(),
Some
(
vec!
[
2
]),
20
,
0
,
None
,
true
,
Some
(
prefill_hint
(
20
,
4
)),
epoch
,
);
seq_manager
.mark_prefill_completed
(
&
"r1"
.to_string
(),
epoch
+
Duration
::
from_secs
(
2
));
assert
!
(
seq_manager
.prefill
.anchored_prefill
.as_ref
()
.is_some_and
(|(
request_id
,
_
)|
request_id
==
"r2"
)
);
assert_eq!
(
seq_manager
.active_tokens
(
epoch
+
Duration
::
from_secs
(
2
)),
20
);
}
#[test]
#[test]
fn
test_prefill_queue_and_sum_invariants_survive_idempotent_cleanup
()
{
fn
test_prefill_queue_and_sum_invariants_survive_idempotent_cleanup
()
{
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
let
mut
seq_manager
=
ActiveSequences
::
new
(
4
);
...
@@ -882,18 +843,27 @@ mod tests {
...
@@ -882,18 +843,27 @@ mod tests {
tokio
::
time
::
advance
(
Duration
::
from_secs
(
20
))
.await
;
tokio
::
time
::
advance
(
Duration
::
from_secs
(
20
))
.await
;
let
expired
=
seq_manager
.force_expiry
();
let
expired
=
seq_manager
.force_expiry
();
assert
!
(
expired
.is_empty
(),
"no check before CHECK_EXPIRY_FREQUENCY"
);
assert
!
(
expired
.expired_request_ids
.is_empty
(),
"no check before CHECK_EXPIRY_FREQUENCY"
);
assert_eq!
(
seq_manager
.active_blocks
(),
4
);
assert_eq!
(
seq_manager
.active_blocks
(),
4
);
tokio
::
time
::
advance
(
Duration
::
from_secs
(
11
))
.await
;
tokio
::
time
::
advance
(
Duration
::
from_secs
(
11
))
.await
;
let
expired
=
seq_manager
.force_expiry
();
let
expired
=
seq_manager
.force_expiry
();
assert
!
(
expired
.is_empty
(),
"requests not old enough to expire"
);
assert
!
(
expired
.expired_request_ids
.is_empty
(),
"requests not old enough to expire"
);
assert_eq!
(
seq_manager
.active_blocks
(),
4
);
assert_eq!
(
seq_manager
.active_blocks
(),
4
);
seq_manager
.assert_consistent
();
seq_manager
.assert_consistent
();
tokio
::
time
::
advance
(
Duration
::
from_secs
(
270
))
.await
;
tokio
::
time
::
advance
(
Duration
::
from_secs
(
270
))
.await
;
let
expired
=
seq_manager
.force_expiry
();
let
expired
=
seq_manager
.force_expiry
();
assert_eq!
(
expired
,
HashSet
::
from
([
"r1"
.to_string
(),
"r2"
.to_string
()]));
assert_eq!
(
expired
.expired_request_ids
,
HashSet
::
from
([
"r1"
.to_string
(),
"r2"
.to_string
()])
);
assert_eq!
(
seq_manager
.active_blocks
(),
0
);
assert_eq!
(
seq_manager
.active_blocks
(),
0
);
assert_eq!
(
seq_manager
.active_tokens
(
Instant
::
now
()),
0
);
assert_eq!
(
seq_manager
.active_tokens
(
Instant
::
now
()),
0
);
seq_manager
.assert_consistent
();
seq_manager
.assert_consistent
();
...
@@ -901,7 +871,7 @@ mod tests {
...
@@ -901,7 +871,7 @@ mod tests {
tokio
::
time
::
advance
(
Duration
::
from_secs
(
31
))
.await
;
tokio
::
time
::
advance
(
Duration
::
from_secs
(
31
))
.await
;
let
expired
=
let
expired
=
seq_manager
.add_request
(
"r3"
.to_string
(),
Some
(
vec!
[
5
]),
4
,
0
,
None
,
Instant
::
now
());
seq_manager
.add_request
(
"r3"
.to_string
(),
Some
(
vec!
[
5
]),
4
,
0
,
None
,
Instant
::
now
());
assert
!
(
expired
.is_empty
());
assert
!
(
expired
.
expired_request_ids
.
is_empty
());
assert_eq!
(
seq_manager
.active_blocks
(),
1
);
assert_eq!
(
seq_manager
.active_blocks
(),
1
);
assert_eq!
(
seq_manager
.active_tokens
(
Instant
::
now
()),
4
);
assert_eq!
(
seq_manager
.active_tokens
(
Instant
::
now
()),
4
);
seq_manager
.assert_consistent
();
seq_manager
.assert_consistent
();
...
@@ -936,7 +906,10 @@ mod tests {
...
@@ -936,7 +906,10 @@ mod tests {
tokio
::
time
::
advance
(
Duration
::
from_secs
(
60
))
.await
;
tokio
::
time
::
advance
(
Duration
::
from_secs
(
60
))
.await
;
let
expired
=
seq_manager
.force_expiry
();
let
expired
=
seq_manager
.force_expiry
();
assert_eq!
(
expired
,
HashSet
::
from
([
"r1"
.to_string
()]));
assert_eq!
(
expired
.expired_request_ids
,
HashSet
::
from
([
"r1"
.to_string
()])
);
assert_eq!
(
seq_manager
.active_tokens
(
Instant
::
now
()),
30
);
assert_eq!
(
seq_manager
.active_tokens
(
Instant
::
now
()),
30
);
assert
!
(
assert
!
(
seq_manager
seq_manager
...
...
lib/kv-router/src/sequences/topology.rs
0 → 100644
View file @
134d484d
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
sync
::
Arc
;
use
parking_lot
::
RwLock
;
use
rustc_hash
::{
FxHashMap
,
FxHashSet
};
use
std
::
collections
::
HashMap
;
use
super
::
prompt_membership_trie
::
WorkerLookup
;
use
super
::
single
::
ActiveSequences
;
use
crate
::
protocols
::
WorkerWithDpRank
;
#[derive(Clone)]
pub
(
super
)
struct
RemovedWorkerState
{
pub
(
super
)
worker
:
WorkerWithDpRank
,
pub
(
super
)
trie_lookup
:
Arc
<
RwLock
<
WorkerLookup
>>
,
}
impl
std
::
fmt
::
Debug
for
RemovedWorkerState
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"RemovedWorkerState"
)
.field
(
"worker"
,
&
self
.worker
)
.finish_non_exhaustive
()
}
}
#[derive(Debug,
Default,
Clone)]
pub
(
super
)
struct
WorkerTopologyChange
{
pub
(
super
)
added
:
Vec
<
WorkerWithDpRank
>
,
pub
(
super
)
removed
:
Vec
<
RemovedWorkerState
>
,
}
pub
(
super
)
struct
WorkerSlot
{
pub
(
super
)
worker
:
WorkerWithDpRank
,
pub
(
super
)
sequences
:
RwLock
<
ActiveSequences
>
,
pub
(
super
)
trie_lookup
:
Arc
<
RwLock
<
WorkerLookup
>>
,
}
impl
WorkerSlot
{
fn
new
(
worker
:
WorkerWithDpRank
,
block_size
:
usize
)
->
Self
{
Self
{
worker
,
sequences
:
RwLock
::
new
(
ActiveSequences
::
new
(
block_size
)),
trie_lookup
:
Arc
::
new
(
RwLock
::
new
(
WorkerLookup
::
default
())),
}
}
}
pub
(
super
)
struct
WorkerTable
{
pub
(
super
)
slots
:
Vec
<
WorkerSlot
>
,
pub
(
super
)
index
:
FxHashMap
<
WorkerWithDpRank
,
usize
>
,
}
impl
WorkerTable
{
pub
(
super
)
fn
new
(
block_size
:
usize
,
dp_range
:
&
HashMap
<
u64
,
(
u32
,
u32
)
>
)
->
Self
{
let
mut
slots
=
Vec
::
new
();
let
mut
index
=
FxHashMap
::
default
();
for
worker
in
workers_from_dp_range
(
dp_range
)
{
let
idx
=
slots
.len
();
slots
.push
(
WorkerSlot
::
new
(
worker
,
block_size
));
index
.insert
(
worker
,
idx
);
}
Self
{
slots
,
index
}
}
pub
(
super
)
fn
workers
(
&
self
)
->
impl
Iterator
<
Item
=
WorkerWithDpRank
>
+
'_
{
self
.slots
.iter
()
.map
(|
slot
|
slot
.worker
)
}
pub
(
super
)
fn
register_external
(
&
mut
self
,
block_size
:
usize
,
dp_range
:
&
HashMap
<
u64
,
(
u32
,
u32
)
>
,
)
->
WorkerTopologyChange
{
let
mut
change
=
WorkerTopologyChange
::
default
();
for
worker
in
workers_from_dp_range
(
dp_range
)
{
if
self
.index
.contains_key
(
&
worker
)
{
continue
;
}
let
idx
=
self
.slots
.len
();
self
.slots
.push
(
WorkerSlot
::
new
(
worker
,
block_size
));
self
.index
.insert
(
worker
,
idx
);
change
.added
.push
(
worker
);
}
change
}
pub
(
super
)
fn
reconcile
(
&
mut
self
,
block_size
:
usize
,
new_dp_range
:
&
HashMap
<
u64
,
(
u32
,
u32
)
>
,
)
->
WorkerTopologyChange
{
let
target_workers
:
FxHashSet
<
WorkerWithDpRank
>
=
workers_from_dp_range
(
new_dp_range
)
.into_iter
()
.collect
();
let
mut
old
:
FxHashMap
<
WorkerWithDpRank
,
WorkerSlot
>
=
self
.slots
.drain
(
..
)
.map
(|
slot
|
(
slot
.worker
,
slot
))
.collect
();
self
.index
.clear
();
let
mut
added
=
Vec
::
new
();
for
worker
in
target_workers
{
if
!
old
.contains_key
(
&
worker
)
{
added
.push
(
worker
);
}
let
idx
=
self
.slots
.len
();
let
slot
=
old
.remove
(
&
worker
)
.unwrap_or_else
(||
WorkerSlot
::
new
(
worker
,
block_size
));
self
.slots
.push
(
slot
);
self
.index
.insert
(
worker
,
idx
);
}
let
removed
=
old
.into_values
()
.map
(|
slot
|
RemovedWorkerState
{
worker
:
slot
.worker
,
trie_lookup
:
slot
.trie_lookup
,
})
.collect
();
WorkerTopologyChange
{
added
,
removed
}
}
pub
(
super
)
fn
ensure_worker
(
&
mut
self
,
block_size
:
usize
,
worker
:
WorkerWithDpRank
,
)
->
WorkerTopologyChange
{
if
self
.index
.contains_key
(
&
worker
)
{
return
WorkerTopologyChange
::
default
();
}
let
idx
=
self
.slots
.len
();
self
.slots
.push
(
WorkerSlot
::
new
(
worker
,
block_size
));
self
.index
.insert
(
worker
,
idx
);
WorkerTopologyChange
{
added
:
vec!
[
worker
],
removed
:
Vec
::
new
(),
}
}
}
fn
workers_from_dp_range
(
dp_range
:
&
HashMap
<
u64
,
(
u32
,
u32
)
>
)
->
Vec
<
WorkerWithDpRank
>
{
let
mut
workers
=
Vec
::
new
();
for
(
&
worker_id
,
&
(
dp_start
,
dp_size
))
in
dp_range
{
for
dp_rank
in
dp_start
..
(
dp_start
+
dp_size
)
{
workers
.push
(
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
));
}
}
workers
}
#[cfg(test)]
mod
tests
{
use
tokio
::
time
::
Instant
;
use
super
::
*
;
fn
worker
(
worker_id
:
u64
,
dp_rank
:
u32
)
->
WorkerWithDpRank
{
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
)
}
#[test]
fn
new_expands_dp_ranges_into_slots_and_index
()
{
let
table
=
WorkerTable
::
new
(
4
,
&
HashMap
::
from
([(
7
,
(
2
,
3
)),
(
9
,
(
0
,
1
))]));
let
workers
:
FxHashSet
<
_
>
=
table
.workers
()
.collect
();
assert_eq!
(
workers
,
FxHashSet
::
from_iter
([
worker
(
7
,
2
),
worker
(
7
,
3
),
worker
(
7
,
4
),
worker
(
9
,
0
)])
);
assert_eq!
(
table
.index
.len
(),
4
);
assert_eq!
(
table
.slots
.len
(),
4
);
for
worker
in
workers
{
assert
!
(
table
.index
.contains_key
(
&
worker
));
}
}
#[test]
fn
register_external_only_adds_missing_workers
()
{
let
mut
table
=
WorkerTable
::
new
(
4
,
&
HashMap
::
from
([(
1
,
(
0
,
1
))]));
let
change
=
table
.register_external
(
4
,
&
HashMap
::
from
([(
1
,
(
0
,
2
)),
(
2
,
(
0
,
1
))]));
assert_eq!
(
change
.added
.into_iter
()
.collect
::
<
FxHashSet
<
_
>>
(),
FxHashSet
::
from_iter
([
worker
(
1
,
1
),
worker
(
2
,
0
)])
);
assert
!
(
change
.removed
.is_empty
());
assert_eq!
(
table
.index
.len
(),
3
);
}
#[test]
fn
ensure_worker_is_idempotent
()
{
let
mut
table
=
WorkerTable
::
new
(
4
,
&
HashMap
::
from
([(
1
,
(
0
,
1
))]));
let
target
=
worker
(
2
,
0
);
let
first
=
table
.ensure_worker
(
4
,
target
);
let
second
=
table
.ensure_worker
(
4
,
target
);
assert_eq!
(
first
.added
,
vec!
[
target
]);
assert
!
(
first
.removed
.is_empty
());
assert
!
(
second
.added
.is_empty
());
assert
!
(
second
.removed
.is_empty
());
assert_eq!
(
table
.index
.len
(),
2
);
}
#[test]
fn
reconcile_preserves_existing_worker_state_and_reports_delta
()
{
let
mut
table
=
WorkerTable
::
new
(
4
,
&
HashMap
::
from
([(
1
,
(
0
,
1
)),
(
2
,
(
0
,
1
))]));
let
existing
=
worker
(
1
,
0
);
let
removed
=
worker
(
2
,
0
);
let
added
=
worker
(
3
,
0
);
{
let
idx
=
table
.index
[
&
existing
];
let
mut
seq
=
table
.slots
[
idx
]
.sequences
.write
();
let
outcome
=
seq
.add_request
(
"req-1"
.to_string
(),
Some
(
vec!
[
1
,
2
,
3
]),
12
,
0
,
None
,
Instant
::
now
(),
);
assert_eq!
(
outcome
.membership_delta.stores
[
0
]
.hashes
,
vec!
[
1
,
2
,
3
],);
}
let
change
=
table
.reconcile
(
4
,
&
HashMap
::
from
([(
1
,
(
0
,
1
)),
(
3
,
(
0
,
1
))]));
assert_eq!
(
change
.added
,
vec!
[
added
]);
assert_eq!
(
change
.removed
.iter
()
.map
(|
state
|
state
.worker
)
.collect
::
<
Vec
<
_
>>
(),
vec!
[
removed
]
);
assert
!
(
table
.index
.contains_key
(
&
existing
));
assert
!
(
table
.index
.contains_key
(
&
added
));
assert
!
(
!
table
.index
.contains_key
(
&
removed
));
let
existing_idx
=
table
.index
[
&
existing
];
assert_eq!
(
table
.slots
[
existing_idx
]
.sequences
.read
()
.active_blocks
(),
3
);
let
added_idx
=
table
.index
[
&
added
];
assert_eq!
(
table
.slots
[
added_idx
]
.sequences
.read
()
.active_blocks
(),
0
);
}
}
lib/llm/src/kv_router/sequence.rs
View file @
134d484d
...
@@ -145,58 +145,6 @@ mod tests {
...
@@ -145,58 +145,6 @@ mod tests {
use
dynamo_runtime
::{
DistributedRuntime
,
Runtime
};
use
dynamo_runtime
::{
DistributedRuntime
,
Runtime
};
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
#[test]
fn
test_active_sequences_shared_blocks
()
{
let
block_size
=
4
;
let
mut
seq_manager
=
ActiveSequences
::
new
(
block_size
);
let
decay_now
=
Instant
::
now
();
seq_manager
.add_request
(
"request_1"
.to_string
(),
Some
(
vec!
[
1
,
2
,
3
]),
12
,
0
,
None
,
decay_now
,
);
assert_eq!
(
seq_manager
.active_blocks
(),
3
);
assert_eq!
(
seq_manager
.active_tokens
(
decay_now
),
12
);
seq_manager
.add_request
(
"request_2"
.to_string
(),
Some
(
vec!
[
4
]),
4
,
0
,
None
,
decay_now
,
);
assert_eq!
(
seq_manager
.active_blocks
(),
4
);
assert_eq!
(
seq_manager
.active_tokens
(
decay_now
),
16
);
seq_manager
.add_request
(
"request_3"
.to_string
(),
Some
(
vec!
[
1
,
2
,
3
,
4
]),
16
,
4
,
None
,
decay_now
,
);
assert_eq!
(
seq_manager
.active_blocks
(),
4
);
assert_eq!
(
seq_manager
.active_tokens
(
decay_now
),
16
);
seq_manager
.free
(
&
"request_2"
.to_string
(),
decay_now
);
assert_eq!
(
seq_manager
.active_blocks
(),
4
);
assert_eq!
(
seq_manager
.active_tokens
(
decay_now
),
12
);
seq_manager
.free
(
&
"request_3"
.to_string
(),
decay_now
);
assert_eq!
(
seq_manager
.active_blocks
(),
3
);
assert_eq!
(
seq_manager
.active_tokens
(
decay_now
),
12
);
seq_manager
.free
(
&
"request_1"
.to_string
(),
decay_now
);
assert_eq!
(
seq_manager
.active_blocks
(),
0
);
assert_eq!
(
seq_manager
.active_tokens
(
decay_now
),
0
);
}
#[tokio::test]
#[tokio::test]
#[ignore]
#[ignore]
async
fn
test_multi_worker_cross_instance_sync
()
->
Result
<
()
>
{
async
fn
test_multi_worker_cross_instance_sync
()
->
Result
<
()
>
{
...
...
lib/mocker/src/replay/offline/components/router.rs
View file @
134d484d
...
@@ -19,6 +19,7 @@ use dynamo_kv_router::{
...
@@ -19,6 +19,7 @@ use dynamo_kv_router::{
SchedulingPolicy
,
SchedulingRequest
,
SequenceRequest
,
WorkerSelector
,
SchedulingPolicy
,
SchedulingRequest
,
SequenceRequest
,
WorkerSelector
,
};
};
use
dynamo_tokens
::
SequenceHash
;
use
dynamo_tokens
::
SequenceHash
;
use
rustc_hash
::
FxHashMap
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
...
@@ -124,8 +125,8 @@ impl PendingRequest {
...
@@ -124,8 +125,8 @@ impl PendingRequest {
fn
scheduling_request
(
fn
scheduling_request
(
&
self
,
&
self
,
decode_blocks
:
HashMap
<
WorkerWithDpRank
,
usize
>
,
decode_blocks
:
Fx
HashMap
<
WorkerWithDpRank
,
usize
>
,
prefill_tokens
:
HashMap
<
WorkerWithDpRank
,
usize
>
,
prefill_tokens
:
Fx
HashMap
<
WorkerWithDpRank
,
usize
>
,
)
->
SchedulingRequest
{
)
->
SchedulingRequest
{
SchedulingRequest
{
SchedulingRequest
{
maybe_request_id
:
Some
(
self
.request_id
()),
maybe_request_id
:
Some
(
self
.request_id
()),
...
@@ -408,7 +409,7 @@ impl OfflineReplayRouter {
...
@@ -408,7 +409,7 @@ impl OfflineReplayRouter {
let
arrival_offset
=
Duration
::
from_secs_f64
((
now_ms
.max
(
0.0
))
/
1000.0
);
let
arrival_offset
=
Duration
::
from_secs_f64
((
now_ms
.max
(
0.0
))
/
1000.0
);
self
.policy
.enqueue_key
(
self
.policy
.enqueue_key
(
arrival_offset
,
arrival_offset
,
&
request
.scheduling_request
(
HashMap
::
new
(),
HashMap
::
new
()),
&
request
.scheduling_request
(
Fx
HashMap
::
default
(),
Fx
HashMap
::
default
()),
)
)
}
}
...
...
lib/mocker/src/scheduler/sglang/core.rs
View file @
134d484d
...
@@ -272,12 +272,15 @@ fn simulate_prefill_duration(
...
@@ -272,12 +272,15 @@ fn simulate_prefill_duration(
}
}
fn
debug_assert_sglang_scheduler_state
(
fn
debug_assert_sglang_scheduler_state
(
waiting
:
&
VecDeque
<
SglangRequest
>
,
_
waiting
:
&
VecDeque
<
SglangRequest
>
,
running
:
&
[
SglangRequest
],
_
running
:
&
[
SglangRequest
],
block_size
:
usize
,
_
block_size
:
usize
,
)
{
)
{
#[cfg(debug_assertions)]
#[cfg(debug_assertions)]
{
{
let
waiting
=
_
waiting
;
let
running
=
_
running
;
let
block_size
=
_
block_size
;
let
mut
seen
=
std
::
collections
::
HashSet
::
new
();
let
mut
seen
=
std
::
collections
::
HashSet
::
new
();
for
req
in
waiting
{
for
req
in
waiting
{
debug_assert!
(
debug_assert!
(
...
...
lib/mocker/src/scheduler/sglang/request.rs
View file @
134d484d
...
@@ -87,9 +87,10 @@ impl SglangRequest {
...
@@ -87,9 +87,10 @@ impl SglangRequest {
self
.materialized_tokens
+=
1
;
self
.materialized_tokens
+=
1
;
}
}
pub
(
super
)
fn
debug_assert_invariants
(
&
self
,
block_size
:
usize
)
{
pub
(
super
)
fn
debug_assert_invariants
(
&
self
,
_
block_size
:
usize
)
{
#[cfg(debug_assertions)]
#[cfg(debug_assertions)]
{
{
let
block_size
=
_
block_size
;
let
sequence_len
=
self
.current_sequence_len
();
let
sequence_len
=
self
.current_sequence_len
();
debug_assert!
(
debug_assert!
(
self
.cached_tokens
<=
self
.materialized_tokens
,
self
.cached_tokens
<=
self
.materialized_tokens
,
...
...
lib/mocker/src/scheduler/vllm/core.rs
View file @
134d484d
...
@@ -759,9 +759,11 @@ fn request_sequence_len(requests: &FxHashMap<Uuid, VllmRequestState>, uuid: Uuid
...
@@ -759,9 +759,11 @@ fn request_sequence_len(requests: &FxHashMap<Uuid, VllmRequestState>, uuid: Uuid
.unwrap_or_default
()
.unwrap_or_default
()
}
}
fn
debug_assert_vllm_request_invariants
(
uuid
:
Uuid
,
request
:
&
VllmRequestState
)
{
fn
debug_assert_vllm_request_invariants
(
_
u
uid
:
Uuid
,
_
request
:
&
VllmRequestState
)
{
#[cfg(debug_assertions)]
#[cfg(debug_assertions)]
{
{
let
uuid
=
_u
uid
;
let
request
=
_
request
;
let
seq_len
=
request
.sequence
.len
();
let
seq_len
=
request
.sequence
.len
();
let
allocated
=
request
.sequence
.num_allocated_tokens
();
let
allocated
=
request
.sequence
.num_allocated_tokens
();
debug_assert!
(
debug_assert!
(
...
@@ -776,9 +778,11 @@ fn debug_assert_vllm_request_invariants(uuid: Uuid, request: &VllmRequestState)
...
@@ -776,9 +778,11 @@ fn debug_assert_vllm_request_invariants(uuid: Uuid, request: &VllmRequestState)
}
}
}
}
fn
debug_assert_vllm_request_progress
(
uuid
:
Uuid
,
request
:
&
VllmRequestState
)
{
fn
debug_assert_vllm_request_progress
(
_
u
uid
:
Uuid
,
_
request
:
&
VllmRequestState
)
{
#[cfg(debug_assertions)]
#[cfg(debug_assertions)]
{
{
let
uuid
=
_u
uid
;
let
request
=
_
request
;
debug_assert_vllm_request_invariants
(
uuid
,
request
);
debug_assert_vllm_request_invariants
(
uuid
,
request
);
let
allocated
=
request
.sequence
.num_allocated_tokens
();
let
allocated
=
request
.sequence
.num_allocated_tokens
();
debug_assert!
(
debug_assert!
(
...
@@ -789,9 +793,11 @@ fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) {
...
@@ -789,9 +793,11 @@ fn debug_assert_vllm_request_progress(uuid: Uuid, request: &VllmRequestState) {
}
}
}
}
fn
debug_assert_vllm_ready_to_decode
(
requests
:
&
FxHashMap
<
Uuid
,
VllmRequestState
>
,
uuid
:
Uuid
)
{
fn
debug_assert_vllm_ready_to_decode
(
_
requests
:
&
FxHashMap
<
Uuid
,
VllmRequestState
>
,
_
u
uid
:
Uuid
)
{
#[cfg(debug_assertions)]
#[cfg(debug_assertions)]
{
{
let
requests
=
_
requests
;
let
uuid
=
_u
uid
;
let
Some
(
request
)
=
requests
.get
(
&
uuid
)
else
{
let
Some
(
request
)
=
requests
.get
(
&
uuid
)
else
{
return
;
return
;
};
};
...
@@ -807,9 +813,10 @@ fn debug_assert_vllm_ready_to_decode(requests: &FxHashMap<Uuid, VllmRequestState
...
@@ -807,9 +813,10 @@ fn debug_assert_vllm_ready_to_decode(requests: &FxHashMap<Uuid, VllmRequestState
}
}
}
}
fn
debug_assert_vllm_scheduler_state
(
state
:
&
SchedulerState
)
{
fn
debug_assert_vllm_scheduler_state
(
_
state
:
&
SchedulerState
)
{
#[cfg(debug_assertions)]
#[cfg(debug_assertions)]
{
{
let
state
=
_
state
;
let
mut
seen
=
std
::
collections
::
HashSet
::
new
();
let
mut
seen
=
std
::
collections
::
HashSet
::
new
();
for
uuid
in
&
state
.waiting_members
{
for
uuid
in
&
state
.waiting_members
{
debug_assert!
(
debug_assert!
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment