Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
efa89448
Unverified
Commit
efa89448
authored
Feb 24, 2026
by
Yan Ru Pei
Committed by
GitHub
Feb 24, 2026
Browse files
chore: de-async scheduler read paths and unblock decode output tracking (#6510)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
6fab12be
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
156 additions
and
504 deletions
+156
-504
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+3
-6
lib/llm/src/kv_router/push_router.rs
lib/llm/src/kv_router/push_router.rs
+0
-1
lib/llm/src/kv_router/queue.rs
lib/llm/src/kv_router/queue.rs
+9
-12
lib/llm/src/kv_router/scheduler.rs
lib/llm/src/kv_router/scheduler.rs
+3
-5
lib/llm/src/kv_router/sequence.rs
lib/llm/src/kv_router/sequence.rs
+141
-480
No files found.
lib/llm/src/kv_router.rs
View file @
efa89448
...
@@ -494,14 +494,12 @@ impl KvRouter {
...
@@ -494,14 +494,12 @@ impl KvRouter {
self
.scheduler
.worker_type
()
self
.scheduler
.worker_type
()
}
}
pub
async
fn
add_output_block
(
pub
fn
add_output_block
(
&
self
,
&
self
,
request_id
:
&
str
,
request_id
:
&
str
,
decay_fraction
:
Option
<
f64
>
,
decay_fraction
:
Option
<
f64
>
,
)
->
Result
<
(),
SequenceError
>
{
)
->
Result
<
(),
SequenceError
>
{
self
.scheduler
self
.scheduler
.add_output_block
(
request_id
,
decay_fraction
)
.add_output_block
(
request_id
,
decay_fraction
)
.await
}
}
pub
fn
block_size
(
&
self
)
->
u32
{
pub
fn
block_size
(
&
self
)
->
u32
{
...
@@ -541,8 +539,7 @@ impl KvRouter {
...
@@ -541,8 +539,7 @@ impl KvRouter {
Ok
(
self
Ok
(
self
.scheduler
.scheduler
.get_potential_loads
(
maybe_seq_hashes
,
isl_tokens
,
overlap_scores
)
.get_potential_loads
(
maybe_seq_hashes
,
isl_tokens
,
overlap_scores
))
.await
)
}
}
/// Dump all events from the indexer
/// Dump all events from the indexer
...
...
lib/llm/src/kv_router/push_router.rs
View file @
efa89448
...
@@ -108,7 +108,6 @@ impl RequestGuard {
...
@@ -108,7 +108,6 @@ impl RequestGuard {
if
let
Err
(
e
)
=
self
if
let
Err
(
e
)
=
self
.chooser
.chooser
.add_output_block
(
&
self
.context_id
,
decay_fraction
)
.add_output_block
(
&
self
.context_id
,
decay_fraction
)
.await
{
{
tracing
::
warn!
(
tracing
::
warn!
(
"Failed to add output block for request {}: {e}"
,
"Failed to add output block for request {}: {e}"
,
...
...
lib/llm/src/kv_router/queue.rs
View file @
efa89448
...
@@ -103,7 +103,7 @@ impl SchedulerQueue {
...
@@ -103,7 +103,7 @@ impl SchedulerQueue {
return
;
return
;
};
};
if
self
.all_workers_busy
(
threshold
)
.await
{
if
self
.all_workers_busy
(
threshold
)
{
tracing
::
debug!
(
"all workers busy, queueing request"
);
tracing
::
debug!
(
"all workers busy, queueing request"
);
let
entry
=
self
.make_entry
(
request
);
let
entry
=
self
.make_entry
(
request
);
self
.pending
.lock
()
.await
.push
(
entry
);
self
.pending
.lock
()
.await
.push
(
entry
);
...
@@ -121,7 +121,7 @@ impl SchedulerQueue {
...
@@ -121,7 +121,7 @@ impl SchedulerQueue {
};
};
loop
{
loop
{
if
self
.all_workers_busy
(
threshold
)
.await
{
if
self
.all_workers_busy
(
threshold
)
{
break
;
break
;
}
}
let
Some
(
entry
)
=
self
.pending
.lock
()
.await
.pop
()
else
{
let
Some
(
entry
)
=
self
.pending
.lock
()
.await
.pop
()
else
{
...
@@ -135,14 +135,11 @@ impl SchedulerQueue {
...
@@ -135,14 +135,11 @@ impl SchedulerQueue {
/// Run the full scheduling pipeline for a single request:
/// Run the full scheduling pipeline for a single request:
/// compute potential load → select worker → respond → book via add_request.
/// compute potential load → select worker → respond → book via add_request.
async
fn
schedule
(
&
self
,
mut
request
:
SchedulingRequest
)
{
async
fn
schedule
(
&
self
,
mut
request
:
SchedulingRequest
)
{
let
(
decode_blocks
,
prefill_tokens
)
=
self
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.potential_blocks_and_tokens
(
.slots
request
.token_seq
.clone
(),
.potential_blocks_and_tokens
(
request
.isl_tokens
,
request
.token_seq
.clone
(),
request
.overlaps
.clone
(),
request
.isl_tokens
,
);
request
.overlaps
.clone
(),
)
.await
;
request
.decode_blocks
=
decode_blocks
;
request
.decode_blocks
=
decode_blocks
;
request
.prefill_tokens
=
prefill_tokens
;
request
.prefill_tokens
=
prefill_tokens
;
...
@@ -194,8 +191,8 @@ impl SchedulerQueue {
...
@@ -194,8 +191,8 @@ impl SchedulerQueue {
/// Check if all workers are busy based on threshold.
/// Check if all workers are busy based on threshold.
/// Returns true only if ALL workers exceed the threshold (no worker has capacity).
/// Returns true only if ALL workers exceed the threshold (no worker has capacity).
async
fn
all_workers_busy
(
&
self
,
threshold
:
f64
)
->
bool
{
fn
all_workers_busy
(
&
self
,
threshold
:
f64
)
->
bool
{
let
active_tokens
=
self
.slots
.active_tokens
()
.await
;
let
active_tokens
=
self
.slots
.active_tokens
();
let
configs
=
self
.workers_with_configs
.borrow
();
let
configs
=
self
.workers_with_configs
.borrow
();
for
(
&
worker_id
,
config
)
in
configs
.iter
()
{
for
(
&
worker_id
,
config
)
in
configs
.iter
()
{
...
...
lib/llm/src/kv_router/scheduler.rs
View file @
efa89448
...
@@ -272,17 +272,16 @@ impl KvScheduler {
...
@@ -272,17 +272,16 @@ impl KvScheduler {
self
.slots
.worker_type
()
self
.slots
.worker_type
()
}
}
pub
async
fn
add_output_block
(
pub
fn
add_output_block
(
&
self
,
&
self
,
request_id
:
&
str
,
request_id
:
&
str
,
decay_fraction
:
Option
<
f64
>
,
decay_fraction
:
Option
<
f64
>
,
)
->
Result
<
(),
SequenceError
>
{
)
->
Result
<
(),
SequenceError
>
{
self
.slots
self
.slots
.add_output_block
(
&
request_id
.to_string
(),
decay_fraction
)
.add_output_block
(
&
request_id
.to_string
(),
decay_fraction
)
.await
}
}
pub
async
fn
get_potential_loads
(
pub
fn
get_potential_loads
(
&
self
,
&
self
,
token_seq
:
Option
<
Vec
<
SequenceHash
>>
,
token_seq
:
Option
<
Vec
<
SequenceHash
>>
,
isl_tokens
:
usize
,
isl_tokens
:
usize
,
...
@@ -290,8 +289,7 @@ impl KvScheduler {
...
@@ -290,8 +289,7 @@ impl KvScheduler {
)
->
Vec
<
PotentialLoad
>
{
)
->
Vec
<
PotentialLoad
>
{
let
(
decode_blocks
,
prefill_tokens
)
=
self
let
(
decode_blocks
,
prefill_tokens
)
=
self
.slots
.slots
.potential_blocks_and_tokens
(
token_seq
,
isl_tokens
,
overlaps
)
.potential_blocks_and_tokens
(
token_seq
,
isl_tokens
,
overlaps
);
.await
;
// Get all unique WorkerWithDpRank from both hashmaps
// Get all unique WorkerWithDpRank from both hashmaps
let
mut
workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
let
mut
workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
...
...
lib/llm/src/kv_router/sequence.rs
View file @
efa89448
...
@@ -9,12 +9,11 @@
...
@@ -9,12 +9,11 @@
//!
//!
//! # Key Components
//! # Key Components
//!
//!
//! - [`ActiveSequences`]:
Single-threaded
sequence manager that tracks active requests and their
//! - [`ActiveSequences`]:
Per-worker
sequence manager that tracks active requests and their
//! token sequences, managing shared KV cache blocks efficiently.
//! token sequences, managing shared KV cache blocks efficiently.
//!
//!
//! - [`ActiveSequencesMultiWorker`]: Multi-threaded extension that distributes sequence management
//! - [`ActiveSequencesMultiWorker`]: Multi-worker extension that stores per-worker
//! across multiple worker threads, enabling parallel processing of requests while maintaining
//! `ActiveSequences` in a shared `DashMap` for lock-free concurrent access.
//! consistency.
//!
//!
//! # Architecture
//! # Architecture
//!
//!
...
@@ -31,7 +30,6 @@ use dynamo_runtime::traits::DistributedRuntimeProvider;
...
@@ -31,7 +30,6 @@ use dynamo_runtime::traits::DistributedRuntimeProvider;
use
dynamo_runtime
::
transports
::
event_plane
::{
EventPublisher
,
EventSubscriber
};
use
dynamo_runtime
::
transports
::
event_plane
::{
EventPublisher
,
EventSubscriber
};
use
dynamo_tokens
::
SequenceHash
;
use
dynamo_tokens
::
SequenceHash
;
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
rc
::{
Rc
,
Weak
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
std
::
time
::
Duration
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
...
@@ -62,9 +60,6 @@ pub enum SequenceError {
...
@@ -62,9 +60,6 @@ pub enum SequenceError {
#[error(
"Failed to publish event: {0}"
)]
#[error(
"Failed to publish event: {0}"
)]
PublishFailed
(
#[from]
anyhow
::
Error
),
PublishFailed
(
#[from]
anyhow
::
Error
),
#[error(
"Failed to send command to worker: channel closed"
)]
WorkerChannelClosed
,
}
}
/// Duration after which stale requests are forcibly expired (5 minutes)
/// Duration after which stale requests are forcibly expired (5 minutes)
...
@@ -87,14 +82,14 @@ pub struct SequenceRequest {
...
@@ -87,14 +82,14 @@ pub struct SequenceRequest {
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug,
Getters)]
#[derive(Debug,
Getters)]
pub
struct
ActiveSequences
{
pub
struct
ActiveSequences
{
active_seqs
:
HashMap
<
RequestId
,
Vec
<
(
SequenceHash
,
R
c
<
()
>
)
>>
,
active_seqs
:
HashMap
<
RequestId
,
Vec
<
(
SequenceHash
,
Ar
c
<
()
>
)
>>
,
prefill_tokens
:
HashMap
<
RequestId
,
usize
>
,
prefill_tokens
:
HashMap
<
RequestId
,
usize
>
,
/// Expected output tokens per request (used for resource estimation)
/// Expected output tokens per request (used for resource estimation)
expected_output_tokens
:
HashMap
<
RequestId
,
u32
>
,
expected_output_tokens
:
HashMap
<
RequestId
,
u32
>
,
unique_blocks
:
HashMap
<
SequenceHash
,
Weak
<
()
>>
,
unique_blocks
:
HashMap
<
SequenceHash
,
std
::
sync
::
Weak
<
()
>>
,
/// Fractional block counts for blocks that are partially cached
/// Fractional block counts for blocks that are partially cached
/// When a block is in both unique_blocks and fractional_blocks,
/// When a block is in both unique_blocks and fractional_blocks,
...
@@ -133,15 +128,15 @@ impl ActiveSequences {
...
@@ -133,15 +128,15 @@ impl ActiveSequences {
}
}
}
}
fn
touch_block
(
&
mut
self
,
block
:
&
SequenceHash
)
->
R
c
<
()
>
{
fn
touch_block
(
&
mut
self
,
block
:
&
SequenceHash
)
->
Ar
c
<
()
>
{
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
if
let
Some
(
weak
)
=
self
.unique_blocks
.get
(
block
)
&&
let
Some
(
rc
)
=
weak
.upgrade
()
&&
let
Some
(
rc
)
=
weak
.upgrade
()
{
{
return
rc
;
return
rc
;
}
}
let
rc
=
R
c
::
new
(());
let
rc
=
Ar
c
::
new
(());
self
.unique_blocks
.insert
(
*
block
,
R
c
::
downgrade
(
&
rc
));
self
.unique_blocks
.insert
(
*
block
,
Ar
c
::
downgrade
(
&
rc
));
rc
rc
}
}
...
@@ -177,7 +172,7 @@ impl ActiveSequences {
...
@@ -177,7 +172,7 @@ impl ActiveSequences {
for
(
hash
,
rc
)
in
blocks
{
for
(
hash
,
rc
)
in
blocks
{
// A block with strong_count == 1 means only this request holds a reference
// A block with strong_count == 1 means only this request holds a reference
if
R
c
::
strong_count
(
rc
)
==
1
{
if
Ar
c
::
strong_count
(
rc
)
==
1
{
self
.fractional_blocks
.insert
(
*
hash
,
fraction
);
self
.fractional_blocks
.insert
(
*
hash
,
fraction
);
}
}
}
}
...
@@ -214,7 +209,7 @@ impl ActiveSequences {
...
@@ -214,7 +209,7 @@ impl ActiveSequences {
}
}
if
let
Some
(
sequence
)
=
token_sequence
{
if
let
Some
(
sequence
)
=
token_sequence
{
let
sequence_with_refs
:
Vec
<
(
SequenceHash
,
R
c
<
()
>
)
>
=
sequence
let
sequence_with_refs
:
Vec
<
(
SequenceHash
,
Ar
c
<
()
>
)
>
=
sequence
.iter
()
.iter
()
.map
(|
block
|
(
*
block
,
self
.touch_block
(
block
)))
.map
(|
block
|
(
*
block
,
self
.touch_block
(
block
)))
.collect
();
.collect
();
...
@@ -370,64 +365,16 @@ impl ActiveSequences {
...
@@ -370,64 +365,16 @@ impl ActiveSequences {
}
}
}
}
enum
UpdateSequences
{
/// Multi-worker extension of ActiveSequences using shared DashMap for lock-free concurrent access
AddRequest
{
request_id
:
RequestId
,
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
isl
:
usize
,
overlap
:
u32
,
expected_output_tokens
:
Option
<
u32
>
,
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
HashSet
<
RequestId
>>
,
},
Free
{
request_id
:
RequestId
,
},
MarkPrefillCompleted
{
request_id
:
RequestId
,
},
AddOutputBlock
{
request_id
:
RequestId
,
decay_fraction
:
Option
<
f64
>
,
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
bool
>
,
},
NewBlocks
{
token_sequence
:
Arc
<
Vec
<
SequenceHash
>>
,
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
usize
>
,
},
PotentialBlocks
{
token_sequence
:
Arc
<
Vec
<
SequenceHash
>>
,
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
usize
>
,
},
PotentialBlocksAndTokens
{
token_sequence
:
Option
<
Arc
<
Vec
<
SequenceHash
>>>
,
isl
:
usize
,
overlap
:
u32
,
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
(
usize
,
usize
)
>
,
},
ActiveBlocks
{
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
usize
>
,
},
ActiveTokens
{
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
usize
>
,
},
Shutdown
,
}
/// Multi-worker extension of ActiveSequences that distributes requests across multiple threads
pub
struct
ActiveSequencesMultiWorker
{
pub
struct
ActiveSequencesMultiWorker
{
send
ers
:
Arc
<
DashMap
<
WorkerWithDpRank
,
tokio
::
sync
::
mpsc
::
UnboundedSender
<
Updat
eSequences
>>
>
,
work
ers
:
Arc
<
DashMap
<
WorkerWithDpRank
,
Activ
eSequences
>>
,
request_to_worker
:
Arc
<
DashMap
<
RequestId
,
WorkerWithDpRank
>>
,
request_to_worker
:
Arc
<
DashMap
<
RequestId
,
WorkerWithDpRank
>>
,
request_to_lora
:
Arc
<
DashMap
<
RequestId
,
String
>>
,
request_to_lora
:
Arc
<
DashMap
<
RequestId
,
String
>>
,
handles
:
Arc
<
DashMap
<
WorkerWithDpRank
,
std
::
thread
::
JoinHandle
<
()
>>>
,
block_size
:
usize
,
block_size
:
usize
,
component
:
Component
,
router_id
:
u64
,
router_id
:
u64
,
/// Publisher for sequence events
event_publisher
:
EventPublisher
,
event_publisher
:
EventPublisher
,
/// Publisher for metrics (namespace-scoped)
metrics_publisher
:
Arc
<
EventPublisher
>
,
metrics_publisher
:
EventPublisher
,
replica_sync
:
bool
,
replica_sync
:
bool
,
/// Worker type for Prometheus metrics labeling ("prefill" or "decode")
worker_type
:
&
'static
str
,
worker_type
:
&
'static
str
,
}
}
...
@@ -442,37 +389,30 @@ impl ActiveSequencesMultiWorker {
...
@@ -442,37 +389,30 @@ impl ActiveSequencesMultiWorker {
)
->
Result
<
Self
>
{
)
->
Result
<
Self
>
{
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
assert
!
(
block_size
>
1
,
"block_size must be greater than 1"
);
let
senders
=
Arc
::
new
(
DashMap
::
new
());
let
workers
=
Arc
::
new
(
DashMap
::
new
());
let
handles
=
Arc
::
new
(
DashMap
::
new
());
let
request_to_worker
=
Arc
::
new
(
DashMap
::
new
());
let
request_to_worker
=
Arc
::
new
(
DashMap
::
new
());
let
request_to_lora
=
Arc
::
new
(
DashMap
::
new
());
let
request_to_lora
=
Arc
::
new
(
DashMap
::
new
());
// Expand workers by their dp_rank
for
(
worker_id
,
config
)
in
workers_with_configs
{
for
(
worker_id
,
config
)
in
workers_with_configs
{
let
dp_size
=
config
.data_parallel_size
;
let
dp_size
=
config
.data_parallel_size
;
for
dp_rank
in
0
..
dp_size
{
for
dp_rank
in
0
..
dp_size
{
let
worker
=
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
);
let
worker
=
WorkerWithDpRank
::
new
(
worker_id
,
dp_rank
);
// Create a child cancellation token from the component's runtime
workers
.insert
(
worker
,
ActiveSequences
::
new
(
block_size
));
let
cancel_token
=
component
.drt
()
.runtime
()
.child_token
();
let
(
sender
,
handle
)
=
Self
::
start_worker
(
block_size
,
cancel_token
);
senders
.insert
(
worker
,
sender
);
handles
.insert
(
worker
,
handle
);
}
}
}
}
let
event_publisher
=
let
event_publisher
=
EventPublisher
::
for_component
(
&
component
,
ACTIVE_SEQUENCES_SUBJECT
)
.await
?
;
EventPublisher
::
for_component
(
&
component
,
ACTIVE_SEQUENCES_SUBJECT
)
.await
?
;
let
metrics_publisher
=
let
metrics_publisher
=
Arc
::
new
(
EventPublisher
::
for_namespace
(
component
.namespace
(),
KV_METRICS_SUBJECT
)
.await
?
;
EventPublisher
::
for_namespace
(
component
.namespace
(),
KV_METRICS_SUBJECT
)
.await
?
,
);
let
multi_worker
=
Self
{
let
multi_worker
=
Self
{
send
ers
:
send
ers
.clone
(),
work
ers
:
work
ers
.clone
(),
request_to_worker
:
request_to_worker
.clone
(),
request_to_worker
:
request_to_worker
.clone
(),
request_to_lora
:
request_to_lora
.clone
(),
request_to_lora
:
request_to_lora
.clone
(),
handles
,
block_size
,
block_size
,
component
:
component
.clone
(),
event_publisher
,
event_publisher
,
metrics_publisher
,
metrics_publisher
,
router_id
,
router_id
,
...
@@ -480,9 +420,8 @@ impl ActiveSequencesMultiWorker {
...
@@ -480,9 +420,8 @@ impl ActiveSequencesMultiWorker {
worker_type
,
worker_type
,
};
};
// Start the subscription loop only if replica_sync is enabled
if
replica_sync
{
if
replica_sync
{
let
send
ers_clone
=
send
ers
.clone
();
let
work
ers_clone
=
work
ers
.clone
();
let
request_to_worker_clone
=
request_to_worker
.clone
();
let
request_to_worker_clone
=
request_to_worker
.clone
();
let
request_to_lora_clone
=
request_to_lora
.clone
();
let
request_to_lora_clone
=
request_to_lora
.clone
();
let
component_clone
=
component
.clone
();
let
component_clone
=
component
.clone
();
...
@@ -490,9 +429,8 @@ impl ActiveSequencesMultiWorker {
...
@@ -490,9 +429,8 @@ impl ActiveSequencesMultiWorker {
let
cancel_token
=
component
.drt
()
.runtime
()
.child_token
();
let
cancel_token
=
component
.drt
()
.runtime
()
.child_token
();
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
// NATS subscription loop
if
let
Err
(
e
)
=
Self
::
subscribe_to_events
(
if
let
Err
(
e
)
=
Self
::
subscribe_to_events
(
send
ers_clone
,
work
ers_clone
,
request_to_worker_clone
,
request_to_worker_clone
,
request_to_lora_clone
,
request_to_lora_clone
,
component_clone
,
component_clone
,
...
@@ -509,120 +447,9 @@ impl ActiveSequencesMultiWorker {
...
@@ -509,120 +447,9 @@ impl ActiveSequencesMultiWorker {
Ok
(
multi_worker
)
Ok
(
multi_worker
)
}
}
/// Helper method to start a worker task
fn
start_worker
(
block_size
:
usize
,
cancel_token
:
CancellationToken
,
)
->
(
tokio
::
sync
::
mpsc
::
UnboundedSender
<
UpdateSequences
>
,
std
::
thread
::
JoinHandle
<
()
>
,
)
{
let
(
request_tx
,
request_rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
handle
=
std
::
thread
::
spawn
(
move
||
{
// Create a single-threaded tokio runtime
let
runtime
=
tokio
::
runtime
::
Builder
::
new_current_thread
()
.enable_all
()
.build
()
.unwrap
();
runtime
.block_on
(
async
move
{
let
mut
active_sequences
=
ActiveSequences
::
new
(
block_size
);
let
mut
request_rx
=
request_rx
;
loop
{
tokio
::
select!
{
command
=
request_rx
.recv
()
=>
{
let
Some
(
command
)
=
command
else
{
break
;
};
match
command
{
UpdateSequences
::
AddRequest
{
request_id
,
token_sequence
,
isl
,
overlap
,
expected_output_tokens
,
resp_tx
,
}
=>
{
let
removed
=
active_sequences
.add_request
(
request_id
,
token_sequence
,
isl
,
overlap
,
expected_output_tokens
);
let
_
=
resp_tx
.send
(
removed
);
}
UpdateSequences
::
Free
{
request_id
}
=>
{
active_sequences
.free
(
&
request_id
);
}
UpdateSequences
::
MarkPrefillCompleted
{
request_id
}
=>
{
active_sequences
.mark_prefill_completed
(
&
request_id
);
}
UpdateSequences
::
AddOutputBlock
{
request_id
,
decay_fraction
,
resp_tx
,
}
=>
{
let
success
=
active_sequences
.add_output_block
(
&
request_id
,
decay_fraction
);
let
_
=
resp_tx
.send
(
success
);
}
UpdateSequences
::
NewBlocks
{
token_sequence
,
resp_tx
,
}
=>
{
let
new_blocks
=
active_sequences
.new_blocks
(
&
token_sequence
);
let
_
=
resp_tx
.send
(
new_blocks
);
}
UpdateSequences
::
PotentialBlocks
{
token_sequence
,
resp_tx
,
}
=>
{
let
potential_blocks
=
active_sequences
.potential_blocks
(
&
token_sequence
);
let
_
=
resp_tx
.send
(
potential_blocks
);
}
UpdateSequences
::
PotentialBlocksAndTokens
{
token_sequence
,
isl
,
overlap
,
resp_tx
,
}
=>
{
let
potential_tokens
=
active_sequences
.potential_blocks_and_tokens
(
token_sequence
.as_ref
()
.map
(|
v
|
v
.as_slice
()),
isl
,
overlap
,
);
let
_
=
resp_tx
.send
(
potential_tokens
);
}
UpdateSequences
::
ActiveBlocks
{
resp_tx
}
=>
{
let
active_blocks
=
active_sequences
.active_blocks
();
let
_
=
resp_tx
.send
(
active_blocks
);
}
UpdateSequences
::
ActiveTokens
{
resp_tx
}
=>
{
let
active_tokens
=
active_sequences
.active_tokens
();
let
_
=
resp_tx
.send
(
active_tokens
);
}
UpdateSequences
::
Shutdown
=>
{
break
;
}
}
}
// Handle cancellation
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
debug!
(
"Worker task cancelled"
);
break
;
}
}
}
});
tracing
::
debug!
(
"ActiveSequences worker task completed"
);
});
(
request_tx
,
handle
)
}
/// Background task to subscribe to active sequence events and update all workers
/// Background task to subscribe to active sequence events and update all workers
async
fn
subscribe_to_events
(
async
fn
subscribe_to_events
(
senders
:
Arc
<
workers
:
Arc
<
DashMap
<
WorkerWithDpRank
,
ActiveSequences
>>
,
DashMap
<
WorkerWithDpRank
,
tokio
::
sync
::
mpsc
::
UnboundedSender
<
UpdateSequences
>>
,
>
,
request_to_worker
:
Arc
<
DashMap
<
RequestId
,
WorkerWithDpRank
>>
,
request_to_worker
:
Arc
<
DashMap
<
RequestId
,
WorkerWithDpRank
>>
,
request_to_lora
:
Arc
<
DashMap
<
RequestId
,
String
>>
,
request_to_lora
:
Arc
<
DashMap
<
RequestId
,
String
>>
,
component
:
Component
,
component
:
Component
,
...
@@ -635,10 +462,8 @@ impl ActiveSequencesMultiWorker {
...
@@ -635,10 +462,8 @@ impl ActiveSequencesMultiWorker {
loop
{
loop
{
tokio
::
select!
{
tokio
::
select!
{
// Handle incoming events
result
=
subscriber
.next
()
=>
{
result
=
subscriber
.next
()
=>
{
let
Some
(
result
)
=
result
else
{
let
Some
(
result
)
=
result
else
{
// Stream ended
break
;
break
;
};
};
...
@@ -650,7 +475,6 @@ impl ActiveSequencesMultiWorker {
...
@@ -650,7 +475,6 @@ impl ActiveSequencesMultiWorker {
continue
;
continue
;
};
};
// Skip events emitted by itself
if
event
.router_id
==
router_id
{
if
event
.router_id
==
router_id
{
continue
;
continue
;
}
}
...
@@ -664,22 +488,18 @@ impl ActiveSequencesMultiWorker {
...
@@ -664,22 +488,18 @@ impl ActiveSequencesMultiWorker {
}
=>
{
}
=>
{
request_to_worker
.insert
(
event
.request_id
.clone
(),
event
.worker
);
request_to_worker
.insert
(
event
.request_id
.clone
(),
event
.worker
);
// Store lora_name mapping if present
if
let
Some
(
ref
lora_name
)
=
event
.lora_name
{
if
let
Some
(
ref
lora_name
)
=
event
.lora_name
{
request_to_lora
.insert
(
event
.request_id
.clone
(),
lora_name
.clone
());
request_to_lora
.insert
(
event
.request_id
.clone
(),
lora_name
.clone
());
}
}
if
let
Some
(
sender
)
=
senders
.get
(
&
event
.worker
)
{
if
let
Some
(
mut
entry
)
=
workers
.get_mut
(
&
event
.worker
)
{
// For replicated events, we create a dummy response channel since we don't need to handle expired requests
entry
.add_request
(
let
(
resp_tx
,
_
)
=
tokio
::
sync
::
oneshot
::
channel
();
event
.request_id
.clone
(),
let
_
=
sender
.send
(
UpdateSequences
::
AddRequest
{
token_sequence
.clone
(),
request_id
:
event
.request_id
.clone
(),
*
isl
,
token_sequence
:
token_sequence
.clone
(),
*
overlap
,
isl
:
*
isl
,
*
expected_output_tokens
,
overlap
:
*
overlap
,
);
expected_output_tokens
:
*
expected_output_tokens
,
resp_tx
,
});
}
else
{
}
else
{
tracing
::
warn!
(
tracing
::
warn!
(
"Worker {:?} not found, cannot process AddRequest"
,
"Worker {:?} not found, cannot process AddRequest"
,
...
@@ -689,27 +509,21 @@ impl ActiveSequencesMultiWorker {
...
@@ -689,27 +509,21 @@ impl ActiveSequencesMultiWorker {
}
}
ActiveSequenceEventData
::
Free
=>
{
ActiveSequenceEventData
::
Free
=>
{
if
let
Some
((
_
,
worker
))
=
request_to_worker
.remove
(
&
event
.request_id
)
if
let
Some
((
_
,
worker
))
=
request_to_worker
.remove
(
&
event
.request_id
)
&&
let
Some
(
sender
)
=
send
ers
.get
(
&
worker
)
&&
let
Some
(
mut
entry
)
=
work
ers
.get
_mut
(
&
worker
)
{
{
let
_
=
sender
.send
(
UpdateSequences
::
Free
{
entry
.free
(
&
event
.request_id
);
request_id
:
event
.request_id
.clone
(),
});
}
}
// Clean up lora_name mapping
request_to_lora
.remove
(
&
event
.request_id
);
request_to_lora
.remove
(
&
event
.request_id
);
}
}
ActiveSequenceEventData
::
MarkPrefillCompleted
=>
{
ActiveSequenceEventData
::
MarkPrefillCompleted
=>
{
if
let
Some
(
worker
)
=
request_to_worker
.get
(
&
event
.request_id
)
if
let
Some
(
worker
)
=
request_to_worker
.get
(
&
event
.request_id
)
&&
let
Some
(
sender
)
=
send
ers
.get
(
&*
worker
)
&&
let
Some
(
mut
entry
)
=
work
ers
.get
_mut
(
&*
worker
)
{
{
let
_
=
sender
.send
(
UpdateSequences
::
MarkPrefillCompleted
{
entry
.mark_prefill_completed
(
&
event
.request_id
);
request_id
:
event
.request_id
.clone
(),
});
}
}
}
}
}
}
}
}
// Handle cancellation
_
=
cancel_token
.cancelled
()
=>
{
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
debug!
(
"Subscription task cancelled"
);
tracing
::
debug!
(
"Subscription task cancelled"
);
break
;
break
;
...
@@ -723,9 +537,8 @@ impl ActiveSequencesMultiWorker {
...
@@ -723,9 +537,8 @@ impl ActiveSequencesMultiWorker {
/// Update the set of workers, adding and removing as needed
/// Update the set of workers, adding and removing as needed
pub
fn
update_workers
(
&
self
,
new_workers_with_configs
:
HashMap
<
u64
,
ModelRuntimeConfig
>
)
{
pub
fn
update_workers
(
&
self
,
new_workers_with_configs
:
HashMap
<
u64
,
ModelRuntimeConfig
>
)
{
let
current_workers
:
HashSet
<
WorkerWithDpRank
>
=
let
current_workers
:
HashSet
<
WorkerWithDpRank
>
=
self
.
send
ers
.iter
()
.map
(|
entry
|
*
entry
.key
())
.collect
();
self
.
work
ers
.iter
()
.map
(|
entry
|
*
entry
.key
())
.collect
();
// Expand new workers by their dp_rank
let
mut
new_workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
let
mut
new_workers
:
HashSet
<
WorkerWithDpRank
>
=
HashSet
::
new
();
for
(
worker_id
,
config
)
in
&
new_workers_with_configs
{
for
(
worker_id
,
config
)
in
&
new_workers_with_configs
{
let
dp_size
=
config
.data_parallel_size
;
let
dp_size
=
config
.data_parallel_size
;
...
@@ -740,17 +553,11 @@ impl ActiveSequencesMultiWorker {
...
@@ -740,17 +553,11 @@ impl ActiveSequencesMultiWorker {
let
workers_to_add
:
Vec
<
WorkerWithDpRank
>
=
let
workers_to_add
:
Vec
<
WorkerWithDpRank
>
=
new_workers
.difference
(
&
current_workers
)
.copied
()
.collect
();
new_workers
.difference
(
&
current_workers
)
.copied
()
.collect
();
// Remove workers (this will naturally remove all dp ranks for a worker_id)
for
worker
in
&
workers_to_remove
{
for
worker
in
&
workers_to_remove
{
tracing
::
warn!
(
"Removing worker {:?}"
,
worker
);
tracing
::
warn!
(
"Removing worker {:?}"
,
worker
);
// Send shutdown command to the worker
self
.workers
.remove
(
worker
);
if
let
Some
((
_
,
sender
))
=
self
.senders
.remove
(
worker
)
{
let
_
=
sender
.send
(
UpdateSequences
::
Shutdown
);
}
self
.handles
.remove
(
worker
);
// Collect request_ids to remove from request_to_lora
let
requests_to_remove
:
Vec
<
RequestId
>
=
self
let
requests_to_remove
:
Vec
<
RequestId
>
=
self
.request_to_worker
.request_to_worker
.iter
()
.iter
()
...
@@ -758,26 +565,18 @@ impl ActiveSequencesMultiWorker {
...
@@ -758,26 +565,18 @@ impl ActiveSequencesMultiWorker {
.map
(|
entry
|
entry
.key
()
.clone
())
.map
(|
entry
|
entry
.key
()
.clone
())
.collect
();
.collect
();
// Clean up request_to_worker mappings for this worker
self
.request_to_worker
self
.request_to_worker
.retain
(|
_
request_id
,
mapped_worker
|
mapped_worker
!=
worker
);
.retain
(|
_
request_id
,
mapped_worker
|
mapped_worker
!=
worker
);
// Clean up request_to_lora mappings for removed requests
for
request_id
in
requests_to_remove
{
for
request_id
in
requests_to_remove
{
self
.request_to_lora
.remove
(
&
request_id
);
self
.request_to_lora
.remove
(
&
request_id
);
}
}
}
}
// Add new workers
for
worker
in
&
workers_to_add
{
for
worker
in
&
workers_to_add
{
tracing
::
warn!
(
"Adding worker {:?}"
,
worker
);
tracing
::
warn!
(
"Adding worker {:?}"
,
worker
);
self
.workers
let
(
sender
,
handle
)
=
Self
::
start_worker
(
.insert
(
*
worker
,
ActiveSequences
::
new
(
self
.block_size
));
self
.block_size
,
self
.component
.drt
()
.runtime
()
.child_token
(),
);
self
.senders
.insert
(
*
worker
,
sender
);
self
.handles
.insert
(
*
worker
,
handle
);
}
}
}
}
...
@@ -792,15 +591,9 @@ impl ActiveSequencesMultiWorker {
...
@@ -792,15 +591,9 @@ impl ActiveSequencesMultiWorker {
lora_name
,
lora_name
,
}
=
req
;
}
=
req
;
// Clone the sender upfront so we don't hold the DashMap Ref across
if
!
self
.workers
.contains_key
(
&
worker
)
{
// the .await points below. Also eliminates the TOCTOU between
return
Err
(
SequenceError
::
WorkerNotFound
{
worker
});
// contains_key and a later get().unwrap().
}
let
sender
=
self
.senders
.get
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
.value
()
.clone
();
if
let
Some
(
existing_worker
)
=
self
.request_to_worker
.get
(
&
request_id
)
{
if
let
Some
(
existing_worker
)
=
self
.request_to_worker
.get
(
&
request_id
)
{
return
Err
(
SequenceError
::
DuplicateRequest
{
return
Err
(
SequenceError
::
DuplicateRequest
{
...
@@ -809,8 +602,6 @@ impl ActiveSequencesMultiWorker {
...
@@ -809,8 +602,6 @@ impl ActiveSequencesMultiWorker {
});
});
}
}
let
(
resp_tx
,
resp_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
if
self
.replica_sync
{
if
self
.replica_sync
{
let
event
=
ActiveSequenceEvent
{
let
event
=
ActiveSequenceEvent
{
request_id
:
request_id
.clone
(),
request_id
:
request_id
.clone
(),
...
@@ -833,38 +624,37 @@ impl ActiveSequencesMultiWorker {
...
@@ -833,38 +624,37 @@ impl ActiveSequencesMultiWorker {
self
.request_to_lora
.insert
(
request_id
.clone
(),
lora
);
self
.request_to_lora
.insert
(
request_id
.clone
(),
lora
);
}
}
sender
let
removed_requests
=
{
.send
(
UpdateSequences
::
AddRequest
{
let
mut
entry
=
self
.workers
.get_mut
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
entry
.add_request
(
request_id
,
request_id
,
token_sequence
,
token_sequence
,
isl
,
isl
,
overlap
,
overlap
,
expected_output_tokens
,
expected_output_tokens
,
resp_tx
,
)
})
};
.map_err
(|
_
|
SequenceError
::
WorkerChannelClosed
)
?
;
let
removed_requests
=
resp_rx
.await
.map_err
(|
_
|
SequenceError
::
WorkerChannelClosed
)
?
;
for
expired_id
in
&
removed_requests
{
for
expired_id
in
&
removed_requests
{
self
.request_to_worker
.remove
(
expired_id
);
self
.request_to_worker
.remove
(
expired_id
);
self
.request_to_lora
.remove
(
expired_id
);
self
.request_to_lora
.remove
(
expired_id
);
}
}
self
.publish_active_load_for_worker
(
worker
)
.await
;
self
.publish_active_load_for_worker
(
worker
);
Ok
(())
Ok
(())
}
}
/// Send a
command
to the worker assigned to a request, optionally publishing
/// Send a
mutation
to the worker assigned to a request, optionally publishing
/// a replica-sync event and cleaning up request mappings afterward.
/// a replica-sync event and cleaning up request mappings afterward.
async
fn
send_to
_request_worker
(
async
fn
mutate
_request_worker
(
&
self
,
&
self
,
request_id
:
&
RequestId
,
request_id
:
&
RequestId
,
event_data
:
ActiveSequenceEventData
,
event_data
:
ActiveSequenceEventData
,
command
_fn
:
impl
FnOnce
(
RequestId
)
->
UpdateSequences
,
mutate
_fn
:
impl
FnOnce
(
&
mut
ActiveSequences
,
&
RequestId
)
,
remove_mapping
:
bool
,
remove_mapping
:
bool
,
)
->
Result
<
(),
SequenceError
>
{
)
->
Result
<
(),
SequenceError
>
{
let
worker
=
self
let
worker
=
self
...
@@ -875,13 +665,6 @@ impl ActiveSequencesMultiWorker {
...
@@ -875,13 +665,6 @@ impl ActiveSequencesMultiWorker {
request_id
:
request_id
.clone
(),
request_id
:
request_id
.clone
(),
})
?
;
})
?
;
let
sender
=
self
.senders
.get
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
.value
()
.clone
();
if
self
.replica_sync
{
if
self
.replica_sync
{
let
lora_name
=
self
let
lora_name
=
self
.request_to_lora
.request_to_lora
...
@@ -898,16 +681,20 @@ impl ActiveSequencesMultiWorker {
...
@@ -898,16 +681,20 @@ impl ActiveSequencesMultiWorker {
self
.event_publisher
.publish
(
&
event
)
.await
?
;
self
.event_publisher
.publish
(
&
event
)
.await
?
;
}
}
sender
{
.send
(
command_fn
(
request_id
.clone
()))
let
mut
entry
=
self
.map_err
(|
_
|
SequenceError
::
WorkerChannelClosed
)
?
;
.workers
.get_mut
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
mutate_fn
(
&
mut
entry
,
request_id
);
}
if
remove_mapping
{
if
remove_mapping
{
self
.request_to_worker
.remove
(
request_id
);
self
.request_to_worker
.remove
(
request_id
);
self
.request_to_lora
.remove
(
request_id
);
self
.request_to_lora
.remove
(
request_id
);
}
}
self
.publish_active_load_for_worker
(
worker
)
.await
;
self
.publish_active_load_for_worker
(
worker
);
Ok
(())
Ok
(())
}
}
...
@@ -922,10 +709,12 @@ impl ActiveSequencesMultiWorker {
...
@@ -922,10 +709,12 @@ impl ActiveSequencesMultiWorker {
return
Ok
(());
return
Ok
(());
}
}
self
.
send_to
_request_worker
(
self
.
mutate
_request_worker
(
request_id
,
request_id
,
ActiveSequenceEventData
::
Free
,
ActiveSequenceEventData
::
Free
,
|
rid
|
UpdateSequences
::
Free
{
request_id
:
rid
},
|
seqs
,
rid
|
{
seqs
.free
(
rid
);
},
true
,
true
,
)
)
.await
.await
...
@@ -939,10 +728,12 @@ impl ActiveSequencesMultiWorker {
...
@@ -939,10 +728,12 @@ impl ActiveSequencesMultiWorker {
&
self
,
&
self
,
request_id
:
&
RequestId
,
request_id
:
&
RequestId
,
)
->
Result
<
(),
SequenceError
>
{
)
->
Result
<
(),
SequenceError
>
{
self
.
send_to
_request_worker
(
self
.
mutate
_request_worker
(
request_id
,
request_id
,
ActiveSequenceEventData
::
MarkPrefillCompleted
,
ActiveSequenceEventData
::
MarkPrefillCompleted
,
|
rid
|
UpdateSequences
::
MarkPrefillCompleted
{
request_id
:
rid
},
|
seqs
,
rid
|
{
seqs
.mark_prefill_completed
(
rid
);
},
false
,
false
,
)
)
.await
.await
...
@@ -952,7 +743,9 @@ impl ActiveSequencesMultiWorker {
...
@@ -952,7 +743,9 @@ impl ActiveSequencesMultiWorker {
///
///
/// This is used during generation to track output blocks as they are created.
/// This is used during generation to track output blocks as they are created.
/// The decay_fraction represents how "temporary" the block is based on generation progress.
/// The decay_fraction represents how "temporary" the block is based on generation progress.
pub
async
fn
add_output_block
(
// TODO: output blocks are not replicated via replica_sync — add an
// ActiveSequenceEventData variant if cross-instance accuracy matters.
pub
fn
add_output_block
(
&
self
,
&
self
,
request_id
:
&
RequestId
,
request_id
:
&
RequestId
,
decay_fraction
:
Option
<
f64
>
,
decay_fraction
:
Option
<
f64
>
,
...
@@ -965,30 +758,13 @@ impl ActiveSequencesMultiWorker {
...
@@ -965,30 +758,13 @@ impl ActiveSequencesMultiWorker {
request_id
:
request_id
.clone
(),
request_id
:
request_id
.clone
(),
})
?
;
})
?
;
// Clone sender upfront to avoid TOCTOU between contains_key and get().unwrap()
let
success
=
{
let
sender
=
self
let
mut
entry
=
self
.senders
.workers
.get
(
&
worker
)
.get_mut
(
&
worker
)
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
.ok_or
(
SequenceError
::
WorkerNotFound
{
worker
})
?
;
.value
()
entry
.add_output_block
(
request_id
,
decay_fraction
)
.clone
();
};
// Create response channel
let
(
resp_tx
,
resp_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
// Send command to worker
sender
.send
(
UpdateSequences
::
AddOutputBlock
{
request_id
:
request_id
.clone
(),
decay_fraction
,
resp_tx
,
})
.map_err
(|
_
|
SequenceError
::
WorkerChannelClosed
)
?
;
// Wait for response
let
success
=
resp_rx
.await
.map_err
(|
_
|
SequenceError
::
WorkerChannelClosed
)
?
;
if
!
success
{
if
!
success
{
return
Err
(
SequenceError
::
RequestNotFound
{
return
Err
(
SequenceError
::
RequestNotFound
{
...
@@ -996,56 +772,22 @@ impl ActiveSequencesMultiWorker {
...
@@ -996,56 +772,22 @@ impl ActiveSequencesMultiWorker {
});
});
}
}
// Publish ActiveLoad metrics for this worker
self
.publish_active_load_for_worker
(
worker
);
self
.publish_active_load_for_worker
(
worker
)
.await
;
Ok
(())
Ok
(())
}
}
/// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad
/// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
async
fn
publish_active_load_for_worker
(
&
self
,
worker
:
WorkerWithDpRank
)
{
/// The NATS publish is spawned as a background task to avoid blocking the caller.
// Clone the sender and drop the DashMap Ref immediately.
fn
publish_active_load_for_worker
(
&
self
,
worker
:
WorkerWithDpRank
)
{
// Holding a Ref across .await points can deadlock: if the task yields
let
(
active_blocks
,
active_tokens
)
=
{
// and update_workers() needs a write lock on the same shard, the
let
Some
(
entry
)
=
self
.workers
.get
(
&
worker
)
else
{
// runtime thread blocks forever.
let
sender
=
{
let
Some
(
entry
)
=
self
.senders
.get
(
&
worker
)
else
{
tracing
::
warn!
(
"Worker {worker:?} not found when publishing ActiveLoad"
);
tracing
::
warn!
(
"Worker {worker:?} not found when publishing ActiveLoad"
);
return
;
return
;
};
};
entry
.
value
()
.clone
()
(
entry
.
active_blocks
(),
entry
.active_tokens
()
)
};
};
// Query active blocks
let
(
blocks_tx
,
blocks_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
if
sender
.send
(
UpdateSequences
::
ActiveBlocks
{
resp_tx
:
blocks_tx
})
.is_err
()
{
tracing
::
warn!
(
"Failed to send ActiveBlocks query to worker {worker:?}"
);
return
;
}
// Query active tokens
let
(
tokens_tx
,
tokens_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
if
sender
.send
(
UpdateSequences
::
ActiveTokens
{
resp_tx
:
tokens_tx
})
.is_err
()
{
tracing
::
warn!
(
"Failed to send ActiveTokens query to worker {worker:?}"
);
return
;
}
// Await both responses
let
(
active_blocks
,
active_tokens
)
=
match
tokio
::
join!
(
blocks_rx
,
tokens_rx
)
{
(
Ok
(
blocks
),
Ok
(
tokens
))
=>
(
blocks
,
tokens
),
_
=>
{
tracing
::
warn!
(
"Failed to receive active blocks/tokens from worker {worker:?}"
);
return
;
}
};
// Update Prometheus gauges directly (router's own bookkeeping)
WORKER_LOAD_METRICS
.observe
(
WORKER_LOAD_METRICS
.observe
(
worker
.worker_id
,
worker
.worker_id
,
worker
.dp_rank
,
worker
.dp_rank
,
...
@@ -1054,7 +796,6 @@ impl ActiveSequencesMultiWorker {
...
@@ -1054,7 +796,6 @@ impl ActiveSequencesMultiWorker {
active_tokens
,
active_tokens
,
);
);
// Also publish ActiveLoad to NATS for other subscribers (if NATS is available)
let
active_load
=
ActiveLoad
{
let
active_load
=
ActiveLoad
{
worker_id
:
worker
.worker_id
,
worker_id
:
worker
.worker_id
,
dp_rank
:
worker
.dp_rank
,
dp_rank
:
worker
.dp_rank
,
...
@@ -1062,15 +803,19 @@ impl ActiveSequencesMultiWorker {
...
@@ -1062,15 +803,19 @@ impl ActiveSequencesMultiWorker {
active_prefill_tokens
:
Some
(
active_tokens
as
u64
),
active_prefill_tokens
:
Some
(
active_tokens
as
u64
),
};
};
if
let
Err
(
e
)
=
self
.metrics_publisher
.publish
(
&
active_load
)
.await
{
let
publisher
=
self
.metrics_publisher
.clone
();
// This is expected if NATS is not available - the local gauge update above already succeeded
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"Failed to publish ActiveLoad to NATS for worker {worker:?}: {e:?}"
);
if
let
Err
(
e
)
=
publisher
.publish
(
&
active_load
)
.await
{
}
tracing
::
trace!
(
"Failed to publish ActiveLoad to NATS for worker {worker:?}: {e:?}"
);
}
});
}
}
/// Get the number of workers
/// Get the number of workers
pub
fn
num_workers
(
&
self
)
->
usize
{
pub
fn
num_workers
(
&
self
)
->
usize
{
self
.
send
ers
.len
()
self
.
work
ers
.len
()
}
}
/// Get the worker type for this router ("prefill" or "decode").
/// Get the worker type for this router ("prefill" or "decode").
...
@@ -1079,80 +824,35 @@ impl ActiveSequencesMultiWorker {
...
@@ -1079,80 +824,35 @@ impl ActiveSequencesMultiWorker {
self
.worker_type
self
.worker_type
}
}
/// Generic method to query all workers with a given command
async
fn
query_workers
<
T
:
Send
+
'static
>
(
&
self
,
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
command_fn
:
impl
Fn
(
Option
<
Arc
<
Vec
<
SequenceHash
>>>
,
tokio
::
sync
::
oneshot
::
Sender
<
T
>
,
)
->
UpdateSequences
,
)
->
HashMap
<
WorkerWithDpRank
,
T
>
{
let
mut
results
=
HashMap
::
new
();
let
token_sequence_shared
=
token_sequence
.map
(
Arc
::
new
);
let
mut
receivers
=
Vec
::
new
();
// Send queries to all workers in parallel
for
entry
in
self
.senders
.iter
()
{
let
worker
=
*
entry
.key
();
let
sender
=
entry
.value
();
let
(
resp_tx
,
resp_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
receivers
.push
((
worker
,
resp_rx
));
if
let
Err
(
e
)
=
sender
.send
(
command_fn
(
token_sequence_shared
.clone
(),
resp_tx
))
{
tracing
::
error!
(
"Failed to send command to worker {:?}: {}"
,
worker
,
e
);
}
}
// Collect results from all workers
for
(
worker
,
receiver
)
in
receivers
{
match
tokio
::
time
::
timeout
(
tokio
::
time
::
Duration
::
from_secs
(
1
),
receiver
)
.await
{
Ok
(
Ok
(
result
))
=>
{
results
.insert
(
worker
,
result
);
}
Ok
(
Err
(
_
))
=>
{
tracing
::
error!
(
"Worker {:?} dropped response channel"
,
worker
);
}
Err
(
_
)
=>
{
tracing
::
error!
(
"Timeout waiting for response from worker {:?}"
,
worker
);
}
}
}
results
}
/// Query all workers for the number of new blocks that would be added by a token sequence
/// Query all workers for the number of new blocks that would be added by a token sequence
pub
async
fn
new_blocks
(
pub
fn
new_blocks
(
&
self
,
&
self
,
token_sequence
:
Vec
<
SequenceHash
>
,
token_sequence
:
Vec
<
SequenceHash
>
,
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
self
.query_workers
(
Some
(
token_sequence
),
|
ts
,
resp_tx
|
match
ts
{
let
mut
results
=
HashMap
::
with_capacity
(
self
.workers
.len
());
Some
(
ts
)
=>
UpdateSequences
::
NewBlocks
{
for
entry
in
self
.workers
.iter
()
{
token_sequence
:
ts
,
results
.insert
(
*
entry
.key
(),
entry
.value
()
.new_blocks
(
&
token_sequence
));
resp_tx
,
}
},
results
None
=>
unreachable!
(
"token_sequence should always be Some for new_blocks"
),
})
.await
}
}
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
pub
async
fn
potential_blocks
(
pub
fn
potential_blocks
(
&
self
,
&
self
,
token_sequence
:
Vec
<
SequenceHash
>
,
token_sequence
:
Vec
<
SequenceHash
>
,
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
self
.query_workers
(
Some
(
token_sequence
),
|
ts
,
resp_tx
|
match
ts
{
let
mut
results
=
HashMap
::
with_capacity
(
self
.workers
.len
());
Some
(
ts
)
=>
UpdateSequences
::
PotentialBlocks
{
for
entry
in
self
.workers
.iter
()
{
token_sequence
:
ts
,
results
.insert
(
resp_tx
,
*
entry
.key
()
,
}
,
entry
.value
()
.potential_blocks
(
&
token_sequence
)
,
None
=>
unreachable!
(
"token_sequence should always be Some for potential_blocks"
),
);
}
)
}
.await
results
}
}
/// Query all workers for the potential
tokens (new + active) that would be used by a token sequence with overlap
/// Query all workers for the potential
blocks and tokens
pub
async
fn
potential_blocks_and_tokens
(
pub
fn
potential_blocks_and_tokens
(
&
self
,
&
self
,
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
token_sequence
:
Option
<
Vec
<
SequenceHash
>>
,
isl
:
usize
,
isl
:
usize
,
...
@@ -1164,56 +864,21 @@ impl ActiveSequencesMultiWorker {
...
@@ -1164,56 +864,21 @@ impl ActiveSequencesMultiWorker {
#[cfg(feature
=
"bench"
)]
#[cfg(feature
=
"bench"
)]
let
start
=
Instant
::
now
();
let
start
=
Instant
::
now
();
#[cfg(feature
=
"bench"
)]
#[cfg(feature
=
"bench"
)]
let
num_workers
=
self
.
send
ers
.len
();
let
num_workers
=
self
.
work
ers
.len
();
let
mut
potential_blocks
=
HashMap
::
new
();
let
mut
potential_blocks
=
HashMap
::
with_capacity
(
self
.workers
.len
());
let
mut
potential_tokens
=
HashMap
::
new
();
let
mut
potential_tokens
=
HashMap
::
with_capacity
(
self
.workers
.len
());
let
token_sequence_shared
=
token_sequence
.map
(
Arc
::
new
);
let
mut
receivers
=
Vec
::
new
();
// Iterate through all workers, not just those with overlap
for
entry
in
self
.workers
.iter
()
{
// This ensures we properly account for active tokens/blocks on all workers
let
worker
=
*
entry
.key
();
for
sender_entry
in
self
.senders
.iter
()
{
let
worker
=
*
sender_entry
.key
();
let
sender
=
sender_entry
.value
();
// Get overlap for this worker (defaults to 0 if not in overlaps)
let
overlap
=
*
overlaps
.scores
.get
(
&
worker
)
.unwrap_or
(
&
0
);
let
overlap
=
*
overlaps
.scores
.get
(
&
worker
)
.unwrap_or
(
&
0
);
let
(
resp_tx
,
resp_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
let
(
blocks
,
tokens
)
=
receivers
.push
((
worker
,
resp_rx
));
entry
.value
()
if
let
Err
(
e
)
=
sender
.send
(
UpdateSequences
::
PotentialBlocksAndTokens
{
.potential_blocks_and_tokens
(
token_sequence
.as_deref
(),
isl
,
overlap
);
token_sequence
:
token_sequence_shared
.clone
(),
potential_blocks
.insert
(
worker
,
blocks
);
isl
,
potential_tokens
.insert
(
worker
,
tokens
);
overlap
,
resp_tx
,
})
{
tracing
::
error!
(
"Failed to send potential_tokens command to worker {:?}: {}"
,
worker
,
e
);
}
}
#[cfg(feature
=
"bench"
)]
let
send_elapsed
=
start
.elapsed
();
// Collect results from all workers
for
(
worker
,
receiver
)
in
receivers
{
match
tokio
::
time
::
timeout
(
tokio
::
time
::
Duration
::
from_secs
(
1
),
receiver
)
.await
{
Ok
(
Ok
((
blocks
,
tokens
)))
=>
{
potential_blocks
.insert
(
worker
,
blocks
);
potential_tokens
.insert
(
worker
,
tokens
);
}
Ok
(
Err
(
_
))
=>
{
tracing
::
error!
(
"Worker {:?} dropped response channel"
,
worker
);
}
Err
(
_
)
=>
{
tracing
::
error!
(
"Timeout waiting for response from worker {:?}"
,
worker
);
}
}
}
}
#[cfg(feature
=
"bench"
)]
#[cfg(feature
=
"bench"
)]
...
@@ -1221,7 +886,6 @@ impl ActiveSequencesMultiWorker {
...
@@ -1221,7 +886,6 @@ impl ActiveSequencesMultiWorker {
let
total_elapsed
=
start
.elapsed
();
let
total_elapsed
=
start
.elapsed
();
tracing
::
info!
(
tracing
::
info!
(
num_workers
,
num_workers
,
send_us
=
send_elapsed
.as_micros
()
as
u64
,
total_us
=
total_elapsed
.as_micros
()
as
u64
,
total_us
=
total_elapsed
.as_micros
()
as
u64
,
"potential_blocks_and_tokens completed"
"potential_blocks_and_tokens completed"
);
);
...
@@ -1231,15 +895,21 @@ impl ActiveSequencesMultiWorker {
...
@@ -1231,15 +895,21 @@ impl ActiveSequencesMultiWorker {
}
}
/// Query all workers for their current number of active blocks
/// Query all workers for their current number of active blocks
pub
async
fn
active_blocks
(
&
self
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
pub
fn
active_blocks
(
&
self
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
self
.query_workers
(
None
,
|
_
,
resp_tx
|
UpdateSequences
::
ActiveBlocks
{
resp_tx
})
let
mut
results
=
HashMap
::
with_capacity
(
self
.workers
.len
());
.await
for
entry
in
self
.workers
.iter
()
{
results
.insert
(
*
entry
.key
(),
entry
.value
()
.active_blocks
());
}
results
}
}
/// Query all workers for their current number of active tokens
/// Query all workers for their current number of active tokens
pub
async
fn
active_tokens
(
&
self
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
pub
fn
active_tokens
(
&
self
)
->
HashMap
<
WorkerWithDpRank
,
usize
>
{
self
.query_workers
(
None
,
|
_
,
resp_tx
|
UpdateSequences
::
ActiveTokens
{
resp_tx
})
let
mut
results
=
HashMap
::
with_capacity
(
self
.workers
.len
());
.await
for
entry
in
self
.workers
.iter
()
{
results
.insert
(
*
entry
.key
(),
entry
.value
()
.active_tokens
());
}
results
}
}
pub
fn
get_active_lora_counts
(
&
self
)
->
HashMap
<
String
,
usize
>
{
pub
fn
get_active_lora_counts
(
&
self
)
->
HashMap
<
String
,
usize
>
{
...
@@ -1252,15 +922,6 @@ impl ActiveSequencesMultiWorker {
...
@@ -1252,15 +922,6 @@ impl ActiveSequencesMultiWorker {
}
}
}
}
impl
Drop
for
ActiveSequencesMultiWorker
{
fn
drop
(
&
mut
self
)
{
// Send shutdown to all workers
for
entry
in
self
.senders
.iter
()
{
let
_
=
entry
.value
()
.send
(
UpdateSequences
::
Shutdown
);
}
}
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
...
@@ -1400,8 +1061,8 @@ mod tests {
...
@@ -1400,8 +1061,8 @@ mod tests {
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
300
))
.await
;
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
300
))
.await
;
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let
blocks_phase1
=
seq_manager_1
.active_blocks
()
.await
;
let
blocks_phase1
=
seq_manager_1
.active_blocks
();
let
tokens_phase1
=
seq_manager_1
.active_tokens
()
.await
;
let
tokens_phase1
=
seq_manager_1
.active_tokens
();
// Verify that seq_manager_1 sees all requests including request_2 from seq_manager_2
// Verify that seq_manager_1 sees all requests including request_2 from seq_manager_2
// We now have:
// We now have:
...
@@ -1450,8 +1111,8 @@ mod tests {
...
@@ -1450,8 +1111,8 @@ mod tests {
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
300
))
.await
;
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
300
))
.await
;
// Query seq_manager_2 to verify everything is empty
// Query seq_manager_2 to verify everything is empty
let
blocks_phase2
=
seq_manager_2
.active_blocks
()
.await
;
let
blocks_phase2
=
seq_manager_2
.active_blocks
();
let
tokens_phase2
=
seq_manager_2
.active_tokens
()
.await
;
let
tokens_phase2
=
seq_manager_2
.active_tokens
();
// Verify phase 2 results - everything should be empty for all 3 workers
// Verify phase 2 results - everything should be empty for all 3 workers
let
all_workers
=
vec!
[
let
all_workers
=
vec!
[
...
@@ -1579,7 +1240,7 @@ mod tests {
...
@@ -1579,7 +1240,7 @@ mod tests {
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
300
))
.await
;
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
300
))
.await
;
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
// Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
let
tokens_phase1
=
seq_manager_1
.active_tokens
()
.await
;
let
tokens_phase1
=
seq_manager_1
.active_tokens
();
// Verify that seq_manager_1 sees all requests including request_2 from thread 2
// Verify that seq_manager_1 sees all requests including request_2 from thread 2
let
worker_0
=
WorkerWithDpRank
::
from_worker_id
(
0
);
let
worker_0
=
WorkerWithDpRank
::
from_worker_id
(
0
);
...
@@ -1621,7 +1282,7 @@ mod tests {
...
@@ -1621,7 +1282,7 @@ mod tests {
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
300
))
.await
;
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
300
))
.await
;
// Query seq_manager_2 to verify everything is empty
// Query seq_manager_2 to verify everything is empty
let
tokens_phase2
=
seq_manager_2
.active_tokens
()
.await
;
let
tokens_phase2
=
seq_manager_2
.active_tokens
();
// Verify phase 2 results - everything should be empty
// Verify phase 2 results - everything should be empty
for
worker_id
in
0
..=
2
{
for
worker_id
in
0
..=
2
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment