Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f978f4d1
Unverified
Commit
f978f4d1
authored
Oct 15, 2025
by
Yan Ru Pei
Committed by
GitHub
Oct 16, 2025
Browse files
feat: dp rank routing (#3597)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
29f5b822
Changes
31
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
183 additions
and
88 deletions
+183
-88
lib/llm/src/local_model/runtime_config.rs
lib/llm/src/local_model/runtime_config.rs
+24
-1
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+6
-15
lib/llm/src/mocker/protocols.rs
lib/llm/src/mocker/protocols.rs
+1
-1
lib/llm/src/mocker/scheduler.rs
lib/llm/src/mocker/scheduler.rs
+10
-10
lib/llm/src/protocols/common/preprocessor.rs
lib/llm/src/protocols/common/preprocessor.rs
+5
-0
lib/llm/tests/block_manager.rs
lib/llm/tests/block_manager.rs
+6
-0
lib/runtime/src/utils/worker_monitor.rs
lib/runtime/src/utils/worker_monitor.rs
+54
-26
tests/router/test_router_e2e_with_mockers.py
tests/router/test_router_e2e_with_mockers.py
+74
-32
tests/serve/test_sglang.py
tests/serve/test_sglang.py
+1
-1
tests/serve/test_trtllm.py
tests/serve/test_trtllm.py
+1
-1
tests/serve/test_vllm.py
tests/serve/test_vllm.py
+1
-1
No files found.
lib/llm/src/local_model/runtime_config.rs
View file @
f978f4d1
...
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
...
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use
crate
::
protocols
::
tensor
;
use
crate
::
protocols
::
tensor
;
#[derive(Debug,
Default,
Clone,
Serialize,
Deserialize,
Eq,
PartialEq)]
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Eq,
PartialEq)]
pub
struct
ModelRuntimeConfig
{
pub
struct
ModelRuntimeConfig
{
pub
total_kv_blocks
:
Option
<
u64
>
,
pub
total_kv_blocks
:
Option
<
u64
>
,
...
@@ -19,6 +19,10 @@ pub struct ModelRuntimeConfig {
...
@@ -19,6 +19,10 @@ pub struct ModelRuntimeConfig {
pub
reasoning_parser
:
Option
<
String
>
,
pub
reasoning_parser
:
Option
<
String
>
,
/// Total number of data parallel ranks for this worker (1 if DP not enabled)
#[serde(default
=
"default_data_parallel_size"
)]
pub
data_parallel_size
:
u32
,
/// Mapping of engine-specific runtime configs
/// Mapping of engine-specific runtime configs
#[serde(default,
skip_serializing_if
=
"HashMap::is_empty"
)]
#[serde(default,
skip_serializing_if
=
"HashMap::is_empty"
)]
pub
runtime_data
:
HashMap
<
String
,
serde_json
::
Value
>
,
pub
runtime_data
:
HashMap
<
String
,
serde_json
::
Value
>
,
...
@@ -34,6 +38,25 @@ pub struct ModelRuntimeConfig {
...
@@ -34,6 +38,25 @@ pub struct ModelRuntimeConfig {
pub
tensor_model_config
:
Option
<
tensor
::
TensorModelConfig
>
,
pub
tensor_model_config
:
Option
<
tensor
::
TensorModelConfig
>
,
}
}
const
fn
default_data_parallel_size
()
->
u32
{
1
}
impl
Default
for
ModelRuntimeConfig
{
fn
default
()
->
Self
{
Self
{
total_kv_blocks
:
None
,
max_num_seqs
:
None
,
max_num_batched_tokens
:
None
,
tool_call_parser
:
None
,
reasoning_parser
:
None
,
data_parallel_size
:
default_data_parallel_size
(),
runtime_data
:
HashMap
::
new
(),
tensor_model_config
:
None
,
}
}
}
impl
ModelRuntimeConfig
{
impl
ModelRuntimeConfig
{
pub
fn
new
()
->
Self
{
pub
fn
new
()
->
Self
{
Self
::
default
()
Self
::
default
()
...
...
lib/llm/src/mocker/engine.rs
View file @
f978f4d1
...
@@ -124,7 +124,7 @@ impl MockVllmEngine {
...
@@ -124,7 +124,7 @@ impl MockVllmEngine {
let
scheduler
=
Scheduler
::
new
(
let
scheduler
=
Scheduler
::
new
(
args
.clone
(),
args
.clone
(),
Some
(
dp_rank
)
,
dp_rank
,
Some
(
output_tx
),
Some
(
output_tx
),
Some
(
kv_events_tx
),
// Pass the KV events sender to scheduler
Some
(
kv_events_tx
),
// Pass the KV events sender to scheduler
Some
(
cancel_token
.clone
()),
Some
(
cancel_token
.clone
()),
...
@@ -283,6 +283,7 @@ impl MockVllmEngine {
...
@@ -283,6 +283,7 @@ impl MockVllmEngine {
let
event
=
KvCacheEvent
{
let
event
=
KvCacheEvent
{
event_id
:
Uuid
::
new_v4
()
.as_u128
()
as
u64
,
event_id
:
Uuid
::
new_v4
()
.as_u128
()
as
u64
,
data
:
event_data
,
data
:
event_data
,
dp_rank
,
};
};
// Publish the event
// Publish the event
...
@@ -316,18 +317,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
...
@@ -316,18 +317,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
)
->
Result
<
ManyOut
<
LLMEngineOutput
>
,
Error
>
{
)
->
Result
<
ManyOut
<
LLMEngineOutput
>
,
Error
>
{
let
(
request
,
ctx
)
=
input
.into_parts
();
let
(
request
,
ctx
)
=
input
.into_parts
();
// Extract dp_rank from annotations if present
// Extract dp_rank from request field (defaults to 0 if not set)
let
dp_rank
=
request
let
dp_rank
=
request
.dp_rank
.unwrap_or
(
0
);
.annotations
.iter
()
.find_map
(|
ann
|
{
if
ann
.starts_with
(
"dp_rank:"
)
{
ann
.strip_prefix
(
"dp_rank:"
)
.and_then
(|
s
|
s
.parse
()
.ok
())
}
else
{
None
}
})
.unwrap_or
(
0
);
// Validate dp_rank
// Validate dp_rank
if
dp_rank
>=
self
.engine_args.dp_size
{
if
dp_rank
>=
self
.engine_args.dp_size
{
...
@@ -348,7 +339,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
...
@@ -348,7 +339,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
.expect
(
"max_output_tokens must be specified for mocker"
)
.expect
(
"max_output_tokens must be specified for mocker"
)
as
usize
,
as
usize
,
uuid
:
Some
(
request_uuid
),
uuid
:
Some
(
request_uuid
),
dp_rank
:
Some
(
dp_rank
)
,
dp_rank
,
};
};
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
...
@@ -512,7 +503,7 @@ pub async fn make_mocker_engine(
...
@@ -512,7 +503,7 @@ pub async fn make_mocker_engine(
args
:
MockEngineArgs
,
args
:
MockEngineArgs
,
)
->
Result
<
crate
::
backend
::
ExecutionContext
,
Error
>
{
)
->
Result
<
crate
::
backend
::
ExecutionContext
,
Error
>
{
// Create the mocker engine
// Create the mocker engine
tracing
::
debug
!
(
"Creating mocker engine with config: {args:?}"
);
tracing
::
info
!
(
"Creating mocker engine with config: {args:?}"
);
let
annotated_engine
=
let
annotated_engine
=
AnnotatedMockEngine
::
new
(
MockVllmEngine
::
new
(
args
),
distributed_runtime
,
endpoint_id
);
AnnotatedMockEngine
::
new
(
MockVllmEngine
::
new
(
args
),
distributed_runtime
,
endpoint_id
);
...
...
lib/llm/src/mocker/protocols.rs
View file @
f978f4d1
...
@@ -37,7 +37,7 @@ pub struct DirectRequest {
...
@@ -37,7 +37,7 @@ pub struct DirectRequest {
pub
tokens
:
Vec
<
Token
>
,
pub
tokens
:
Vec
<
Token
>
,
pub
max_output_tokens
:
usize
,
pub
max_output_tokens
:
usize
,
pub
uuid
:
Option
<
Uuid
>
,
pub
uuid
:
Option
<
Uuid
>
,
pub
dp_rank
:
Option
<
u32
>
,
pub
dp_rank
:
u32
,
}
}
/// Represents the cost of prefilling content in the cache
/// Represents the cost of prefilling content in the cache
...
...
lib/llm/src/mocker/scheduler.rs
View file @
f978f4d1
...
@@ -248,7 +248,7 @@ impl Scheduler {
...
@@ -248,7 +248,7 @@ impl Scheduler {
/// Create a new Scheduler with the given parameters
/// Create a new Scheduler with the given parameters
pub
fn
new
(
pub
fn
new
(
args
:
MockEngineArgs
,
args
:
MockEngineArgs
,
dp_rank
:
Option
<
u32
>
,
dp_rank
:
u32
,
output_tx
:
Option
<
mpsc
::
UnboundedSender
<
OutputSignal
>>
,
output_tx
:
Option
<
mpsc
::
UnboundedSender
<
OutputSignal
>>
,
kv_events_tx
:
Option
<
mpsc
::
UnboundedSender
<
KvCacheEventData
>>
,
kv_events_tx
:
Option
<
mpsc
::
UnboundedSender
<
KvCacheEventData
>>
,
cancellation_token
:
Option
<
CancellationToken
>
,
cancellation_token
:
Option
<
CancellationToken
>
,
...
@@ -280,7 +280,7 @@ impl Scheduler {
...
@@ -280,7 +280,7 @@ impl Scheduler {
// Create channel for request handling
// Create channel for request handling
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
DirectRequest
>
();
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
DirectRequest
>
();
let
mut
initial_metrics
=
ForwardPassMetrics
::
default
();
let
mut
initial_metrics
=
ForwardPassMetrics
::
default
();
initial_metrics
.worker_stats.data_parallel_rank
=
dp_rank
;
initial_metrics
.worker_stats.data_parallel_rank
=
Some
(
dp_rank
)
;
let
(
metrics_tx
,
metrics_rx
)
=
let
(
metrics_tx
,
metrics_rx
)
=
tokio
::
sync
::
watch
::
channel
::
<
ForwardPassMetrics
>
(
initial_metrics
);
tokio
::
sync
::
watch
::
channel
::
<
ForwardPassMetrics
>
(
initial_metrics
);
...
@@ -573,7 +573,7 @@ fn get_fwd_pass_metrics(
...
@@ -573,7 +573,7 @@ fn get_fwd_pass_metrics(
state
:
&
SchedulerState
,
state
:
&
SchedulerState
,
kv_manager
:
&
KvManager
,
kv_manager
:
&
KvManager
,
hit_rates
:
&
VecDeque
<
f32
>
,
hit_rates
:
&
VecDeque
<
f32
>
,
dp_rank
:
Option
<
u32
>
,
dp_rank
:
u32
,
)
->
ForwardPassMetrics
{
)
->
ForwardPassMetrics
{
// Get state metrics
// Get state metrics
let
request_active_slots
=
state
.decode
.len
()
as
u64
;
let
request_active_slots
=
state
.decode
.len
()
as
u64
;
...
@@ -597,7 +597,7 @@ fn get_fwd_pass_metrics(
...
@@ -597,7 +597,7 @@ fn get_fwd_pass_metrics(
};
};
let
worker_stats
=
WorkerStats
{
let
worker_stats
=
WorkerStats
{
data_parallel_rank
:
dp_rank
,
data_parallel_rank
:
Some
(
dp_rank
)
,
request_active_slots
,
request_active_slots
,
request_total_slots
:
1024
,
// vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
request_total_slots
:
1024
,
// vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
num_requests_waiting
,
num_requests_waiting
,
...
@@ -728,7 +728,7 @@ mod tests {
...
@@ -728,7 +728,7 @@ mod tests {
.unwrap
();
.unwrap
();
// Create scheduler with new args struct
// Create scheduler with new args struct
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
let
scheduler
=
Scheduler
::
new
(
args
,
0
,
Some
(
output_tx
),
None
,
None
);
// Create shared tokens for caching case
// Create shared tokens for caching case
let
shared_tokens
=
if
use_shared_tokens
{
let
shared_tokens
=
if
use_shared_tokens
{
...
@@ -759,7 +759,7 @@ mod tests {
...
@@ -759,7 +759,7 @@ mod tests {
tokens
:
input_tokens
,
tokens
:
input_tokens
,
max_output_tokens
,
max_output_tokens
,
uuid
:
None
,
uuid
:
None
,
dp_rank
:
None
,
dp_rank
:
0
,
};
};
scheduler
.receive
(
request
)
.await
;
scheduler
.receive
(
request
)
.await
;
}
}
...
@@ -853,7 +853,7 @@ mod tests {
...
@@ -853,7 +853,7 @@ mod tests {
.unwrap
();
.unwrap
();
// Create scheduler
// Create scheduler
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
let
scheduler
=
Scheduler
::
new
(
args
,
0
,
Some
(
output_tx
),
None
,
None
);
// Create identical tokens for all requests
// Create identical tokens for all requests
let
identical_tokens
:
Vec
<
u32
>
=
(
0
..
token_length
)
.map
(|
i
|
i
as
u32
)
.collect
();
let
identical_tokens
:
Vec
<
u32
>
=
(
0
..
token_length
)
.map
(|
i
|
i
as
u32
)
.collect
();
...
@@ -864,7 +864,7 @@ mod tests {
...
@@ -864,7 +864,7 @@ mod tests {
tokens
:
identical_tokens
.clone
(),
tokens
:
identical_tokens
.clone
(),
max_output_tokens
,
max_output_tokens
,
uuid
:
None
,
uuid
:
None
,
dp_rank
:
None
,
dp_rank
:
0
,
};
};
scheduler
.receive
(
request
)
.await
;
scheduler
.receive
(
request
)
.await
;
// Sleep for 0.1 second after each request
// Sleep for 0.1 second after each request
...
@@ -950,7 +950,7 @@ mod tests {
...
@@ -950,7 +950,7 @@ mod tests {
.unwrap
();
.unwrap
();
// Create scheduler
// Create scheduler
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
let
scheduler
=
Scheduler
::
new
(
args
,
0
,
Some
(
output_tx
),
None
,
None
);
// Create request with 256 tokens
// Create request with 256 tokens
let
tokens
:
Vec
<
u32
>
=
(
0
..
input_tokens
)
.map
(|
i
|
i
as
u32
)
.collect
();
let
tokens
:
Vec
<
u32
>
=
(
0
..
input_tokens
)
.map
(|
i
|
i
as
u32
)
.collect
();
...
@@ -958,7 +958,7 @@ mod tests {
...
@@ -958,7 +958,7 @@ mod tests {
tokens
,
tokens
,
max_output_tokens
,
max_output_tokens
,
uuid
:
None
,
uuid
:
None
,
dp_rank
:
None
,
dp_rank
:
0
,
};
};
scheduler
.receive
(
request
)
.await
;
scheduler
.receive
(
request
)
.await
;
...
...
lib/llm/src/protocols/common/preprocessor.rs
View file @
f978f4d1
...
@@ -61,6 +61,11 @@ pub struct PreprocessedRequest {
...
@@ -61,6 +61,11 @@ pub struct PreprocessedRequest {
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
disaggregated_params
:
Option
<
serde_json
::
Value
>
,
pub
disaggregated_params
:
Option
<
serde_json
::
Value
>
,
/// Data parallel rank for the request (used with data parallelism)
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
dp_rank
:
Option
<
u32
>
,
/// Additional arguments for extensibility
/// Additional arguments for extensibility
#[builder(default)]
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
...
...
lib/llm/tests/block_manager.rs
View file @
f978f4d1
...
@@ -294,6 +294,7 @@ pub mod llm_kvbm {
...
@@ -294,6 +294,7 @@ pub mod llm_kvbm {
let
event
=
KvCacheEvent
{
let
event
=
KvCacheEvent
{
data
,
data
,
event_id
:
event_id_counter
,
event_id
:
event_id_counter
,
dp_rank
:
0
,
};
};
let
router_event
=
RouterEvent
::
new
(
worker_identifier
as
i64
,
event
);
let
router_event
=
RouterEvent
::
new
(
worker_identifier
as
i64
,
event
);
event_id_counter
+=
1
;
event_id_counter
+=
1
;
...
@@ -313,6 +314,7 @@ pub mod llm_kvbm {
...
@@ -313,6 +314,7 @@ pub mod llm_kvbm {
block_hashes
:
vec!
[
ExternalSequenceBlockHash
(
sequence_hash
)],
block_hashes
:
vec!
[
ExternalSequenceBlockHash
(
sequence_hash
)],
}),
}),
event_id
:
event_id_counter
,
event_id
:
event_id_counter
,
dp_rank
:
0
,
};
};
let
router_event
=
RouterEvent
::
new
(
worker_identifier
as
i64
,
event
);
let
router_event
=
RouterEvent
::
new
(
worker_identifier
as
i64
,
event
);
event_id_counter
+=
1
;
event_id_counter
+=
1
;
...
@@ -573,6 +575,7 @@ mod tests {
...
@@ -573,6 +575,7 @@ mod tests {
}],
}],
parent_hash
:
None
,
parent_hash
:
None
,
}),
}),
dp_rank
:
0
,
},
},
);
);
...
@@ -587,6 +590,7 @@ mod tests {
...
@@ -587,6 +590,7 @@ mod tests {
}],
}],
parent_hash
:
None
,
parent_hash
:
None
,
}),
}),
dp_rank
:
0
,
},
},
);
);
...
@@ -630,6 +634,7 @@ mod tests {
...
@@ -630,6 +634,7 @@ mod tests {
}],
}],
parent_hash
:
None
,
parent_hash
:
None
,
}),
}),
dp_rank
:
0
,
},
},
);
);
...
@@ -678,6 +683,7 @@ mod tests {
...
@@ -678,6 +683,7 @@ mod tests {
data
:
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
data
:
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
block_hashes
:
vec!
[
ExternalSequenceBlockHash
(
4
)],
block_hashes
:
vec!
[
ExternalSequenceBlockHash
(
4
)],
}),
}),
dp_rank
:
0
,
},
},
);
);
...
...
lib/runtime/src/utils/worker_monitor.rs
View file @
f978f4d1
...
@@ -26,34 +26,59 @@ struct LoadEvent {
...
@@ -26,34 +26,59 @@ struct LoadEvent {
#[derive(serde::Deserialize)]
#[derive(serde::Deserialize)]
struct
ForwardPassMetrics
{
struct
ForwardPassMetrics
{
worker_stats
:
WorkerStats
,
kv_stats
:
KvStats
,
kv_stats
:
KvStats
,
}
}
#[derive(serde::Deserialize)]
struct
WorkerStats
{
data_parallel_rank
:
Option
<
u32
>
,
}
#[derive(serde::Deserialize)]
#[derive(serde::Deserialize)]
struct
KvStats
{
struct
KvStats
{
kv_active_blocks
:
u64
,
kv_active_blocks
:
u64
,
}
}
#[derive(serde::Deserialize)]
#[derive(serde::Deserialize
,
Clone
)]
struct
RuntimeConfig
{
struct
RuntimeConfig
{
total_kv_blocks
:
Option
<
u64
>
,
total_kv_blocks
:
Option
<
u64
>
,
data_parallel_size
:
u32
,
}
}
/// Worker load monitoring state
/// Worker load monitoring state
per dp_rank
#[derive(Clone,
Debug)]
#[derive(Clone,
Debug
,
Default
)]
pub
struct
WorkerLoadState
{
pub
struct
WorkerLoadState
{
pub
kv_active_blocks
:
Option
<
u64
>
,
pub
kv_active_blocks
:
HashMap
<
u32
,
u64
>
,
pub
kv_total_blocks
:
Option
<
u64
>
,
pub
kv_total_blocks
:
HashMap
<
u32
,
u64
>
,
}
}
impl
WorkerLoadState
{
impl
WorkerLoadState
{
/// Returns true if ALL dp_ranks (that have data in both maps) exceed the threshold
pub
fn
is_busy
(
&
self
,
threshold
:
f64
)
->
bool
{
pub
fn
is_busy
(
&
self
,
threshold
:
f64
)
->
bool
{
match
(
self
.kv_active_blocks
,
self
.kv_total_blocks
)
{
// Get all dp_ranks that exist in both active and total blocks
(
Some
(
active
),
Some
(
total
))
if
total
>
0
=>
{
let
common_dp_ranks
:
Vec
<
_
>
=
self
(
active
as
f64
)
>
(
threshold
*
total
as
f64
)
.kv_active_blocks
.keys
()
.filter
(|
dp_rank
|
self
.kv_total_blocks
.contains_key
(
dp_rank
))
.collect
();
// If no common dp_ranks, not busy
if
common_dp_ranks
.is_empty
()
{
return
false
;
}
}
_
=>
false
,
// Check if ALL common dp_ranks exceed threshold
common_dp_ranks
.iter
()
.all
(|
&&
dp_rank
|
{
if
let
(
Some
(
&
active
),
Some
(
&
total
))
=
(
self
.kv_active_blocks
.get
(
&
dp_rank
),
self
.kv_total_blocks
.get
(
&
dp_rank
),
)
{
total
>
0
&&
(
active
as
f64
)
>
(
threshold
*
total
as
f64
)
}
else
{
false
}
}
})
}
}
}
}
...
@@ -97,9 +122,10 @@ impl WorkerMonitor {
...
@@ -97,9 +122,10 @@ impl WorkerMonitor {
"v1/mdc/"
,
// should be model_card::ROOT_PREFIX but wrong crate
"v1/mdc/"
,
// should be model_card::ROOT_PREFIX but wrong crate
key_extractors
::
lease_id
,
key_extractors
::
lease_id
,
|
card
:
serde_json
::
Value
|
{
|
card
:
serde_json
::
Value
|
{
card
.get
(
"runtime_config"
)
let
runtime_config
:
Option
<
RuntimeConfig
>
=
card
.and_then
(|
rc
|
rc
.get
(
"total_kv_blocks"
))
.get
(
"runtime_config"
)
.and_then
(|
t_kv
|
t_kv
.as_u64
())
.and_then
(|
rc
|
serde_json
::
from_value
(
rc
.clone
())
.ok
());
runtime_config
},
},
component
.drt
()
.child_token
(),
component
.drt
()
.child_token
(),
)
)
...
@@ -132,13 +158,17 @@ impl WorkerMonitor {
...
@@ -132,13 +158,17 @@ impl WorkerMonitor {
let
mut
states
=
worker_load_states
.write
()
.unwrap
();
let
mut
states
=
worker_load_states
.write
()
.unwrap
();
states
.retain
(|
lease_id
,
_
|
runtime_configs
.contains_key
(
lease_id
));
states
.retain
(|
lease_id
,
_
|
runtime_configs
.contains_key
(
lease_id
));
// Update worker load states with total blocks
// Update worker load states with total blocks for all dp_ranks
for
(
lease_id
,
total_blocks
)
in
runtime_configs
.iter
()
{
for
(
lease_id
,
runtime_config
)
in
runtime_configs
.iter
()
{
let
state
=
states
.entry
(
*
lease_id
)
.or_insert
(
WorkerLoadState
{
let
state
=
states
.entry
(
*
lease_id
)
.or_default
();
kv_active_blocks
:
None
,
kv_total_blocks
:
None
,
// Populate total_blocks for all dp_ranks (they share the same total)
});
// data_parallel_size defaults to 1 via serde in ModelRuntimeConfig
state
.kv_total_blocks
=
Some
(
*
total_blocks
);
if
let
Some
(
total_blocks
)
=
runtime_config
.total_kv_blocks
{
for
dp_rank
in
0
..
runtime_config
.data_parallel_size
{
state
.kv_total_blocks
.insert
(
dp_rank
,
total_blocks
);
}
}
}
}
}
}
...
@@ -152,14 +182,12 @@ impl WorkerMonitor {
...
@@ -152,14 +182,12 @@ impl WorkerMonitor {
if
let
Ok
(
load_event
)
=
serde_json
::
from_slice
::
<
LoadEvent
>
(
&
event
.payload
)
{
if
let
Ok
(
load_event
)
=
serde_json
::
from_slice
::
<
LoadEvent
>
(
&
event
.payload
)
{
let
worker_id
=
load_event
.worker_id
;
let
worker_id
=
load_event
.worker_id
;
let
active_blocks
=
load_event
.data.kv_stats.kv_active_blocks
;
let
active_blocks
=
load_event
.data.kv_stats.kv_active_blocks
;
let
dp_rank
=
load_event
.data.worker_stats.data_parallel_rank
.unwrap_or
(
0
);
// Update worker load state
// Update worker load state
per dp_rank
let
mut
states
=
worker_load_states
.write
()
.unwrap
();
let
mut
states
=
worker_load_states
.write
()
.unwrap
();
let
state
=
states
.entry
(
worker_id
)
.or_insert
(
WorkerLoadState
{
let
state
=
states
.entry
(
worker_id
)
.or_default
();
kv_active_blocks
:
None
,
state
.kv_active_blocks
.insert
(
dp_rank
,
active_blocks
);
kv_total_blocks
:
None
,
});
state
.kv_active_blocks
=
Some
(
active_blocks
);
drop
(
states
);
drop
(
states
);
// Recalculate all busy instances and update
// Recalculate all busy instances and update
...
...
tests/router/test_router_e2e_with_mockers.py
View file @
f978f4d1
...
@@ -298,6 +298,7 @@ async def send_request_via_python_kv_router(
...
@@ -298,6 +298,7 @@ async def send_request_via_python_kv_router(
worker_id
:
Optional
[
worker_id
:
Optional
[
int
int
]
=
None
,
# If None, Router will select the best available worker
]
=
None
,
# If None, Router will select the best available worker
dp_rank
:
Optional
[
int
]
=
None
,
# Data parallel rank (defaults to 0)
):
):
"""Send a request to the specified mocker instance.
"""Send a request to the specified mocker instance.
Returns True if mockers respond, otherwise raises or returns False.
Returns True if mockers respond, otherwise raises or returns False.
...
@@ -324,6 +325,7 @@ async def send_request_via_python_kv_router(
...
@@ -324,6 +325,7 @@ async def send_request_via_python_kv_router(
output_options
=
output_options
,
output_options
=
output_options
,
router_config_override
=
router_config_override
,
router_config_override
=
router_config_override
,
worker_id
=
worker_id
,
worker_id
=
worker_id
,
dp_rank
=
dp_rank
,
)
)
if
stream
is
not
None
:
if
stream
is
not
None
:
...
@@ -1314,33 +1316,38 @@ def test_query_instance_id_returns_worker_and_tokens(
...
@@ -1314,33 +1316,38 @@ def test_query_instance_id_returns_worker_and_tokens(
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
model
(
MODEL_NAME
)
@
pytest
.
mark
.
model
(
MODEL_NAME
)
def
test_router_decisions
(
request
,
runtime_services
,
predownload_tokenizers
):
def
test_router_decisions
(
request
,
runtime_services
,
predownload_tokenizers
):
"""Validate KV cache prefix reuse by sending progressive requests with overlapping prefixes.
"""Validate KV cache prefix reuse
and dp_rank routing
by sending progressive requests with overlapping prefixes.
Flow:
Flow:
- Start two mocker workers
sharing a namespace
.
- Start two mocker workers
, each with dp_size=4 (8 total dp ranks)
.
- Wait for workers to be ready.
- Wait for workers to be ready.
- Send 4 progressive requests, each extending the previous tokens:
- Send 4 progressive requests, each extending the previous tokens:
* Request 1: BLOCK_SIZE random tokens
* Request 1: BLOCK_SIZE random tokens
(forced to specific worker_id and dp_rank=1)
* Request 2: Request 1 tokens + BLOCK_SIZE new random tokens
* Request 2: Request 1 tokens + BLOCK_SIZE new random tokens
(naturally routed)
* Request 3: Request 2 tokens + BLOCK_SIZE new random tokens
* Request 3: Request 2 tokens + BLOCK_SIZE new random tokens
(naturally routed)
* Request 4: Request 3 tokens + BLOCK_SIZE new random tokens
* Request 4: Request 3 tokens + BLOCK_SIZE new random tokens
(naturally routed)
- Dump events from router and verify:
- Dump events from router and verify:
* All but one worker should have no events (one worker handles all due to prefix reuse)
* All but one (worker_id, dp_rank) should have no events (due to prefix reuse)
* The worker with events should have exactly 4 events (one per request)
* The (worker_id, dp_rank) with events should have exactly 4 events (one per request)
* All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse)
"""
"""
# runtime_services starts etcd and nats
# runtime_services starts etcd and nats
logger
.
info
(
"Starting test router prefix reuse and KV events synchronization"
)
logger
.
info
(
"Starting test router prefix reuse and KV events synchronization"
)
# Create mocker args dictionary
# Create mocker args dictionary with dp_size=4
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"dp_size"
:
4
,
}
try
:
try
:
# Start mocker instances with the new CLI interface
# Start 2 mocker instances, each with dp_size=4 (8 total dp ranks)
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
logger
.
info
(
mockers
=
MockerProcess
(
"Starting 2 mocker instances with dp_size=4 each (8 total dp ranks)"
request
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
)
)
mockers
=
MockerProcess
(
request
,
mocker_args
=
mocker_args
,
num_mockers
=
2
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
# Initialize mockers
# Initialize mockers
mockers
.
__enter__
()
mockers
.
__enter__
()
...
@@ -1363,9 +1370,19 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
...
@@ -1363,9 +1370,19 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
# Use async to manage the test flow
# Use async to manage the test flow
async
def
test_sync
():
async
def
test_sync
():
# Wait for workers to be ready and get their instance IDs
# Wait for workers to be ready and get their instance IDs
mocker_worker_ids
=
await
wait_for_mockers_ready
(
endpoint
,
kv_push_router
)
mocker_worker_ids
=
await
wait_for_mockers_ready
(
endpoint
,
kv_push_router
,
expected_num_workers
=
2
)
logger
.
info
(
f
"Workers ready:
{
mocker_worker_ids
}
"
)
logger
.
info
(
f
"Workers ready:
{
mocker_worker_ids
}
"
)
# Use the first worker_id for forced routing
forced_worker_id
=
mocker_worker_ids
[
0
]
forced_dp_rank
=
1
logger
.
info
(
f
"Will force first request to worker_id=
{
forced_worker_id
}
, dp_rank=
{
forced_dp_rank
}
"
)
# Send 4 progressive requests with overlapping prefixes
# Send 4 progressive requests with overlapping prefixes
cumulative_tokens
=
[]
cumulative_tokens
=
[]
...
@@ -1374,9 +1391,14 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
...
@@ -1374,9 +1391,14 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
new_tokens
=
[
random
.
randint
(
1
,
10000
)
for
_
in
range
(
BLOCK_SIZE
)]
new_tokens
=
[
random
.
randint
(
1
,
10000
)
for
_
in
range
(
BLOCK_SIZE
)]
cumulative_tokens
.
extend
(
new_tokens
)
cumulative_tokens
.
extend
(
new_tokens
)
# Force first request to specific worker_id and dp_rank=1, let subsequent requests follow naturally
worker_id_override
=
forced_worker_id
if
i
==
0
else
None
dp_rank_override
=
forced_dp_rank
if
i
==
0
else
None
logger
.
info
(
logger
.
info
(
f
"Sending request
{
i
+
1
}
/4 with
{
len
(
cumulative_tokens
)
}
tokens "
f
"Sending request
{
i
+
1
}
/4 with
{
len
(
cumulative_tokens
)
}
tokens "
f
"(added
{
len
(
new_tokens
)
}
new tokens)"
f
"(added
{
len
(
new_tokens
)
}
new tokens)"
f
"
{
f
' - FORCING worker_id=
{
worker_id_override
}
,
dp_rank
=
{
dp_rank_override
}
' if worker_id_override is not None else ''
}
"
)
)
await
send_request_via_python_kv_router
(
await
send_request_via_python_kv_router
(
...
@@ -1388,6 +1410,8 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
...
@@ -1388,6 +1410,8 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
"ignore_eos"
:
True
,
# Don't stop on EOS token
"ignore_eos"
:
True
,
# Don't stop on EOS token
"max_tokens"
:
2
,
# Generate exactly 2 tokens
"max_tokens"
:
2
,
# Generate exactly 2 tokens
},
},
worker_id
=
worker_id_override
,
dp_rank
=
dp_rank_override
,
)
)
# Wait a bit between requests
# Wait a bit between requests
...
@@ -1398,46 +1422,64 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
...
@@ -1398,46 +1422,64 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
# Dump events from the router
# Dump events from the router
events_json
=
await
kv_push_router
.
dump_events
()
events_json
=
await
kv_push_router
.
dump_events
()
return
events_json
return
events_json
,
forced_worker_id
,
forced_dp_rank
# Run the async test
# Run the async test
events_json
=
asyncio
.
run
(
test_sync
())
events_json
,
expected_worker_id
,
expected_dp_rank
=
asyncio
.
run
(
test_sync
())
# Parse events and count by worker
# Parse events and count by
(
worker
_id, dp_rank)
events
=
json
.
loads
(
events_json
)
events
=
json
.
loads
(
events_json
)
events_by_worker
:
dict
[
int
,
list
[
Any
]]
=
{}
events_by_worker
_dp
:
dict
[
tuple
[
int
,
int
]
,
list
[
Any
]]
=
{}
for
event
in
events
:
for
event
in
events
:
worker_id
=
event
.
get
(
"worker_id"
)
worker_id
=
event
.
get
(
"worker_id"
)
if
worker_id
not
in
events_by_worker
:
# Extract dp_rank from the event's KvCacheEvent
events_by_worker
[
worker_id
]
=
[]
dp_rank
=
event
.
get
(
"event"
,
{}).
get
(
"dp_rank"
,
0
)
events_by_worker
[
worker_id
].
append
(
event
)
key
=
(
worker_id
,
dp_rank
)
if
key
not
in
events_by_worker_dp
:
events_by_worker_dp
[
key
]
=
[]
events_by_worker_dp
[
key
].
append
(
event
)
logger
.
info
(
logger
.
info
(
f
"Events by worker:
{
[(
wid
,
len
(
evts
))
for
wid
,
evts
in
events_by_worker
.
items
()]
}
"
f
"Events by
(
worker
_id, dp_rank)
:
{
[(
key
,
len
(
evts
))
for
key
,
evts
in
events_by_worker
_dp
.
items
()]
}
"
)
)
# Verify: All but one worker should have no events
# Verify: All but one
(
worker
_id, dp_rank)
should have no events
workers_with_events
=
[
workers_with_events
=
[
wid
for
wid
,
evts
in
events_by_worker
.
items
()
if
len
(
evts
)
>
0
key
for
key
,
evts
in
events_by_worker
_dp
.
items
()
if
len
(
evts
)
>
0
]
]
assert
len
(
workers_with_events
)
==
1
,
(
assert
len
(
workers_with_events
)
==
1
,
(
f
"Expected exactly 1 worker to have events (due to prefix reuse), "
f
"Expected exactly 1
(
worker
_id, dp_rank)
to have events (due to prefix reuse), "
f
"but found
{
len
(
workers_with_events
)
}
workers
with events:
{
workers_with_events
}
"
f
"but found
{
len
(
workers_with_events
)
}
with events:
{
workers_with_events
}
"
)
)
# Verify: The worker with events should have exactly 4 events
# Verify: The
(
worker
_id, dp_rank)
with events should have exactly 4 events
active_worker
=
workers_with_events
[
0
]
active_worker
_dp
=
workers_with_events
[
0
]
num_events
=
len
(
events_by_worker
[
active_worker
])
num_events
=
len
(
events_by_worker
_dp
[
active_worker
_dp
])
assert
num_events
==
4
,
(
assert
num_events
==
4
,
(
f
"Expected worker
{
active_worker
}
to have exactly 4 events, "
f
"Expected
(
worker
_id, dp_rank)
{
active_worker
_dp
}
to have exactly 4 events, "
f
"but found
{
num_events
}
events"
f
"but found
{
num_events
}
events"
)
)
# Verify: Both worker_id and dp_rank should match the forced values
active_worker_id
=
active_worker_dp
[
0
]
active_dp_rank
=
active_worker_dp
[
1
]
assert
active_worker_id
==
expected_worker_id
,
(
f
"Expected all events to have worker_id=
{
expected_worker_id
}
(forced in first request), "
f
"but found worker_id=
{
active_worker_id
}
"
)
assert
active_dp_rank
==
expected_dp_rank
,
(
f
"Expected all events to have dp_rank=
{
expected_dp_rank
}
(forced in first request), "
f
"but found dp_rank=
{
active_dp_rank
}
"
)
logger
.
info
(
logger
.
info
(
f
"Successfully verified: Worker
{
active_worker
}
handled all 4 requests with prefix reuse. "
f
"Successfully verified: Worker
{
active_worker_id
}
dp_rank
{
active_dp_rank
}
handled all 4 requests with prefix reuse. "
f
"All events correctly routed to worker_id=
{
expected_worker_id
}
, dp_rank=
{
expected_dp_rank
}
as expected. "
f
"KV events synchronized correctly."
f
"KV events synchronized correctly."
)
)
...
...
tests/serve/test_sglang.py
View file @
f978f4d1
...
@@ -69,7 +69,7 @@ sglang_configs = {
...
@@ -69,7 +69,7 @@ sglang_configs = {
expected_log
=
[
expected_log
=
[
r
"ZMQ listener .* received batch with \d+ events \(seq=\d+\)"
,
r
"ZMQ listener .* received batch with \d+ events \(seq=\d+\)"
,
r
"Event processor for worker_id \d+ processing event: Stored\("
,
r
"Event processor for worker_id \d+ processing event: Stored\("
,
r
"Selected worker:
\d+
, logit: "
,
r
"Selected worker:
worker_id=\d+ dp_rank=.*?
, logit: "
,
]
]
)
)
],
],
...
...
tests/serve/test_trtllm.py
View file @
f978f4d1
...
@@ -60,7 +60,7 @@ trtllm_configs = {
...
@@ -60,7 +60,7 @@ trtllm_configs = {
chat_payload_default
(
chat_payload_default
(
expected_log
=
[
expected_log
=
[
r
"Event processor for worker_id \d+ processing event: Stored\("
,
r
"Event processor for worker_id \d+ processing event: Stored\("
,
r
"Selected worker:
\d+
, logit: "
,
r
"Selected worker:
worker_id=\d+ dp_rank=.*?
, logit: "
,
]
]
)
)
],
],
...
...
tests/serve/test_vllm.py
View file @
f978f4d1
...
@@ -53,7 +53,7 @@ vllm_configs = {
...
@@ -53,7 +53,7 @@ vllm_configs = {
expected_log
=
[
expected_log
=
[
r
"ZMQ listener .* received batch with \d+ events \(seq=\d+\)"
,
r
"ZMQ listener .* received batch with \d+ events \(seq=\d+\)"
,
r
"Event processor for worker_id \d+ processing event: Stored\("
,
r
"Event processor for worker_id \d+ processing event: Stored\("
,
r
"Selected worker:
\d+
, logit: "
,
r
"Selected worker:
worker_id=\d+ dp_rank=.*?
, logit: "
,
]
]
)
)
],
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment