Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
66231cf0
Unverified
Commit
66231cf0
authored
Jul 31, 2025
by
Yan Ru Pei
Committed by
GitHub
Jul 31, 2025
Browse files
feat: reduce / revert routing overheads, do not consider output tokens (#2182)
parent
dbd33df6
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
252 additions
and
404 deletions
+252
-404
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+19
-52
lib/llm/src/kv_router/approx.rs
lib/llm/src/kv_router/approx.rs
+22
-10
lib/llm/src/kv_router/indexer.rs
lib/llm/src/kv_router/indexer.rs
+35
-0
lib/llm/src/kv_router/scheduler.rs
lib/llm/src/kv_router/scheduler.rs
+7
-10
lib/llm/src/kv_router/sequence.rs
lib/llm/src/kv_router/sequence.rs
+169
-227
lib/llm/src/kv_router/worker.rs
lib/llm/src/kv_router/worker.rs
+0
-105
No files found.
lib/llm/src/kv_router.rs
View file @
66231cf0
...
...
@@ -31,8 +31,8 @@ use crate::{
kv_router
::{
approx
::
ApproxKvIndexer
,
indexer
::{
compute_block_hash_for_seq
,
KvIndexer
,
KvIndexerInterface
,
KvRouterError
,
OverlapScores
,
RouterEvent
,
compute_block_hash_for_seq
,
compute_seq_hash_for_block
,
KvIndexer
,
KvIndexerInterface
,
KvRouterError
,
OverlapScores
,
RouterEvent
,
},
// metrics_aggregator::EndpointCollector,
protocols
::{
LocalBlockHash
,
RouterRequest
,
RouterResponse
,
WorkerSelectionResult
},
...
...
@@ -71,7 +71,8 @@ pub struct KvRouterConfig {
pub
use_kv_events
:
bool
,
// note: this is not actually used for now
// TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub
max_num_batched_tokens
:
u32
,
}
...
...
@@ -231,25 +232,25 @@ impl KvRouter {
let
_
guard
=
self
.find_best_match_mutex
.lock
()
.await
;
let
isl_tokens
=
tokens
.len
();
let
block_size
=
self
.block_size
;
let
local_block_hashes
=
compute_block_hash_for_seq
(
tokens
,
self
.block_size
);
let
overlap_scores
=
self
.indexer
.find_matches
(
local_block_hashes
)
.await
?
;
let
block_hashes
=
compute_block_hash_for_seq
(
tokens
,
self
.block_size
);
let
seq_hashes
=
compute_seq_hash_for_block
(
&
block_hashes
);
let
overlap_scores
=
self
.indexer
.find_matches
(
block_hashes
.clone
())
.await
?
;
let
best_worker_id
=
self
.scheduler
.schedule
(
context_id
.to_string
(),
isl_tokens
,
block_size
,
tokens
,
seq_hashes
.clone
(),
overlap_scores
.clone
(),
)
.await
?
;
if
let
Indexer
::
ApproxKvIndexer
(
ref
indexer
)
=
self
.indexer
{
indexer
.process_routing_decision
_for_request
(
tokens
,
best_worker_id
)
.process_routing_decision
(
best_worker_id
,
block_hashes
,
seq_hashes
)
.await
.unwrap
();
};
...
...
@@ -262,9 +263,9 @@ impl KvRouter {
Ok
((
best_worker_id
,
overlap_amount
))
}
///
Push tokens to a specific request's sequence
pub
async
fn
push
(
&
self
,
request_id
:
&
String
,
tokens
:
&
[
u32
]
)
{
self
.scheduler
.
push
(
request_id
,
tokens
)
.await
///
Free all blocks associated with a request
pub
async
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
String
)
{
self
.scheduler
.
mark_prefill_completed
(
request_id
)
.await
}
/// Free all blocks associated with a request
...
...
@@ -331,7 +332,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let
stream_context
=
request
.context
()
.clone
();
// Update the request with the estimated prefix hit blocks
let
(
mut
backend_input
,
context
)
=
request
.into_parts
();
let
isl
=
backend_input
.token_ids
.len
();
backend_input
.estimated_prefix_hit_num_blocks
=
Some
(
overlap_amount
);
let
updated_request
=
context
.map
(|
_
|
backend_input
);
...
...
@@ -345,55 +345,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let
stream
=
stream
::
iter
(
vec!
[
response
]);
return
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
stream_context
));
}
// Get the response stream from the worker
let
mut
response_stream
=
self
.inner
.direct
(
updated_request
,
instance_id
)
.await
?
;
// Wrap the stream to track tokens
let
mut
response_stream
=
self
.inner
.direct
(
updated_request
,
instance_id
)
.await
?
;
let
stream_context
=
response_stream
.context
();
let
chooser
=
self
.chooser
.clone
();
let
request_id
=
context_id
.clone
();
let
block_size
=
chooser
.block_size
()
as
usize
;
let
wrapped_stream
=
Box
::
pin
(
async_stream
::
stream!
{
let
mut
accumulated_tokens
=
Vec
::
new
();
let
mut
total_output_length
=
0u
size
;
let
mut
last_block_index
=
(
isl
.saturating_sub
(
1
))
/
block_size
;
let
mut
first_push_done
=
false
;
while
let
Some
(
item
)
=
response_stream
.next
()
.await
{
// Track tokens if they exist in the response
let
Some
(
ref
output
)
=
item
.data
else
{
yield
item
;
continue
;
};
if
output
.token_ids
.is_empty
()
{
yield
item
;
continue
;
}
// Add tokens to accumulator
accumulated_tokens
.extend_from_slice
(
&
output
.token_ids
);
total_output_length
+=
output
.token_ids
.len
();
// Always push for the first generated token (to mark prefill done)
// or when we've moved to a new block
let
current_block_index
=
(
isl
+
total_output_length
)
.saturating_sub
(
1
)
/
block_size
;
let
should_push
=
(
!
first_push_done
&&
total_output_length
>=
1
)
||
(
first_push_done
&&
current_block_index
>
last_block_index
);
if
should_push
{
chooser
.push
(
&
request_id
,
&
accumulated_tokens
)
.await
;
accumulated_tokens
.clear
();
last_block_index
=
current_block_index
;
if
!
first_push_done
{
first_push_done
=
true
;
}
if
let
Some
(
first_item
)
=
response_stream
.next
()
.await
{
chooser
.mark_prefill_completed
(
&
context_id
)
.await
;
yield
first_item
;
}
while
let
Some
(
item
)
=
response_stream
.next
()
.await
{
yield
item
;
}
chooser
.free
(
&
reques
t_id
)
.await
;
chooser
.free
(
&
contex
t_id
)
.await
;
});
Ok
(
ResponseStream
::
new
(
wrapped_stream
,
stream_context
))
}
...
...
lib/llm/src/kv_router/approx.rs
View file @
66231cf0
...
...
@@ -23,7 +23,7 @@ use tokio::sync::{mpsc, oneshot};
use
tokio
::
time
::{
Duration
,
Instant
};
use
tokio_util
::
sync
::
CancellationToken
;
use
crate
::
tokens
::
TokenBlockSequence
;
use
crate
::
tokens
::
{
SequenceHash
,
TokenBlockSequence
}
;
use
crate
::
kv_router
::
indexer
::{
compute_block_hash_for_seq
,
DumpRequest
,
KvIndexerInterface
,
KvRouterError
,
OverlapScores
,
...
...
@@ -295,6 +295,26 @@ impl ApproxKvIndexer {
self
.kv_block_size
}
/// Core function to process a routing decision with pre-computed hashes
pub
async
fn
process_routing_decision
(
&
self
,
worker_id
:
WorkerId
,
local_hashes
:
Vec
<
LocalBlockHash
>
,
sequence_hashes
:
Vec
<
SequenceHash
>
,
)
->
Result
<
(),
KvRouterError
>
{
self
.route_tx
.send
(
RouterResult
{
worker_id
,
local_hashes
,
sequence_hashes
,
})
.await
.map_err
(|
_
|
KvRouterError
::
IndexerDroppedRequest
)
?
;
Ok
(())
}
/// Wrapper function that computes hashes from tokens and calls the core function
pub
async
fn
process_routing_decision_for_request
(
&
self
,
tokens
:
&
[
u32
],
...
...
@@ -309,16 +329,8 @@ impl ApproxKvIndexer {
.map
(|
b
|
b
.sequence_hash
())
.collect
::
<
Vec
<
_
>>
();
self
.route_tx
.send
(
RouterResult
{
worker_id
,
local_hashes
,
sequence_hashes
,
})
self
.process_routing_decision
(
worker_id
,
local_hashes
,
sequence_hashes
)
.await
.map_err
(|
_
|
KvRouterError
::
IndexerDroppedRequest
)
?
;
Ok
(())
}
}
...
...
lib/llm/src/kv_router/indexer.rs
View file @
66231cf0
...
...
@@ -63,6 +63,7 @@ use xxhash_rust::xxh3;
pub
const
XXH3_SEED
:
u64
=
1337
;
use
crate
::
kv_router
::
protocols
::
*
;
use
crate
::
tokens
::
SequenceHash
;
/// Errors that can occur in the KV Router.
#[derive(Debug,
thiserror::Error)]
...
...
@@ -133,6 +134,40 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<Loc
.collect
()
}
/// Compute rolling sequence hashes for a vector of block hashes.
///
/// This mirrors the behavior in tokens.rs where:
/// - The first block's sequence hash equals its block hash
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
///
/// ### Arguments
///
/// * `block_hashes` - A vector of `LocalBlockHash` values representing the block hashes.
///
/// ### Returns
///
/// A vector of u64 values representing the sequence hashes for each block.
pub
fn
compute_seq_hash_for_block
(
block_hashes
:
&
[
LocalBlockHash
])
->
Vec
<
SequenceHash
>
{
if
block_hashes
.is_empty
()
{
return
Vec
::
new
();
}
let
mut
sequence_hashes
=
Vec
::
with_capacity
(
block_hashes
.len
());
sequence_hashes
.push
(
block_hashes
[
0
]
.0
);
for
i
in
1
..
block_hashes
.len
()
{
let
parent_seq_hash
=
sequence_hashes
[
i
-
1
];
let
current_block_hash
=
block_hashes
[
i
]
.0
;
let
combined
=
[
parent_seq_hash
,
current_block_hash
];
let
bytes
:
Vec
<
u8
>
=
combined
.iter
()
.flat_map
(|
&
num
|
num
.to_le_bytes
())
.collect
();
let
seq_hash
=
compute_hash
(
&
bytes
);
sequence_hashes
.push
(
seq_hash
);
}
sequence_hashes
}
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
RouterEvent
{
...
...
lib/llm/src/kv_router/scheduler.rs
View file @
66231cf0
...
...
@@ -29,7 +29,7 @@ use crate::kv_router::protocols::LoadMetrics;
use
crate
::
kv_router
::
sequence
::
ActiveSequencesMultiWorker
;
use
crate
::
kv_router
::
KvRouterConfig
;
use
crate
::
kv_router
::
KV_HIT_RATE_SUBJECT
;
use
crate
::
tokens
::
TokenBlock
Sequence
;
use
crate
::
tokens
::
Sequence
Hash
;
use
dynamo_runtime
::
component
::
Instance
;
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
...
...
@@ -217,15 +217,13 @@ impl KvScheduler {
&
self
,
request_id
:
String
,
isl_tokens
:
usize
,
block_size
:
u32
,
tokens
:
&
[
u32
],
token_seq
:
Vec
<
SequenceHash
>
,
overlaps
:
OverlapScores
,
)
->
Result
<
i64
,
KvSchedulerError
>
{
let
mut
sequences
=
self
.sequences
.lock
()
.await
;
let
token_sequence
=
TokenBlockSequence
::
from_slice
(
tokens
,
block_size
,
None
);
let
(
potential_blocks
,
potential_tokens
)
=
sequences
.potential_blocks_and_tokens
(
token_seq
uence
,
overlaps
.clone
());
sequences
.potential_blocks_and_tokens
(
token_seq
.clone
(),
isl_tokens
,
overlaps
.clone
());
let
(
resp_tx
,
resp_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
let
request
=
SchedulingRequest
{
...
...
@@ -247,10 +245,10 @@ impl KvScheduler {
sequences
.update_workers
(
new_worker_ids
);
}
let
token_sequence
=
TokenBlockSequence
::
from_slice
(
tokens
,
block_size
,
None
);
sequences
.add_request
(
request_id
,
token_sequence
,
token_seq
,
isl_tokens
,
response
.overlap_blocks
,
response
.best_worker_id
,
);
...
...
@@ -258,10 +256,9 @@ impl KvScheduler {
Ok
(
response
.best_worker_id
)
}
/// Push tokens to a specific request's sequence
pub
async
fn
push
(
&
self
,
request_id
:
&
String
,
tokens
:
&
[
u32
])
{
pub
async
fn
mark_prefill_completed
(
&
self
,
request_id
:
&
String
)
{
let
mut
sequences
=
self
.sequences
.lock
()
.await
;
sequences
.
push
(
request_id
,
tokens
)
sequences
.
mark_prefill_completed
(
request_id
)
}
/// Free all blocks associated with a request
...
...
lib/llm/src/kv_router/sequence.rs
View file @
66231cf0
...
...
@@ -36,50 +36,24 @@
use
crate
::
kv_router
::
indexer
::
OverlapScores
;
use
crate
::
kv_router
::
indexer
::
WorkerId
;
use
crate
::
tokens
::
blocks
::
UniqueBlock
;
use
crate
::
tokens
::
TokenBlockSequence
;
use
crate
::
tokens
::
SequenceHash
;
use
derive_getters
::
Getters
;
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
sync
::{
mpsc
,
Arc
};
use
std
::
thread
;
use
std
::
time
::
Duration
;
use
uuid
;
// TODO: use the common request_id if it exists in the repo
pub
type
RequestId
=
String
;
/// Create unique blocks from a TokenBlockSequence
fn
create_unique_blocks_from_sequence
(
tokens
:
&
TokenBlockSequence
,
uuid
:
Option
<
uuid
::
Uuid
>
,
block_size
:
usize
,
)
->
Vec
<
UniqueBlock
>
{
let
mut
unique_blocks
:
Vec
<
UniqueBlock
>
=
tokens
.blocks
()
.iter
()
.map
(|
block
|
UniqueBlock
::
FullBlock
(
block
.sequence_hash
()))
.collect
();
// Only push the partial block if tokens count isn't a multiple of block_size
if
tokens
.total_tokens
()
%
block_size
!=
0
{
unique_blocks
.push
(
match
uuid
{
Some
(
uuid
)
=>
UniqueBlock
::
PartialBlock
(
uuid
),
None
=>
UniqueBlock
::
default
(),
});
}
unique_blocks
}
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug,
Getters)]
pub
struct
ActiveSequences
{
active_seqs
:
HashMap
<
RequestId
,
TokenBlockSequence
>
,
partial_blocks
:
HashMap
<
RequestId
,
UniqueBlock
>
,
active_seqs
:
HashMap
<
RequestId
,
Vec
<
SequenceHash
>>
,
prefill_tokens
:
HashMap
<
RequestId
,
usize
>
,
unique_blocks
:
HashMap
<
UniqueBlock
,
HashSet
<
RequestId
>>
,
unique_blocks
:
HashMap
<
SequenceHash
,
HashSet
<
RequestId
>>
,
#[getter(copy)]
block_size
:
usize
,
...
...
@@ -99,7 +73,6 @@ impl ActiveSequences {
Self
{
active_seqs
:
HashMap
::
new
(),
partial_blocks
:
HashMap
::
new
(),
prefill_tokens
:
HashMap
::
new
(),
unique_blocks
:
HashMap
::
new
(),
block_size
,
...
...
@@ -108,24 +81,20 @@ impl ActiveSequences {
}
}
fn
add_block
(
&
mut
self
,
request_id
:
RequestId
,
block
:
&
UniqueBlock
)
{
fn
add_block
(
&
mut
self
,
request_id
:
RequestId
,
block
:
&
SequenceHash
)
{
let
is_new_block
=
!
self
.unique_blocks
.contains_key
(
block
);
self
.unique_blocks
.entry
(
block
.clone
()
)
.entry
(
*
block
)
.or_default
()
.insert
(
request_id
.clone
());
if
is_new_block
{
self
.active_blocks
+=
1
;
}
if
matches!
(
block
,
UniqueBlock
::
PartialBlock
(
_
))
{
self
.partial_blocks
.insert
(
request_id
,
block
.clone
());
};
}
fn
remove_block
(
&
mut
self
,
request_id
:
&
RequestId
,
block
:
&
UniqueBlock
)
{
fn
remove_block
(
&
mut
self
,
request_id
:
&
RequestId
,
block
:
&
SequenceHash
)
{
let
Some
(
request_ids
)
=
self
.unique_blocks
.get_mut
(
block
)
else
{
panic!
(
"Cannot remove a block that does not exist."
)
};
...
...
@@ -142,17 +111,16 @@ impl ActiveSequences {
pub
fn
add_request
(
&
mut
self
,
request_id
:
RequestId
,
token_sequence
:
TokenBlockSequence
,
token_sequence
:
Vec
<
SequenceHash
>
,
isl
:
usize
,
overlap
:
u32
,
)
->
usize
{
let
prefill_tokens
=
self
.new_tokens
(
&
token_sequence
,
overlap
);
let
prefill_tokens
=
self
.new_tokens
(
isl
,
overlap
);
self
.prefill_tokens
.insert
(
request_id
.clone
(),
prefill_tokens
);
self
.active_tokens
+=
prefill_tokens
;
let
blocks
=
create_unique_blocks_from_sequence
(
&
token_sequence
,
None
,
self
.block_size
);
for
block
in
&
blocks
{
for
block
in
&
token_sequence
{
self
.add_block
(
request_id
.clone
(),
block
);
}
...
...
@@ -161,30 +129,35 @@ impl ActiveSequences {
self
.active_blocks
}
pub
fn
new_tokens
(
&
self
,
token_sequence
:
&
TokenBlockSequence
,
overlap
:
u32
)
->
usize
{
let
input_tokens
=
token_sequence
.total_tokens
();
input_tokens
.checked_sub
((
overlap
as
usize
)
*
self
.block_size
)
.unwrap_or_else
(||
{
panic!
(
"prefill_tokens < 0 with overlap {overlap} and ISL {input_tokens}"
)
})
/// Mark prefill as completed for a request, removing it from prefill_tokens tracking
pub
fn
mark_prefill_completed
(
&
mut
self
,
request_id
:
&
RequestId
)
{
if
let
Some
(
tokens
)
=
self
.prefill_tokens
.remove
(
request_id
)
{
self
.active_tokens
=
self
.active_tokens
.checked_sub
(
tokens
.saturating_sub
(
1
))
// Keep 1 token for decoding
.expect
(
"active_tokens underflow"
);
}
}
pub
fn
new_tokens
(
&
self
,
isl
:
usize
,
overlap
:
u32
)
->
usize
{
isl
.checked_sub
((
overlap
as
usize
)
*
self
.block_size
)
.unwrap_or_else
(||
panic!
(
"prefill_tokens < 0 with overlap {overlap} and ISL {isl}"
))
}
pub
fn
potential_blocks_and_tokens
(
&
self
,
token_sequence
:
&
TokenBlockSequence
,
token_sequence
:
&
[
SequenceHash
],
isl
:
usize
,
overlap
:
u32
,
)
->
(
usize
,
usize
)
{
let
potential_blocks
=
self
.new_blocks
(
token_sequence
)
+
self
.active_blocks
;
let
potential_tokens
=
self
.new_tokens
(
token_sequence
,
overlap
)
+
self
.active_tokens
;
let
potential_tokens
=
self
.new_tokens
(
isl
,
overlap
)
+
self
.active_tokens
;
(
potential_blocks
,
potential_tokens
)
}
/// Match a request against existing blocks and return the number of new blocks that would be added
pub
fn
new_blocks
(
&
self
,
token_sequence
:
&
TokenBlockSequence
)
->
usize
{
let
blocks
=
create_unique_blocks_from_sequence
(
token_sequence
,
None
,
self
.block_size
);
blocks
pub
fn
new_blocks
(
&
self
,
token_sequence
:
&
[
SequenceHash
])
->
usize
{
token_sequence
.iter
()
.filter
(|
block
|
!
self
.unique_blocks
.contains_key
(
block
))
.count
()
...
...
@@ -192,7 +165,7 @@ impl ActiveSequences {
/// Return the total number of blocks that would be used if the token sequence was added
/// This is the sum of new blocks that would be added plus the current active blocks
pub
fn
potential_blocks
(
&
self
,
token_sequence
:
&
TokenBlock
Sequence
)
->
usize
{
pub
fn
potential_blocks
(
&
self
,
token_sequence
:
&
[
Sequence
Hash
]
)
->
usize
{
self
.new_blocks
(
token_sequence
)
+
self
.active_blocks
}
...
...
@@ -209,110 +182,49 @@ impl ActiveSequences {
return
0
;
};
let
blocks
=
create_unique_blocks_from_sequence
(
token_seq
,
None
,
self
.block_size
);
for
block
in
blocks
{
if
matches!
(
block
,
UniqueBlock
::
FullBlock
(
_
))
{
self
.remove_block
(
request_id
,
&
block
);
}
}
if
let
Some
(
partial_block
)
=
self
.partial_blocks
.remove
(
request_id
)
{
self
.remove_block
(
request_id
,
&
partial_block
);
for
block
in
token_seq
.clone
()
{
self
.remove_block
(
request_id
,
&
block
)
}
self
.active_seqs
.remove
(
request_id
)
.unwrap
();
self
.active_blocks
}
/// Push tokens to a specific request's sequence
pub
fn
push
(
&
mut
self
,
request_id
:
&
RequestId
,
tokens
:
&
[
u32
])
->
usize
{
if
let
Some
(
prefill_tokens
)
=
self
.prefill_tokens
.get
(
request_id
)
.cloned
()
{
self
.prefill_tokens
.remove
(
request_id
);
// decoding has one active token
self
.active_tokens
=
self
.active_tokens
.checked_sub
(
prefill_tokens
)
.expect
(
"active_tokens < 0"
)
+
1
;
};
// Collect operations to perform after releasing the borrow
let
mut
blocks_to_remove
=
Vec
::
new
();
let
mut
blocks_to_add
=
Vec
::
new
();
{
let
token_seq
=
self
.active_seqs
.get_mut
(
request_id
)
.expect
(
"Request ID not found for token push"
);
for
&
token
in
tokens
{
token_seq
.append
(
token
)
.expect
(
"Token push failed."
);
// Guard: skip if we didn't cross a block boundary
if
token_seq
.total_tokens
()
%
self
.block_size
!=
1
{
continue
;
}
let
last_seq_hash
=
token_seq
.last_complete_block
()
.map
(|
block
|
block
.sequence_hash
());
// Queue operations for later
if
let
Some
(
partial_block
)
=
self
.partial_blocks
.get
(
request_id
)
.cloned
()
{
blocks_to_remove
.push
(
partial_block
);
}
if
let
Some
(
full_block
)
=
last_seq_hash
{
blocks_to_add
.push
(
UniqueBlock
::
FullBlock
(
full_block
));
}
blocks_to_add
.push
(
UniqueBlock
::
default
());
}
}
// token_seq borrow is dropped here
// Now perform all the queued operations
for
block
in
blocks_to_remove
{
self
.remove_block
(
request_id
,
&
block
);
}
for
block
in
blocks_to_add
{
self
.add_block
(
request_id
.clone
(),
&
block
);
}
self
.active_blocks
}
}
#[derive(Debug)]
enum
UpdateSequences
{
AddRequest
{
request_id
:
RequestId
,
token_sequence
:
TokenBlockSequence
,
token_sequence
:
Vec
<
SequenceHash
>
,
isl
:
usize
,
overlap
:
u32
,
},
Free
{
request_id
:
RequestId
,
},
Push
{
MarkPrefillCompleted
{
request_id
:
RequestId
,
tokens
:
Vec
<
u32
>
,
// Changed from token: u32
},
NewBlocks
{
token_sequence
:
Arc
<
TokenBlock
Sequence
>
,
token_sequence
:
Arc
<
Vec
<
Sequence
Hash
>
>
,
resp_tx
:
mpsc
::
SyncSender
<
usize
>
,
},
PotentialBlocks
{
token_sequence
:
Arc
<
TokenBlock
Sequence
>
,
token_sequence
:
Arc
<
Vec
<
Sequence
Hash
>
>
,
resp_tx
:
mpsc
::
SyncSender
<
usize
>
,
},
PotentialBlocksAndTokens
{
token_sequence
:
Arc
<
TokenBlockSequence
>
,
token_sequence
:
Arc
<
Vec
<
SequenceHash
>>
,
isl
:
usize
,
overlap
:
u32
,
resp_tx
:
mpsc
::
SyncSender
<
(
usize
,
usize
)
>
,
},
ActiveBlocks
{
resp_tx
:
mpsc
::
SyncSender
<
usize
>
,
},
ActiveTokens
{
resp_tx
:
mpsc
::
SyncSender
<
usize
>
,
},
Shutdown
,
}
...
...
@@ -357,15 +269,16 @@ impl ActiveSequencesMultiWorker {
UpdateSequences
::
AddRequest
{
request_id
,
token_sequence
,
isl
,
overlap
,
}
=>
{
active_sequences
.add_request
(
request_id
,
token_sequence
,
overlap
);
active_sequences
.add_request
(
request_id
,
token_sequence
,
isl
,
overlap
);
}
UpdateSequences
::
Free
{
request_id
}
=>
{
active_sequences
.free
(
&
request_id
);
}
UpdateSequences
::
Push
{
request_id
,
tokens
}
=>
{
active_sequences
.
push
(
&
request_id
,
&
tokens
);
// Changed to pass tokens slice
UpdateSequences
::
MarkPrefillCompleted
{
request_id
}
=>
{
active_sequences
.
mark_prefill_completed
(
&
request_id
);
}
UpdateSequences
::
NewBlocks
{
token_sequence
,
...
...
@@ -383,17 +296,25 @@ impl ActiveSequencesMultiWorker {
}
UpdateSequences
::
PotentialBlocksAndTokens
{
token_sequence
,
isl
,
overlap
,
resp_tx
,
}
=>
{
let
potential_tokens
=
active_sequences
.potential_blocks_and_tokens
(
&
token_sequence
,
overlap
);
let
potential_tokens
=
active_sequences
.potential_blocks_and_tokens
(
&
token_sequence
,
isl
,
overlap
,
);
let
_
=
resp_tx
.send
(
potential_tokens
);
}
UpdateSequences
::
ActiveBlocks
{
resp_tx
}
=>
{
let
active_blocks
=
active_sequences
.active_blocks
();
let
_
=
resp_tx
.send
(
active_blocks
);
}
UpdateSequences
::
ActiveTokens
{
resp_tx
}
=>
{
let
active_tokens
=
active_sequences
.active_tokens
();
let
_
=
resp_tx
.send
(
active_tokens
);
}
UpdateSequences
::
Shutdown
=>
{
break
;
}
...
...
@@ -443,7 +364,8 @@ impl ActiveSequencesMultiWorker {
pub
fn
add_request
(
&
mut
self
,
request_id
:
RequestId
,
token_sequence
:
TokenBlockSequence
,
token_sequence
:
Vec
<
SequenceHash
>
,
isl
:
usize
,
overlap
:
u32
,
worker_id
:
WorkerId
,
)
{
...
...
@@ -457,6 +379,7 @@ impl ActiveSequencesMultiWorker {
.send
(
UpdateSequences
::
AddRequest
{
request_id
,
token_sequence
,
isl
,
overlap
,
})
.expect
(
"Failed to send add_request command to worker"
);
...
...
@@ -478,18 +401,19 @@ impl ActiveSequencesMultiWorker {
self
.request_to_worker
.remove
(
request_id
);
}
pub
fn
push
(
&
mut
self
,
request_id
:
&
RequestId
,
tokens
:
&
[
u32
])
{
/// Mark prefill as completed for a request
pub
fn
mark_prefill_completed
(
&
mut
self
,
request_id
:
&
RequestId
)
{
let
worker_id
=
self
.request_to_worker
.get
(
request_id
)
.copied
()
.expect
(
"Request ID not found in request_to_worker mapping"
);
self
.senders
[
&
worker_id
]
.send
(
UpdateSequences
::
Push
{
.send
(
UpdateSequences
::
MarkPrefillCompleted
{
request_id
:
request_id
.clone
(),
tokens
:
tokens
.to_vec
(),
// Convert to Vec
})
.expect
(
"Failed to send
push
command to worker"
);
.expect
(
"Failed to send
mark_prefill_completed
command to worker"
);
}
/// Get the number of workers
...
...
@@ -500,8 +424,8 @@ impl ActiveSequencesMultiWorker {
/// Generic method to query all workers with a given command
fn
query_workers
(
&
self
,
token_sequence
:
Option
<
TokenBlock
Sequence
>
,
command_fn
:
impl
Fn
(
Option
<
Arc
<
TokenBlock
Sequence
>>
,
mpsc
::
SyncSender
<
usize
>
)
->
UpdateSequences
,
token_sequence
:
Option
<
Vec
<
Sequence
Hash
>
>
,
command_fn
:
impl
Fn
(
Option
<
Arc
<
Vec
<
Sequence
Hash
>
>>
,
mpsc
::
SyncSender
<
usize
>
)
->
UpdateSequences
,
)
->
HashMap
<
WorkerId
,
usize
>
{
let
mut
results
=
HashMap
::
new
();
let
token_sequence_shared
=
token_sequence
.map
(
Arc
::
new
);
...
...
@@ -528,7 +452,7 @@ impl ActiveSequencesMultiWorker {
}
/// Query all workers for the number of new blocks that would be added by a token sequence
pub
fn
new_blocks
(
&
self
,
token_sequence
:
TokenBlock
Sequence
)
->
HashMap
<
WorkerId
,
usize
>
{
pub
fn
new_blocks
(
&
self
,
token_sequence
:
Vec
<
Sequence
Hash
>
)
->
HashMap
<
WorkerId
,
usize
>
{
self
.query_workers
(
Some
(
token_sequence
),
|
ts
,
resp_tx
|
match
ts
{
Some
(
ts
)
=>
UpdateSequences
::
NewBlocks
{
token_sequence
:
ts
,
...
...
@@ -539,7 +463,7 @@ impl ActiveSequencesMultiWorker {
}
/// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
pub
fn
potential_blocks
(
&
self
,
token_sequence
:
TokenBlock
Sequence
)
->
HashMap
<
WorkerId
,
usize
>
{
pub
fn
potential_blocks
(
&
self
,
token_sequence
:
Vec
<
Sequence
Hash
>
)
->
HashMap
<
WorkerId
,
usize
>
{
self
.query_workers
(
Some
(
token_sequence
),
|
ts
,
resp_tx
|
match
ts
{
Some
(
ts
)
=>
UpdateSequences
::
PotentialBlocks
{
token_sequence
:
ts
,
...
...
@@ -552,7 +476,8 @@ impl ActiveSequencesMultiWorker {
/// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap
pub
fn
potential_blocks_and_tokens
(
&
self
,
token_sequence
:
TokenBlockSequence
,
token_sequence
:
Vec
<
SequenceHash
>
,
isl
:
usize
,
overlaps
:
OverlapScores
,
)
->
(
HashMap
<
WorkerId
,
usize
>
,
HashMap
<
WorkerId
,
usize
>
)
{
let
mut
potential_blocks
=
HashMap
::
new
();
...
...
@@ -568,6 +493,7 @@ impl ActiveSequencesMultiWorker {
sender
.send
(
UpdateSequences
::
PotentialBlocksAndTokens
{
token_sequence
:
token_sequence_shared
.clone
(),
isl
,
overlap
:
overlaps
.scores
.get
(
worker_id
)
.copied
()
.unwrap_or
(
0
),
resp_tx
,
})
...
...
@@ -590,6 +516,11 @@ impl ActiveSequencesMultiWorker {
pub
fn
active_blocks
(
&
self
)
->
HashMap
<
WorkerId
,
usize
>
{
self
.query_workers
(
None
,
|
_
,
resp_tx
|
UpdateSequences
::
ActiveBlocks
{
resp_tx
})
}
/// Query all workers for their current number of active tokens
pub
fn
active_tokens
(
&
self
)
->
HashMap
<
WorkerId
,
usize
>
{
self
.query_workers
(
None
,
|
_
,
resp_tx
|
UpdateSequences
::
ActiveTokens
{
resp_tx
})
}
}
impl
Drop
for
ActiveSequencesMultiWorker
{
...
...
@@ -609,91 +540,102 @@ impl Drop for ActiveSequencesMultiWorker {
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
tokens
::
Tokens
;
#[test]
fn
test_shared_sequence_manager_operations
()
{
let
block_size
=
4
;
let
mut
manager
=
ActiveSequences
::
new
(
block_size
);
let
to_sequence
=
|
tokens
:
Vec
<
u32
>
|
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
as
u32
,
None
);
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager
.add_request
(
"0"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
]),
0
);
manager
.push
(
&
"0"
.to_string
(),
&
[
3
,
4
]);
// Push both tokens at once
assert_eq!
(
manager
.active_tokens
(),
1
);
assert_eq!
(
manager
.active_blocks
(),
2
);
assert_eq!
(
manager
.partial_blocks
.len
(),
1
);
// Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6]
manager
.add_request
(
"1"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]),
1
);
assert_eq!
(
manager
.active_tokens
(),
1
+
3
);
assert_eq!
(
manager
.active_blocks
(),
3
);
// Check that only one key is FullBlock with both requests sharing it
let
mut
full_block_count
=
0
;
let
mut
shared_block_requests
=
None
;
for
(
block
,
requests
)
in
&
manager
.unique_blocks
{
if
let
UniqueBlock
::
FullBlock
(
_
)
=
block
{
full_block_count
+=
1
;
if
requests
.len
()
==
2
{
shared_block_requests
=
Some
(
requests
.clone
());
}
}
}
assert_eq!
(
full_block_count
,
1
);
assert
!
(
shared_block_requests
.is_some
());
let
shared_requests
=
shared_block_requests
.unwrap
();
assert
!
(
shared_requests
.contains
(
"0"
));
assert
!
(
shared_requests
.contains
(
"1"
));
let
new_blocks
=
manager
.new_blocks
(
&
to_sequence
(
vec!
[
0
,
1
,
2
,
3
,
4
,
5
]));
assert_eq!
(
new_blocks
,
1
);
// Step 3: Free request 1
manager
.free
(
&
"1"
.to_string
());
assert_eq!
(
manager
.active_blocks
(),
2
);
// Step 4: Free request 0
manager
.free
(
&
"0"
.to_string
());
assert_eq!
(
manager
.active_tokens
(),
0
);
assert_eq!
(
manager
.active_blocks
(),
0
);
assert_eq!
(
manager
.unique_blocks
.len
(),
0
);
assert_eq!
(
manager
.partial_blocks
.len
(),
0
);
assert_eq!
(
manager
.active_seqs
.len
(),
0
);
}
#[test]
fn
test_active_sequences_multi_worker
()
{
let
block_size
=
4
;
fn
test_multi_worker_block_sharing
()
{
// Create multi-worker sequence manager with 3 workers
let
block_size
=
4
;
// arbitrary block size
let
worker_ids
=
vec!
[
0
,
1
,
2
];
let
mut
manager
=
ActiveSequencesMultiWorker
::
new
(
block_size
,
worker_ids
);
let
to_sequence
=
|
tokens
:
Vec
<
u32
>
|
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
as
u32
,
None
);
// Send request [0, 1, 2, 3] to worker 0
manager
.add_request
(
"req0"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
,
3
]),
0
,
0
);
// Send request [0, 1, 2] to worker 1, then push 3 and 4
manager
.add_request
(
"req1"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
]),
0
,
1
);
manager
.push
(
&
"req1"
.to_string
(),
&
[
3
,
4
]);
// Push both tokens at once
// Send request [0, 1, 2] to worker 2
manager
.add_request
(
"req2"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
]),
0
,
2
);
// Check new_blocks on tokens [0, 1, 2, 3, 4]
let
new_blocks_map
=
manager
.new_blocks
(
to_sequence
(
vec!
[
0
,
1
,
2
,
3
,
4
]));
assert_eq!
(
new_blocks_map
[
&
0
],
1
);
// Worker 0 would have 1 new block
assert_eq!
(
new_blocks_map
[
&
1
],
1
);
// Worker 1 would have 1 new block
assert_eq!
(
new_blocks_map
[
&
2
],
2
);
// Worker 2 would have 2 new blocks
manager
.update_workers
(
vec!
[
0
,
1
]);
manager
.update_workers
(
vec!
[
0
,
1
,
3
]);
let
new_blocks_map
=
manager
.new_blocks
(
to_sequence
(
vec!
[
0
,
1
,
2
,
3
,
4
]));
assert_eq!
(
new_blocks_map
.len
(),
3
);
assert_eq!
(
new_blocks_map
[
&
3
],
2
);
let
mut
seq_manager
=
ActiveSequencesMultiWorker
::
new
(
block_size
,
worker_ids
);
// Add requests to each worker
// Worker 0: sequence [0, 1, 2]
seq_manager
.add_request
(
"request_0"
.to_string
(),
vec!
[
0
,
1
,
2
],
12
,
// ISL (3 blocks * 4 block_size)
0
,
// no overlap
0
,
// worker_id
);
// Worker 1: sequence [3, 4]
seq_manager
.add_request
(
"request_1"
.to_string
(),
vec!
[
3
,
4
],
8
,
// ISL (2 blocks * 4 block_size)
0
,
// no overlap
1
,
// worker_id
);
// Worker 2: sequence [0, 1, 2, 3]
seq_manager
.add_request
(
"request_2"
.to_string
(),
vec!
[
0
,
1
,
2
,
3
],
16
,
// ISL (4 blocks * 4 block_size)
0
,
// no overlap
2
,
// worker_id
);
// Verify active tokens after adding requests
let
tokens_after_add
=
seq_manager
.active_tokens
();
assert_eq!
(
tokens_after_add
[
&
0
],
12
,
"Worker 0 should have 12 active tokens"
);
assert_eq!
(
tokens_after_add
[
&
1
],
8
,
"Worker 1 should have 8 active tokens"
);
assert_eq!
(
tokens_after_add
[
&
2
],
16
,
"Worker 2 should have 16 active tokens"
);
// Test potential blocks for sequence [0, 1]
let
potential_blocks
=
seq_manager
.potential_blocks
(
vec!
[
0
,
1
]);
// Worker 0 should return 3 (already has blocks 0, 1, 2, so no new blocks needed for [0, 1])
assert_eq!
(
potential_blocks
[
&
0
],
3
,
"Worker 0 should have 3 potential blocks"
);
// Worker 1 should return 4 (has blocks 3, 4, would need to add blocks 0, 1)
assert_eq!
(
potential_blocks
[
&
1
],
4
,
"Worker 1 should have 4 potential blocks"
);
// Worker 2 should return 4 (already has blocks 0, 1, 2, 3, so no new blocks needed for [0, 1])
assert_eq!
(
potential_blocks
[
&
2
],
4
,
"Worker 2 should have 4 potential blocks"
);
// Free all original requests
seq_manager
.free
(
&
"request_0"
.to_string
());
seq_manager
.free
(
&
"request_1"
.to_string
());
seq_manager
.free
(
&
"request_2"
.to_string
());
// Verify active blocks are zero for all workers
let
active_blocks
=
seq_manager
.active_blocks
();
assert_eq!
(
active_blocks
[
&
0
],
0
,
"Worker 0 should have 0 active blocks"
);
assert_eq!
(
active_blocks
[
&
1
],
0
,
"Worker 1 should have 0 active blocks"
);
assert_eq!
(
active_blocks
[
&
2
],
0
,
"Worker 2 should have 0 active blocks"
);
// Verify active tokens are zero for all workers
let
final_tokens
=
seq_manager
.active_tokens
();
assert_eq!
(
final_tokens
[
&
0
],
0
,
"Worker 0 should have 0 active tokens after freeing all"
);
assert_eq!
(
final_tokens
[
&
1
],
0
,
"Worker 1 should have 0 active tokens after freeing all"
);
assert_eq!
(
final_tokens
[
&
2
],
0
,
"Worker 2 should have 0 active tokens after freeing all"
);
}
}
lib/llm/src/kv_router/worker.rs
deleted
100644 → 0
View file @
dbd33df6
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
sync
::
Arc
;
pub
use
crate
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
anyhow
::
Result
;
use
derive_builder
::
Builder
;
use
dynamo_runtime
::
pipeline
::
network
::{
ingress
::
push_endpoint
::
PushEndpoint
,
PushWorkHandler
,
};
use
dynamo_runtime
::
transports
::
nats
::{
self
,
ServiceExt
};
use
tokio
::
sync
::
watch
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tracing
as
log
;
#[derive(Builder)]
pub
struct
KvRoutedIngress
{
#[builder(setter(into))]
pub
service_name
:
String
,
#[builder(setter(into))]
pub
worker_id
:
String
,
pub
nats
:
nats
::
Client
,
pub
service_handler
:
Arc
<
dyn
PushWorkHandler
>
,
pub
metrics_rx
:
watch
::
Receiver
<
Arc
<
ForwardPassMetrics
>>
,
pub
cancellation_token
:
CancellationToken
,
}
/// version of crate
pub
const
VERSION
:
&
str
=
env!
(
"CARGO_PKG_VERSION"
);
impl
KvRoutedIngress
{
pub
fn
builder
()
->
KvRoutedIngressBuilder
{
KvRoutedIngressBuilder
::
default
()
}
pub
async
fn
start
(
self
)
->
Result
<
()
>
{
let
worker_id
=
self
.worker_id
;
log
::
trace!
(
worker_id
,
"Starting nats service: {}:{}"
,
self
.service_name
,
VERSION
);
let
mut
metrics_rx
=
self
.metrics_rx
;
let
worker_id_clone
=
worker_id
.clone
();
let
service
=
self
.nats
.client
()
.service_builder
()
.description
(
"A handy min max service"
)
.stats_handler
(
move
|
name
,
stats
|
{
log
::
debug!
(
worker_id
=
worker_id_clone
.as_str
(),
"[IN worker?] Stats for service {}: {:?}"
,
name
,
stats
);
let
metrics
=
metrics_rx
.borrow_and_update
()
.clone
();
serde_json
::
to_value
(
&*
metrics
)
.unwrap
()
})
.start
(
self
.service_name
.as_str
(),
VERSION
)
.await
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to start service: {e}"
))
?
;
let
group
=
service
.group
(
self
.service_name
.as_str
());
log
::
trace!
(
worker_id
,
"Starting endpoint: {}"
,
worker_id
);
// creates an endpoint for the service
let
service_endpoint
=
group
.endpoint
(
worker_id
.clone
())
.await
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to start endpoint: {e}"
))
?
;
let
push_endpoint
=
PushEndpoint
::
builder
()
.service_handler
(
self
.service_handler
)
.cancellation_token
(
self
.cancellation_token
)
.build
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to build push endpoint: {e}"
))
?
;
push_endpoint
.start
(
service_endpoint
)
.await
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment