Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
df91fce2
Unverified
Commit
df91fce2
authored
Jul 14, 2025
by
Yan Ru Pei
Committed by
GitHub
Jul 14, 2025
Browse files
feat: prefill aware routing (#1895)
parent
ad8ad66b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
169 additions
and
58 deletions
+169
-58
components/metrics/src/bin/mock_worker.rs
components/metrics/src/bin/mock_worker.rs
+1
-1
components/metrics/src/main.rs
components/metrics/src/main.rs
+1
-1
docs/guides/dynamo_run.md
docs/guides/dynamo_run.md
+1
-1
launch/dynamo-run/src/flags.rs
launch/dynamo-run/src/flags.rs
+2
-2
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+11
-3
lib/llm/src/kv_router/protocols.rs
lib/llm/src/kv_router/protocols.rs
+1
-1
lib/llm/src/kv_router/scheduler.rs
lib/llm/src/kv_router/scheduler.rs
+39
-42
lib/llm/src/kv_router/sequence.rs
lib/llm/src/kv_router/sequence.rs
+113
-7
No files found.
components/metrics/src/bin/mock_worker.rs
View file @
df91fce2
...
@@ -93,7 +93,7 @@ async fn mock_event_publisher(namespace: Namespace) {
...
@@ -93,7 +93,7 @@ async fn mock_event_publisher(namespace: Namespace) {
let
event
=
KVHitRateEvent
{
let
event
=
KVHitRateEvent
{
worker_id
,
worker_id
,
isl_blocks
,
isl_blocks
,
overlap_blocks
,
overlap_blocks
:
overlap_blocks
as
u32
,
};
};
if
let
Err
(
e
)
=
namespace
.publish
(
KV_HIT_RATE_SUBJECT
,
&
event
)
.await
{
if
let
Err
(
e
)
=
namespace
.publish
(
KV_HIT_RATE_SUBJECT
,
&
event
)
.await
{
...
...
components/metrics/src/main.rs
View file @
df91fce2
...
@@ -199,7 +199,7 @@ async fn app(runtime: Runtime) -> Result<()> {
...
@@ -199,7 +199,7 @@ async fn app(runtime: Runtime) -> Result<()> {
&
config_clone
,
&
config_clone
,
event
.worker_id
,
event
.worker_id
,
event
.isl_blocks
,
event
.isl_blocks
,
event
.overlap_blocks
,
event
.overlap_blocks
as
usize
,
);
);
}
}
Err
(
e
)
=>
{
Err
(
e
)
=>
{
...
...
docs/guides/dynamo_run.md
View file @
df91fce2
...
@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
...
@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
Usage:
Usage:
```
```
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.
5
] [--use-kv-events=true] [--verbosity (-v|-vv)]
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.
0
] [--use-kv-events=true] [--verbosity (-v|-vv)]
```
```
Example:
`dynamo run Qwen/Qwen3-0.6B`
Example:
`dynamo run Qwen/Qwen3-0.6B`
...
...
launch/dynamo-run/src/flags.rs
View file @
df91fce2
...
@@ -118,13 +118,13 @@ pub struct Flags {
...
@@ -118,13 +118,13 @@ pub struct Flags {
pub
max_num_batched_tokens
:
Option
<
u32
>
,
pub
max_num_batched_tokens
:
Option
<
u32
>
,
/// KV Router: Weight for overlap score in worker selection.
/// KV Router: Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default:
2
.0
/// Higher values prioritize KV cache reuse. Default:
1
.0
#[arg(long)]
#[arg(long)]
pub
kv_overlap_score_weight
:
Option
<
f64
>
,
pub
kv_overlap_score_weight
:
Option
<
f64
>
,
/// KV Router: Temperature for worker sampling via softmax.
/// KV Router: Temperature for worker sampling via softmax.
/// Higher values promote more randomness, and 0 fallbacks to deterministic.
/// Higher values promote more randomness, and 0 fallbacks to deterministic.
/// Default: 0.
5
/// Default: 0.
0
#[arg(long)]
#[arg(long)]
pub
router_temperature
:
Option
<
f64
>
,
pub
router_temperature
:
Option
<
f64
>
,
...
...
lib/llm/src/kv_router.rs
View file @
df91fce2
...
@@ -78,7 +78,7 @@ impl Default for KvRouterConfig {
...
@@ -78,7 +78,7 @@ impl Default for KvRouterConfig {
fn
default
()
->
Self
{
fn
default
()
->
Self
{
Self
{
Self
{
overlap_score_weight
:
1.0
,
overlap_score_weight
:
1.0
,
router_temperature
:
0.
5
,
router_temperature
:
0.
0
,
use_kv_events
:
true
,
use_kv_events
:
true
,
max_num_batched_tokens
:
8192
,
max_num_batched_tokens
:
8192
,
}
}
...
@@ -337,6 +337,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
...
@@ -337,6 +337,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let
mut
accumulated_tokens
=
Vec
::
new
();
let
mut
accumulated_tokens
=
Vec
::
new
();
let
mut
total_output_length
=
0u
size
;
let
mut
total_output_length
=
0u
size
;
let
mut
last_block_index
=
(
isl
.saturating_sub
(
1
))
/
block_size
;
let
mut
last_block_index
=
(
isl
.saturating_sub
(
1
))
/
block_size
;
let
mut
first_push_done
=
false
;
while
let
Some
(
item
)
=
response_stream
.next
()
.await
{
while
let
Some
(
item
)
=
response_stream
.next
()
.await
{
// Track tokens if they exist in the response
// Track tokens if they exist in the response
...
@@ -353,12 +354,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
...
@@ -353,12 +354,19 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
accumulated_tokens
.extend_from_slice
(
&
output
.token_ids
);
accumulated_tokens
.extend_from_slice
(
&
output
.token_ids
);
total_output_length
+=
output
.token_ids
.len
();
total_output_length
+=
output
.token_ids
.len
();
// Check if we've moved to a new block
// Always push for the first generated token (to mark prefill done)
// or when we've moved to a new block
let
current_block_index
=
(
isl
+
total_output_length
)
.saturating_sub
(
1
)
/
block_size
;
let
current_block_index
=
(
isl
+
total_output_length
)
.saturating_sub
(
1
)
/
block_size
;
if
current_block_index
>
last_block_index
{
let
should_push
=
(
!
first_push_done
&&
total_output_length
>=
1
)
||
(
first_push_done
&&
current_block_index
>
last_block_index
);
if
should_push
{
chooser
.push
(
&
request_id
,
&
accumulated_tokens
)
.await
;
chooser
.push
(
&
request_id
,
&
accumulated_tokens
)
.await
;
accumulated_tokens
.clear
();
accumulated_tokens
.clear
();
last_block_index
=
current_block_index
;
last_block_index
=
current_block_index
;
if
!
first_push_done
{
first_push_done
=
true
;
}
}
}
yield
item
;
yield
item
;
...
...
lib/llm/src/kv_router/protocols.rs
View file @
df91fce2
...
@@ -36,7 +36,7 @@ pub struct WorkerSelectionResult {
...
@@ -36,7 +36,7 @@ pub struct WorkerSelectionResult {
/// The number of blocks that the selected worker may already have cached.
/// The number of blocks that the selected worker may already have cached.
/// This is not a guarantee, but an estimate.
/// This is not a guarantee, but an estimate.
pub
overlap_blocks
:
u
size
,
pub
overlap_blocks
:
u
32
,
}
}
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Default,
PartialEq)]
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Default,
PartialEq)]
...
...
lib/llm/src/kv_router/scheduler.rs
View file @
df91fce2
...
@@ -25,7 +25,6 @@ use tokio::sync::Mutex;
...
@@ -25,7 +25,6 @@ use tokio::sync::Mutex;
use
super
::
protocols
::
WorkerSelectionResult
;
use
super
::
protocols
::
WorkerSelectionResult
;
use
super
::
WorkerSelector
;
use
super
::
WorkerSelector
;
use
crate
::
kv_router
::
indexer
::
OverlapScores
;
use
crate
::
kv_router
::
indexer
::
OverlapScores
;
use
crate
::
kv_router
::
indexer
::
WorkerId
;
use
crate
::
kv_router
::
protocols
::
LoadMetrics
;
use
crate
::
kv_router
::
protocols
::
LoadMetrics
;
use
crate
::
kv_router
::
scoring
::
ProcessedEndpoints
;
use
crate
::
kv_router
::
scoring
::
ProcessedEndpoints
;
use
crate
::
kv_router
::
sequence
::
ActiveSequencesMultiWorker
;
use
crate
::
kv_router
::
sequence
::
ActiveSequencesMultiWorker
;
...
@@ -37,7 +36,7 @@ use crate::tokens::TokenBlockSequence;
...
@@ -37,7 +36,7 @@ use crate::tokens::TokenBlockSequence;
pub
struct
KVHitRateEvent
{
pub
struct
KVHitRateEvent
{
pub
worker_id
:
i64
,
pub
worker_id
:
i64
,
pub
isl_blocks
:
usize
,
pub
isl_blocks
:
usize
,
pub
overlap_blocks
:
u
size
,
pub
overlap_blocks
:
u
32
,
}
}
#[derive(Debug,
thiserror::Error)]
#[derive(Debug,
thiserror::Error)]
...
@@ -79,13 +78,15 @@ impl Endpoint {
...
@@ -79,13 +78,15 @@ impl Endpoint {
#[derive(Debug)]
#[derive(Debug)]
pub
struct
SchedulingResponse
{
pub
struct
SchedulingResponse
{
pub
best_worker_id
:
i64
,
pub
best_worker_id
:
i64
,
pub
overlap_blocks
:
u32
,
// Add this field
pub
endpoints_changed
:
Option
<
Vec
<
i64
>>
,
pub
endpoints_changed
:
Option
<
Vec
<
i64
>>
,
}
}
pub
struct
SchedulingRequest
{
pub
struct
SchedulingRequest
{
pub
isl_tokens
:
usize
,
pub
isl_tokens
:
usize
,
pub
overlap
:
OverlapScores
,
pub
overlap
s
:
OverlapScores
,
pub
potential_blocks
:
HashMap
<
i64
,
usize
>
,
pub
potential_blocks
:
HashMap
<
i64
,
usize
>
,
pub
potential_tokens
:
HashMap
<
i64
,
usize
>
,
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
SchedulingResponse
>
,
resp_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
SchedulingResponse
>
,
}
}
...
@@ -174,6 +175,7 @@ impl KvScheduler {
...
@@ -174,6 +175,7 @@ impl KvScheduler {
let
response
=
SchedulingResponse
{
let
response
=
SchedulingResponse
{
best_worker_id
:
selection
.worker_id
,
best_worker_id
:
selection
.worker_id
,
overlap_blocks
:
selection
.overlap_blocks
,
endpoints_changed
:
pending_endpoint_update
.take
(),
endpoints_changed
:
pending_endpoint_update
.take
(),
};
};
request
.respond
(
response
);
request
.respond
(
response
);
...
@@ -207,18 +209,20 @@ impl KvScheduler {
...
@@ -207,18 +209,20 @@ impl KvScheduler {
isl_tokens
:
usize
,
isl_tokens
:
usize
,
block_size
:
u32
,
block_size
:
u32
,
tokens
:
&
[
u32
],
tokens
:
&
[
u32
],
overlap
:
OverlapScores
,
overlap
s
:
OverlapScores
,
)
->
Result
<
i64
,
KvSchedulerError
>
{
)
->
Result
<
i64
,
KvSchedulerError
>
{
let
mut
sequences
=
self
.sequences
.lock
()
.await
;
let
mut
sequences
=
self
.sequences
.lock
()
.await
;
let
token_sequence
=
TokenBlockSequence
::
from_slice
(
tokens
,
block_size
,
None
);
let
token_sequence
=
TokenBlockSequence
::
from_slice
(
tokens
,
block_size
,
None
);
let
potential_blocks
=
sequences
.potential_blocks
(
token_sequence
);
let
(
potential_blocks
,
potential_tokens
)
=
sequences
.potential_blocks_and_tokens
(
token_sequence
,
overlaps
.clone
());
let
(
resp_tx
,
resp_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
let
(
resp_tx
,
resp_rx
)
=
tokio
::
sync
::
oneshot
::
channel
();
let
request
=
SchedulingRequest
{
let
request
=
SchedulingRequest
{
isl_tokens
,
isl_tokens
,
overlap
,
overlap
s
,
potential_blocks
,
potential_blocks
,
potential_tokens
,
resp_tx
,
resp_tx
,
};
};
self
.request_tx
self
.request_tx
...
@@ -234,31 +238,16 @@ impl KvScheduler {
...
@@ -234,31 +238,16 @@ impl KvScheduler {
}
}
let
token_sequence
=
TokenBlockSequence
::
from_slice
(
tokens
,
block_size
,
None
);
let
token_sequence
=
TokenBlockSequence
::
from_slice
(
tokens
,
block_size
,
None
);
sequences
.add_request
(
request_id
,
token_sequence
,
response
.best_worker_id
);
sequences
.add_request
(
request_id
,
token_sequence
,
response
.overlap_blocks
,
response
.best_worker_id
,
);
Ok
(
response
.best_worker_id
)
Ok
(
response
.best_worker_id
)
}
}
/// Find the potential blocks for each worker if the sequence were routed there
pub
async
fn
potential_blocks
(
&
self
,
token_sequence
:
TokenBlockSequence
,
)
->
HashMap
<
i64
,
usize
>
{
let
sequences
=
self
.sequences
.lock
()
.await
;
sequences
.potential_blocks
(
token_sequence
)
}
/// Add a new request with its initial tokens to a specific worker
pub
async
fn
add_request
(
&
self
,
request_id
:
String
,
token_sequence
:
TokenBlockSequence
,
worker_id
:
WorkerId
,
)
{
let
mut
sequences
=
self
.sequences
.lock
()
.await
;
sequences
.add_request
(
request_id
,
token_sequence
,
worker_id
)
}
/// Push tokens to a specific request's sequence
/// Push tokens to a specific request's sequence
pub
async
fn
push
(
&
self
,
request_id
:
&
String
,
tokens
:
&
[
u32
])
{
pub
async
fn
push
(
&
self
,
request_id
:
&
String
,
tokens
:
&
[
u32
])
{
let
mut
sequences
=
self
.sequences
.lock
()
.await
;
let
mut
sequences
=
self
.sequences
.lock
()
.await
;
...
@@ -370,34 +359,47 @@ impl WorkerSelector for DefaultWorkerSelector {
...
@@ -370,34 +359,47 @@ impl WorkerSelector for DefaultWorkerSelector {
return
Err
(
KvSchedulerError
::
NoEndpoints
);
return
Err
(
KvSchedulerError
::
NoEndpoints
);
}
}
let
request_blocks
=
request
.isl_tokens
.div_ceil
(
block_size
as
usize
);
let
isl
=
request
.isl_tokens
;
let
request_blocks
=
isl
.div_ceil
(
block_size
as
usize
);
let
overlaps
=
&
request
.overlaps.scores
;
// active blocks for decoding
let
potential_active_blocks
=
&
request
.potential_blocks
;
let
potential_active_blocks
=
&
request
.potential_blocks
;
// active tokens in the batch (processed by the linear layers), mostly prefill tokens
let
potential_active_tokens
=
&
request
.potential_tokens
;
let
mut
worker_logits
=
HashMap
::
new
();
let
mut
worker_logits
=
HashMap
::
new
();
let
mut
max_logit
=
f64
::
NEG_INFINITY
;
let
mut
max_logit
=
f64
::
NEG_INFINITY
;
// Calculate logits for each worker
// Calculate logits for each worker
for
(
worker_id
,
_
)
in
workers
.endpoints
.iter
()
{
for
(
worker_id
,
_
)
in
workers
.endpoints
.iter
()
{
let
cached_blocks
=
request
.overlap.scores
.get
(
worker_id
)
.copied
()
.unwrap_or
(
0
)
as
f64
;
// this is the number of tokens each worker would have if the request were scheduled there
let
prefill_blocks
=
request_blocks
as
f64
-
cached_blocks
;
let
potential_tokens
=
*
potential_active_tokens
.get
(
worker_id
)
.unwrap_or_else
(||
{
tracing
::
warn!
(
"assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet"
);
&
isl
})
as
f64
;
// this is the number of blocks each worker would have if the request were scheduled there
// this is the number of blocks each worker would have if the request were scheduled there
let
potential_blocks
=
*
potential_active_blocks
.get
(
worker_id
)
.unwrap_or_else
(||
let
potential_blocks
=
*
potential_active_blocks
.get
(
worker_id
)
.unwrap_or_else
(||
{
tracing
::
warn!
(
"assuming
0
decoding blocks for {worker_id}, as the
load metrics
endpoint does not exist yet"
);
{
tracing
::
warn!
(
"assuming
{request_blocks}
decoding blocks for {worker_id}, as the endpoint does not exist yet"
);
&
0
&
request_blocks
})
as
f64
;
})
as
f64
;
let
potential_prefill_blocks
=
potential_tokens
/
(
block_size
as
f64
);
// Calculate logit (lower is better)
// Calculate logit (lower is better)
let
logit
=
let
logit
=
self
.kv_router_config.overlap_score_weight
*
potential_prefill_blocks
self
.kv_router_config.overlap_score_weight
*
prefill_blocks
+
potential_blocks
;
+
potential_blocks
;
max_logit
=
max_logit
.max
(
logit
);
max_logit
=
max_logit
.max
(
logit
);
worker_logits
.insert
(
*
worker_id
,
logit
);
worker_logits
.insert
(
*
worker_id
,
logit
);
tracing
::
info!
(
tracing
::
info!
(
"Formula for {worker_id}: {logit:.3} = {:.1} * {prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {
cached_blocks
})"
,
"Formula for {worker_id}: {logit:.3} = {:.1} * {
potential_
prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})"
,
self
.kv_router_config.overlap_score_weight
,
self
.kv_router_config.overlap_score_weight
,
cached_blocks
=
cached_blocks
overlaps
.get
(
worker_id
)
.unwrap_or
(
&
0
),
);
);
}
}
...
@@ -412,12 +414,7 @@ impl WorkerSelector for DefaultWorkerSelector {
...
@@ -412,12 +414,7 @@ impl WorkerSelector for DefaultWorkerSelector {
let
temperature
=
self
.kv_router_config.router_temperature
;
let
temperature
=
self
.kv_router_config.router_temperature
;
let
best_worker_id
=
softmax_sample
(
&
worker_logits
,
temperature
);
let
best_worker_id
=
softmax_sample
(
&
worker_logits
,
temperature
);
let
overlap_blocks
=
request
let
overlap_blocks
=
overlaps
.get
(
&
best_worker_id
)
.copied
()
.unwrap_or
(
0
);
.overlap
.scores
.get
(
&
best_worker_id
)
.copied
()
.unwrap_or
(
0
)
as
usize
;
let
best_logit
=
worker_logits
[
&
best_worker_id
];
let
best_logit
=
worker_logits
[
&
best_worker_id
];
tracing
::
info!
(
tracing
::
info!
(
...
...
lib/llm/src/kv_router/sequence.rs
View file @
df91fce2
...
@@ -34,6 +34,7 @@
...
@@ -34,6 +34,7 @@
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).
//! requests share common prefixes (e.g., system prompts, few-shot examples).
use
crate
::
kv_router
::
indexer
::
OverlapScores
;
use
crate
::
kv_router
::
indexer
::
WorkerId
;
use
crate
::
kv_router
::
indexer
::
WorkerId
;
use
crate
::
tokens
::
blocks
::
UniqueBlock
;
use
crate
::
tokens
::
blocks
::
UniqueBlock
;
use
crate
::
tokens
::
TokenBlockSequence
;
use
crate
::
tokens
::
TokenBlockSequence
;
...
@@ -76,6 +77,8 @@ pub struct ActiveSequences {
...
@@ -76,6 +77,8 @@ pub struct ActiveSequences {
partial_blocks
:
HashMap
<
RequestId
,
UniqueBlock
>
,
partial_blocks
:
HashMap
<
RequestId
,
UniqueBlock
>
,
prefill_tokens
:
HashMap
<
RequestId
,
usize
>
,
unique_blocks
:
HashMap
<
UniqueBlock
,
HashSet
<
RequestId
>>
,
unique_blocks
:
HashMap
<
UniqueBlock
,
HashSet
<
RequestId
>>
,
#[getter(copy)]
#[getter(copy)]
...
@@ -83,6 +86,9 @@ pub struct ActiveSequences {
...
@@ -83,6 +86,9 @@ pub struct ActiveSequences {
#[getter(copy)]
#[getter(copy)]
active_blocks
:
usize
,
active_blocks
:
usize
,
#[getter(copy)]
active_tokens
:
usize
,
}
}
impl
ActiveSequences
{
impl
ActiveSequences
{
...
@@ -94,9 +100,11 @@ impl ActiveSequences {
...
@@ -94,9 +100,11 @@ impl ActiveSequences {
Self
{
Self
{
active_seqs
:
HashMap
::
new
(),
active_seqs
:
HashMap
::
new
(),
partial_blocks
:
HashMap
::
new
(),
partial_blocks
:
HashMap
::
new
(),
prefill_tokens
:
HashMap
::
new
(),
unique_blocks
:
HashMap
::
new
(),
unique_blocks
:
HashMap
::
new
(),
block_size
,
block_size
,
active_blocks
:
0
,
active_blocks
:
0
,
active_tokens
:
0
,
}
}
}
}
...
@@ -135,7 +143,13 @@ impl ActiveSequences {
...
@@ -135,7 +143,13 @@ impl ActiveSequences {
&
mut
self
,
&
mut
self
,
request_id
:
RequestId
,
request_id
:
RequestId
,
token_sequence
:
TokenBlockSequence
,
token_sequence
:
TokenBlockSequence
,
overlap
:
u32
,
)
->
usize
{
)
->
usize
{
let
prefill_tokens
=
self
.new_tokens
(
&
token_sequence
,
overlap
);
self
.prefill_tokens
.insert
(
request_id
.clone
(),
prefill_tokens
);
self
.active_tokens
+=
prefill_tokens
;
let
blocks
=
create_unique_blocks_from_sequence
(
&
token_sequence
,
None
,
self
.block_size
);
let
blocks
=
create_unique_blocks_from_sequence
(
&
token_sequence
,
None
,
self
.block_size
);
for
block
in
&
blocks
{
for
block
in
&
blocks
{
...
@@ -147,6 +161,25 @@ impl ActiveSequences {
...
@@ -147,6 +161,25 @@ impl ActiveSequences {
self
.active_blocks
self
.active_blocks
}
}
pub
fn
new_tokens
(
&
self
,
token_sequence
:
&
TokenBlockSequence
,
overlap
:
u32
)
->
usize
{
let
input_tokens
=
token_sequence
.total_tokens
();
input_tokens
.checked_sub
((
overlap
as
usize
)
*
self
.block_size
)
.unwrap_or_else
(||
{
panic!
(
"prefill_tokens < 0 with overlap {overlap} and ISL {input_tokens}"
)
})
}
pub
fn
potential_blocks_and_tokens
(
&
self
,
token_sequence
:
&
TokenBlockSequence
,
overlap
:
u32
,
)
->
(
usize
,
usize
)
{
let
potential_blocks
=
self
.new_blocks
(
token_sequence
)
+
self
.active_blocks
;
let
potential_tokens
=
self
.new_tokens
(
token_sequence
,
overlap
)
+
self
.active_tokens
;
(
potential_blocks
,
potential_tokens
)
}
/// Match a request against existing blocks and return the number of new blocks that would be added
/// Match a request against existing blocks and return the number of new blocks that would be added
pub
fn
new_blocks
(
&
self
,
token_sequence
:
&
TokenBlockSequence
)
->
usize
{
pub
fn
new_blocks
(
&
self
,
token_sequence
:
&
TokenBlockSequence
)
->
usize
{
let
blocks
=
create_unique_blocks_from_sequence
(
token_sequence
,
None
,
self
.block_size
);
let
blocks
=
create_unique_blocks_from_sequence
(
token_sequence
,
None
,
self
.block_size
);
...
@@ -165,6 +198,12 @@ impl ActiveSequences {
...
@@ -165,6 +198,12 @@ impl ActiveSequences {
/// Free all blocks associated with a request
/// Free all blocks associated with a request
pub
fn
free
(
&
mut
self
,
request_id
:
&
RequestId
)
->
usize
{
pub
fn
free
(
&
mut
self
,
request_id
:
&
RequestId
)
->
usize
{
// decoding has one active token
self
.active_tokens
=
self
.active_tokens
.checked_sub
(
self
.prefill_tokens
.remove
(
request_id
)
.unwrap_or
(
1
))
.expect
(
"active_tokens < 0"
);
let
Some
(
token_seq
)
=
self
.active_seqs
.get
(
request_id
)
else
{
let
Some
(
token_seq
)
=
self
.active_seqs
.get
(
request_id
)
else
{
tracing
::
warn!
(
"Trying to free free non-existent request {request_id}"
);
tracing
::
warn!
(
"Trying to free free non-existent request {request_id}"
);
return
0
;
return
0
;
...
@@ -187,6 +226,16 @@ impl ActiveSequences {
...
@@ -187,6 +226,16 @@ impl ActiveSequences {
/// Push tokens to a specific request's sequence
/// Push tokens to a specific request's sequence
pub
fn
push
(
&
mut
self
,
request_id
:
&
RequestId
,
tokens
:
&
[
u32
])
->
usize
{
pub
fn
push
(
&
mut
self
,
request_id
:
&
RequestId
,
tokens
:
&
[
u32
])
->
usize
{
if
let
Some
(
prefill_tokens
)
=
self
.prefill_tokens
.get
(
request_id
)
.cloned
()
{
self
.prefill_tokens
.remove
(
request_id
);
// decoding has one active token
self
.active_tokens
=
self
.active_tokens
.checked_sub
(
prefill_tokens
)
.expect
(
"active_tokens < 0"
)
+
1
;
};
// Collect operations to perform after releasing the borrow
// Collect operations to perform after releasing the borrow
let
mut
blocks_to_remove
=
Vec
::
new
();
let
mut
blocks_to_remove
=
Vec
::
new
();
let
mut
blocks_to_add
=
Vec
::
new
();
let
mut
blocks_to_add
=
Vec
::
new
();
...
@@ -239,6 +288,7 @@ enum UpdateSequences {
...
@@ -239,6 +288,7 @@ enum UpdateSequences {
AddRequest
{
AddRequest
{
request_id
:
RequestId
,
request_id
:
RequestId
,
token_sequence
:
TokenBlockSequence
,
token_sequence
:
TokenBlockSequence
,
overlap
:
u32
,
},
},
Free
{
Free
{
request_id
:
RequestId
,
request_id
:
RequestId
,
...
@@ -255,6 +305,11 @@ enum UpdateSequences {
...
@@ -255,6 +305,11 @@ enum UpdateSequences {
token_sequence
:
Arc
<
TokenBlockSequence
>
,
token_sequence
:
Arc
<
TokenBlockSequence
>
,
resp_tx
:
mpsc
::
SyncSender
<
usize
>
,
resp_tx
:
mpsc
::
SyncSender
<
usize
>
,
},
},
PotentialBlocksAndTokens
{
token_sequence
:
Arc
<
TokenBlockSequence
>
,
overlap
:
u32
,
resp_tx
:
mpsc
::
SyncSender
<
(
usize
,
usize
)
>
,
},
ActiveBlocks
{
ActiveBlocks
{
resp_tx
:
mpsc
::
SyncSender
<
usize
>
,
resp_tx
:
mpsc
::
SyncSender
<
usize
>
,
},
},
...
@@ -302,8 +357,9 @@ impl ActiveSequencesMultiWorker {
...
@@ -302,8 +357,9 @@ impl ActiveSequencesMultiWorker {
UpdateSequences
::
AddRequest
{
UpdateSequences
::
AddRequest
{
request_id
,
request_id
,
token_sequence
,
token_sequence
,
overlap
,
}
=>
{
}
=>
{
active_sequences
.add_request
(
request_id
,
token_sequence
);
active_sequences
.add_request
(
request_id
,
token_sequence
,
overlap
);
}
}
UpdateSequences
::
Free
{
request_id
}
=>
{
UpdateSequences
::
Free
{
request_id
}
=>
{
active_sequences
.free
(
&
request_id
);
active_sequences
.free
(
&
request_id
);
...
@@ -325,6 +381,15 @@ impl ActiveSequencesMultiWorker {
...
@@ -325,6 +381,15 @@ impl ActiveSequencesMultiWorker {
let
potential_blocks
=
active_sequences
.potential_blocks
(
&
token_sequence
);
let
potential_blocks
=
active_sequences
.potential_blocks
(
&
token_sequence
);
let
_
=
resp_tx
.send
(
potential_blocks
);
let
_
=
resp_tx
.send
(
potential_blocks
);
}
}
UpdateSequences
::
PotentialBlocksAndTokens
{
token_sequence
,
overlap
,
resp_tx
,
}
=>
{
let
potential_tokens
=
active_sequences
.potential_blocks_and_tokens
(
&
token_sequence
,
overlap
);
let
_
=
resp_tx
.send
(
potential_tokens
);
}
UpdateSequences
::
ActiveBlocks
{
resp_tx
}
=>
{
UpdateSequences
::
ActiveBlocks
{
resp_tx
}
=>
{
let
active_blocks
=
active_sequences
.active_blocks
();
let
active_blocks
=
active_sequences
.active_blocks
();
let
_
=
resp_tx
.send
(
active_blocks
);
let
_
=
resp_tx
.send
(
active_blocks
);
...
@@ -379,6 +444,7 @@ impl ActiveSequencesMultiWorker {
...
@@ -379,6 +444,7 @@ impl ActiveSequencesMultiWorker {
&
mut
self
,
&
mut
self
,
request_id
:
RequestId
,
request_id
:
RequestId
,
token_sequence
:
TokenBlockSequence
,
token_sequence
:
TokenBlockSequence
,
overlap
:
u32
,
worker_id
:
WorkerId
,
worker_id
:
WorkerId
,
)
{
)
{
if
!
self
.senders
.contains_key
(
&
worker_id
)
{
if
!
self
.senders
.contains_key
(
&
worker_id
)
{
...
@@ -391,6 +457,7 @@ impl ActiveSequencesMultiWorker {
...
@@ -391,6 +457,7 @@ impl ActiveSequencesMultiWorker {
.send
(
UpdateSequences
::
AddRequest
{
.send
(
UpdateSequences
::
AddRequest
{
request_id
,
request_id
,
token_sequence
,
token_sequence
,
overlap
,
})
})
.expect
(
"Failed to send add_request command to worker"
);
.expect
(
"Failed to send add_request command to worker"
);
}
}
...
@@ -482,6 +549,43 @@ impl ActiveSequencesMultiWorker {
...
@@ -482,6 +549,43 @@ impl ActiveSequencesMultiWorker {
})
})
}
}
/// Query all workers for the potential tokens (new + active) that would be used by a token sequence with overlap
pub
fn
potential_blocks_and_tokens
(
&
self
,
token_sequence
:
TokenBlockSequence
,
overlaps
:
OverlapScores
,
)
->
(
HashMap
<
WorkerId
,
usize
>
,
HashMap
<
WorkerId
,
usize
>
)
{
let
mut
potential_blocks
=
HashMap
::
new
();
let
mut
potential_tokens
=
HashMap
::
new
();
let
token_sequence_shared
=
Arc
::
new
(
token_sequence
);
let
mut
receivers
=
Vec
::
new
();
// Send queries to all workers in parallel
for
(
worker_id
,
sender
)
in
&
self
.senders
{
let
(
resp_tx
,
resp_rx
)
=
mpsc
::
sync_channel
(
0
);
receivers
.push
((
worker_id
,
resp_rx
));
sender
.send
(
UpdateSequences
::
PotentialBlocksAndTokens
{
token_sequence
:
token_sequence_shared
.clone
(),
overlap
:
overlaps
.scores
.get
(
worker_id
)
.copied
()
.unwrap_or
(
0
),
resp_tx
,
})
.expect
(
"Failed to send potential_tokens command to worker"
);
}
// Collect results from all workers
for
(
worker_id
,
receiver
)
in
receivers
{
let
(
blocks
,
tokens
)
=
receiver
.recv_timeout
(
Duration
::
from_secs
(
1
))
.expect
(
"Failed to receive response from worker"
);
potential_blocks
.insert
(
*
worker_id
,
blocks
);
potential_tokens
.insert
(
*
worker_id
,
tokens
);
}
(
potential_blocks
,
potential_tokens
)
}
/// Query all workers for their current number of active blocks
/// Query all workers for their current number of active blocks
pub
fn
active_blocks
(
&
self
)
->
HashMap
<
WorkerId
,
usize
>
{
pub
fn
active_blocks
(
&
self
)
->
HashMap
<
WorkerId
,
usize
>
{
self
.query_workers
(
None
,
|
_
,
resp_tx
|
UpdateSequences
::
ActiveBlocks
{
resp_tx
})
self
.query_workers
(
None
,
|
_
,
resp_tx
|
UpdateSequences
::
ActiveBlocks
{
resp_tx
})
...
@@ -515,14 +619,15 @@ mod tests {
...
@@ -515,14 +619,15 @@ mod tests {
|
tokens
:
Vec
<
u32
>
|
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
as
u32
,
None
);
|
tokens
:
Vec
<
u32
>
|
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
as
u32
,
None
);
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
// Step 1: Add request 0 with tokens [0, 1, 2], then push 3 and 4
manager
.add_request
(
"0"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
]));
manager
.add_request
(
"0"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
])
,
0
);
manager
.push
(
&
"0"
.to_string
(),
&
[
3
,
4
]);
// Push both tokens at once
manager
.push
(
&
"0"
.to_string
(),
&
[
3
,
4
]);
// Push both tokens at once
assert_eq!
(
manager
.active_tokens
(),
1
);
assert_eq!
(
manager
.active_blocks
(),
2
);
assert_eq!
(
manager
.active_blocks
(),
2
);
assert_eq!
(
manager
.partial_blocks
.len
(),
1
);
assert_eq!
(
manager
.partial_blocks
.len
(),
1
);
// Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6]
// Step 2: Add request 1 with tokens [0, 1, 2, 3, 4, 5, 6]
manager
.add_request
(
"1"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]));
manager
.add_request
(
"1"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]),
1
);
assert_eq!
(
manager
.active_tokens
(),
1
+
3
);
assert_eq!
(
manager
.active_blocks
(),
3
);
assert_eq!
(
manager
.active_blocks
(),
3
);
// Check that only one key is FullBlock with both requests sharing it
// Check that only one key is FullBlock with both requests sharing it
...
@@ -551,6 +656,7 @@ mod tests {
...
@@ -551,6 +656,7 @@ mod tests {
// Step 4: Free request 0
// Step 4: Free request 0
manager
.free
(
&
"0"
.to_string
());
manager
.free
(
&
"0"
.to_string
());
assert_eq!
(
manager
.active_tokens
(),
0
);
assert_eq!
(
manager
.active_blocks
(),
0
);
assert_eq!
(
manager
.active_blocks
(),
0
);
assert_eq!
(
manager
.unique_blocks
.len
(),
0
);
assert_eq!
(
manager
.unique_blocks
.len
(),
0
);
assert_eq!
(
manager
.partial_blocks
.len
(),
0
);
assert_eq!
(
manager
.partial_blocks
.len
(),
0
);
...
@@ -566,14 +672,14 @@ mod tests {
...
@@ -566,14 +672,14 @@ mod tests {
|
tokens
:
Vec
<
u32
>
|
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
as
u32
,
None
);
|
tokens
:
Vec
<
u32
>
|
Tokens
::
from
(
tokens
)
.into_sequence
(
block_size
as
u32
,
None
);
// Send request [0, 1, 2, 3] to worker 0
// Send request [0, 1, 2, 3] to worker 0
manager
.add_request
(
"req0"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
,
3
]),
0
);
manager
.add_request
(
"req0"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
,
3
]),
0
,
0
);
// Send request [0, 1, 2] to worker 1, then push 3 and 4
// Send request [0, 1, 2] to worker 1, then push 3 and 4
manager
.add_request
(
"req1"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
]),
1
);
manager
.add_request
(
"req1"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
]),
0
,
1
);
manager
.push
(
&
"req1"
.to_string
(),
&
[
3
,
4
]);
// Push both tokens at once
manager
.push
(
&
"req1"
.to_string
(),
&
[
3
,
4
]);
// Push both tokens at once
// Send request [0, 1, 2] to worker 2
// Send request [0, 1, 2] to worker 2
manager
.add_request
(
"req2"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
]),
2
);
manager
.add_request
(
"req2"
.to_string
(),
to_sequence
(
vec!
[
0
,
1
,
2
]),
0
,
2
);
// Check new_blocks on tokens [0, 1, 2, 3, 4]
// Check new_blocks on tokens [0, 1, 2, 3, 4]
let
new_blocks_map
=
manager
.new_blocks
(
to_sequence
(
vec!
[
0
,
1
,
2
,
3
,
4
]));
let
new_blocks_map
=
manager
.new_blocks
(
to_sequence
(
vec!
[
0
,
1
,
2
,
3
,
4
]));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment