Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f978f4d1
Unverified
Commit
f978f4d1
authored
Oct 15, 2025
by
Yan Ru Pei
Committed by
GitHub
Oct 16, 2025
Browse files
feat: dp rank routing (#3597)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
29f5b822
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
183 additions
and
88 deletions
+183
-88
lib/llm/src/local_model/runtime_config.rs
lib/llm/src/local_model/runtime_config.rs
+24
-1
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+6
-15
lib/llm/src/mocker/protocols.rs
lib/llm/src/mocker/protocols.rs
+1
-1
lib/llm/src/mocker/scheduler.rs
lib/llm/src/mocker/scheduler.rs
+10
-10
lib/llm/src/protocols/common/preprocessor.rs
lib/llm/src/protocols/common/preprocessor.rs
+5
-0
lib/llm/tests/block_manager.rs
lib/llm/tests/block_manager.rs
+6
-0
lib/runtime/src/utils/worker_monitor.rs
lib/runtime/src/utils/worker_monitor.rs
+54
-26
tests/router/test_router_e2e_with_mockers.py
tests/router/test_router_e2e_with_mockers.py
+74
-32
tests/serve/test_sglang.py
tests/serve/test_sglang.py
+1
-1
tests/serve/test_trtllm.py
tests/serve/test_trtllm.py
+1
-1
tests/serve/test_vllm.py
tests/serve/test_vllm.py
+1
-1
No files found.
lib/llm/src/local_model/runtime_config.rs
View file @
f978f4d1
...
...
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use
crate
::
protocols
::
tensor
;
#[derive(Debug,
Default,
Clone,
Serialize,
Deserialize,
Eq,
PartialEq)]
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Eq,
PartialEq)]
pub
struct
ModelRuntimeConfig
{
pub
total_kv_blocks
:
Option
<
u64
>
,
...
...
@@ -19,6 +19,10 @@ pub struct ModelRuntimeConfig {
pub
reasoning_parser
:
Option
<
String
>
,
/// Total number of data parallel ranks for this worker (1 if DP not enabled)
#[serde(default
=
"default_data_parallel_size"
)]
pub
data_parallel_size
:
u32
,
/// Mapping of engine-specific runtime configs
#[serde(default,
skip_serializing_if
=
"HashMap::is_empty"
)]
pub
runtime_data
:
HashMap
<
String
,
serde_json
::
Value
>
,
...
...
@@ -34,6 +38,25 @@ pub struct ModelRuntimeConfig {
pub
tensor_model_config
:
Option
<
tensor
::
TensorModelConfig
>
,
}
const
fn
default_data_parallel_size
()
->
u32
{
1
}
impl
Default
for
ModelRuntimeConfig
{
fn
default
()
->
Self
{
Self
{
total_kv_blocks
:
None
,
max_num_seqs
:
None
,
max_num_batched_tokens
:
None
,
tool_call_parser
:
None
,
reasoning_parser
:
None
,
data_parallel_size
:
default_data_parallel_size
(),
runtime_data
:
HashMap
::
new
(),
tensor_model_config
:
None
,
}
}
}
impl
ModelRuntimeConfig
{
pub
fn
new
()
->
Self
{
Self
::
default
()
...
...
lib/llm/src/mocker/engine.rs
View file @
f978f4d1
...
...
@@ -124,7 +124,7 @@ impl MockVllmEngine {
let
scheduler
=
Scheduler
::
new
(
args
.clone
(),
Some
(
dp_rank
)
,
dp_rank
,
Some
(
output_tx
),
Some
(
kv_events_tx
),
// Pass the KV events sender to scheduler
Some
(
cancel_token
.clone
()),
...
...
@@ -283,6 +283,7 @@ impl MockVllmEngine {
let
event
=
KvCacheEvent
{
event_id
:
Uuid
::
new_v4
()
.as_u128
()
as
u64
,
data
:
event_data
,
dp_rank
,
};
// Publish the event
...
...
@@ -316,18 +317,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
)
->
Result
<
ManyOut
<
LLMEngineOutput
>
,
Error
>
{
let
(
request
,
ctx
)
=
input
.into_parts
();
// Extract dp_rank from annotations if present
let
dp_rank
=
request
.annotations
.iter
()
.find_map
(|
ann
|
{
if
ann
.starts_with
(
"dp_rank:"
)
{
ann
.strip_prefix
(
"dp_rank:"
)
.and_then
(|
s
|
s
.parse
()
.ok
())
}
else
{
None
}
})
.unwrap_or
(
0
);
// Extract dp_rank from request field (defaults to 0 if not set)
let
dp_rank
=
request
.dp_rank
.unwrap_or
(
0
);
// Validate dp_rank
if
dp_rank
>=
self
.engine_args.dp_size
{
...
...
@@ -348,7 +339,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
.expect
(
"max_output_tokens must be specified for mocker"
)
as
usize
,
uuid
:
Some
(
request_uuid
),
dp_rank
:
Some
(
dp_rank
)
,
dp_rank
,
};
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
OutputSignal
>
();
...
...
@@ -512,7 +503,7 @@ pub async fn make_mocker_engine(
args
:
MockEngineArgs
,
)
->
Result
<
crate
::
backend
::
ExecutionContext
,
Error
>
{
// Create the mocker engine
tracing
::
debug
!
(
"Creating mocker engine with config: {args:?}"
);
tracing
::
info
!
(
"Creating mocker engine with config: {args:?}"
);
let
annotated_engine
=
AnnotatedMockEngine
::
new
(
MockVllmEngine
::
new
(
args
),
distributed_runtime
,
endpoint_id
);
...
...
lib/llm/src/mocker/protocols.rs
View file @
f978f4d1
...
...
@@ -37,7 +37,7 @@ pub struct DirectRequest {
pub
tokens
:
Vec
<
Token
>
,
pub
max_output_tokens
:
usize
,
pub
uuid
:
Option
<
Uuid
>
,
pub
dp_rank
:
Option
<
u32
>
,
pub
dp_rank
:
u32
,
}
/// Represents the cost of prefilling content in the cache
...
...
lib/llm/src/mocker/scheduler.rs
View file @
f978f4d1
...
...
@@ -248,7 +248,7 @@ impl Scheduler {
/// Create a new Scheduler with the given parameters
pub
fn
new
(
args
:
MockEngineArgs
,
dp_rank
:
Option
<
u32
>
,
dp_rank
:
u32
,
output_tx
:
Option
<
mpsc
::
UnboundedSender
<
OutputSignal
>>
,
kv_events_tx
:
Option
<
mpsc
::
UnboundedSender
<
KvCacheEventData
>>
,
cancellation_token
:
Option
<
CancellationToken
>
,
...
...
@@ -280,7 +280,7 @@ impl Scheduler {
// Create channel for request handling
let
(
request_tx
,
mut
request_rx
)
=
mpsc
::
unbounded_channel
::
<
DirectRequest
>
();
let
mut
initial_metrics
=
ForwardPassMetrics
::
default
();
initial_metrics
.worker_stats.data_parallel_rank
=
dp_rank
;
initial_metrics
.worker_stats.data_parallel_rank
=
Some
(
dp_rank
)
;
let
(
metrics_tx
,
metrics_rx
)
=
tokio
::
sync
::
watch
::
channel
::
<
ForwardPassMetrics
>
(
initial_metrics
);
...
...
@@ -573,7 +573,7 @@ fn get_fwd_pass_metrics(
state
:
&
SchedulerState
,
kv_manager
:
&
KvManager
,
hit_rates
:
&
VecDeque
<
f32
>
,
dp_rank
:
Option
<
u32
>
,
dp_rank
:
u32
,
)
->
ForwardPassMetrics
{
// Get state metrics
let
request_active_slots
=
state
.decode
.len
()
as
u64
;
...
...
@@ -597,7 +597,7 @@ fn get_fwd_pass_metrics(
};
let
worker_stats
=
WorkerStats
{
data_parallel_rank
:
dp_rank
,
data_parallel_rank
:
Some
(
dp_rank
)
,
request_active_slots
,
request_total_slots
:
1024
,
// vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
num_requests_waiting
,
...
...
@@ -728,7 +728,7 @@ mod tests {
.unwrap
();
// Create scheduler with new args struct
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
let
scheduler
=
Scheduler
::
new
(
args
,
0
,
Some
(
output_tx
),
None
,
None
);
// Create shared tokens for caching case
let
shared_tokens
=
if
use_shared_tokens
{
...
...
@@ -759,7 +759,7 @@ mod tests {
tokens
:
input_tokens
,
max_output_tokens
,
uuid
:
None
,
dp_rank
:
None
,
dp_rank
:
0
,
};
scheduler
.receive
(
request
)
.await
;
}
...
...
@@ -853,7 +853,7 @@ mod tests {
.unwrap
();
// Create scheduler
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
let
scheduler
=
Scheduler
::
new
(
args
,
0
,
Some
(
output_tx
),
None
,
None
);
// Create identical tokens for all requests
let
identical_tokens
:
Vec
<
u32
>
=
(
0
..
token_length
)
.map
(|
i
|
i
as
u32
)
.collect
();
...
...
@@ -864,7 +864,7 @@ mod tests {
tokens
:
identical_tokens
.clone
(),
max_output_tokens
,
uuid
:
None
,
dp_rank
:
None
,
dp_rank
:
0
,
};
scheduler
.receive
(
request
)
.await
;
// Sleep for 0.1 second after each request
...
...
@@ -950,7 +950,7 @@ mod tests {
.unwrap
();
// Create scheduler
let
scheduler
=
Scheduler
::
new
(
args
,
None
,
Some
(
output_tx
),
None
,
None
);
let
scheduler
=
Scheduler
::
new
(
args
,
0
,
Some
(
output_tx
),
None
,
None
);
// Create request with 256 tokens
let
tokens
:
Vec
<
u32
>
=
(
0
..
input_tokens
)
.map
(|
i
|
i
as
u32
)
.collect
();
...
...
@@ -958,7 +958,7 @@ mod tests {
tokens
,
max_output_tokens
,
uuid
:
None
,
dp_rank
:
None
,
dp_rank
:
0
,
};
scheduler
.receive
(
request
)
.await
;
...
...
lib/llm/src/protocols/common/preprocessor.rs
View file @
f978f4d1
...
...
@@ -61,6 +61,11 @@ pub struct PreprocessedRequest {
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
disaggregated_params
:
Option
<
serde_json
::
Value
>
,
/// Data parallel rank for the request (used with data parallelism)
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
dp_rank
:
Option
<
u32
>
,
/// Additional arguments for extensibility
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
...
...
lib/llm/tests/block_manager.rs
View file @
f978f4d1
...
...
@@ -294,6 +294,7 @@ pub mod llm_kvbm {
let
event
=
KvCacheEvent
{
data
,
event_id
:
event_id_counter
,
dp_rank
:
0
,
};
let
router_event
=
RouterEvent
::
new
(
worker_identifier
as
i64
,
event
);
event_id_counter
+=
1
;
...
...
@@ -313,6 +314,7 @@ pub mod llm_kvbm {
block_hashes
:
vec!
[
ExternalSequenceBlockHash
(
sequence_hash
)],
}),
event_id
:
event_id_counter
,
dp_rank
:
0
,
};
let
router_event
=
RouterEvent
::
new
(
worker_identifier
as
i64
,
event
);
event_id_counter
+=
1
;
...
...
@@ -573,6 +575,7 @@ mod tests {
}],
parent_hash
:
None
,
}),
dp_rank
:
0
,
},
);
...
...
@@ -587,6 +590,7 @@ mod tests {
}],
parent_hash
:
None
,
}),
dp_rank
:
0
,
},
);
...
...
@@ -630,6 +634,7 @@ mod tests {
}],
parent_hash
:
None
,
}),
dp_rank
:
0
,
},
);
...
...
@@ -678,6 +683,7 @@ mod tests {
data
:
KvCacheEventData
::
Removed
(
KvCacheRemoveData
{
block_hashes
:
vec!
[
ExternalSequenceBlockHash
(
4
)],
}),
dp_rank
:
0
,
},
);
...
...
lib/runtime/src/utils/worker_monitor.rs
View file @
f978f4d1
...
...
@@ -26,34 +26,59 @@ struct LoadEvent {
#[derive(serde::Deserialize)]
struct
ForwardPassMetrics
{
worker_stats
:
WorkerStats
,
kv_stats
:
KvStats
,
}
#[derive(serde::Deserialize)]
struct
WorkerStats
{
data_parallel_rank
:
Option
<
u32
>
,
}
#[derive(serde::Deserialize)]
struct
KvStats
{
kv_active_blocks
:
u64
,
}
#[derive(serde::Deserialize)]
#[derive(serde::Deserialize
,
Clone
)]
struct
RuntimeConfig
{
total_kv_blocks
:
Option
<
u64
>
,
data_parallel_size
:
u32
,
}
/// Worker load monitoring state
#[derive(Clone,
Debug)]
/// Worker load monitoring state
per dp_rank
#[derive(Clone,
Debug
,
Default
)]
pub
struct
WorkerLoadState
{
pub
kv_active_blocks
:
Option
<
u64
>
,
pub
kv_total_blocks
:
Option
<
u64
>
,
pub
kv_active_blocks
:
HashMap
<
u32
,
u64
>
,
pub
kv_total_blocks
:
HashMap
<
u32
,
u64
>
,
}
impl
WorkerLoadState
{
/// Returns true if ALL dp_ranks (that have data in both maps) exceed the threshold
pub
fn
is_busy
(
&
self
,
threshold
:
f64
)
->
bool
{
match
(
self
.kv_active_blocks
,
self
.kv_total_blocks
)
{
(
Some
(
active
),
Some
(
total
))
if
total
>
0
=>
{
(
active
as
f64
)
>
(
threshold
*
total
as
f64
)
}
_
=>
false
,
// Get all dp_ranks that exist in both active and total blocks
let
common_dp_ranks
:
Vec
<
_
>
=
self
.kv_active_blocks
.keys
()
.filter
(|
dp_rank
|
self
.kv_total_blocks
.contains_key
(
dp_rank
))
.collect
();
// If no common dp_ranks, not busy
if
common_dp_ranks
.is_empty
()
{
return
false
;
}
// Check if ALL common dp_ranks exceed threshold
common_dp_ranks
.iter
()
.all
(|
&&
dp_rank
|
{
if
let
(
Some
(
&
active
),
Some
(
&
total
))
=
(
self
.kv_active_blocks
.get
(
&
dp_rank
),
self
.kv_total_blocks
.get
(
&
dp_rank
),
)
{
total
>
0
&&
(
active
as
f64
)
>
(
threshold
*
total
as
f64
)
}
else
{
false
}
})
}
}
...
...
@@ -97,9 +122,10 @@ impl WorkerMonitor {
"v1/mdc/"
,
// should be model_card::ROOT_PREFIX but wrong crate
key_extractors
::
lease_id
,
|
card
:
serde_json
::
Value
|
{
card
.get
(
"runtime_config"
)
.and_then
(|
rc
|
rc
.get
(
"total_kv_blocks"
))
.and_then
(|
t_kv
|
t_kv
.as_u64
())
let
runtime_config
:
Option
<
RuntimeConfig
>
=
card
.get
(
"runtime_config"
)
.and_then
(|
rc
|
serde_json
::
from_value
(
rc
.clone
())
.ok
());
runtime_config
},
component
.drt
()
.child_token
(),
)
...
...
@@ -132,13 +158,17 @@ impl WorkerMonitor {
let
mut
states
=
worker_load_states
.write
()
.unwrap
();
states
.retain
(|
lease_id
,
_
|
runtime_configs
.contains_key
(
lease_id
));
// Update worker load states with total blocks
for
(
lease_id
,
total_blocks
)
in
runtime_configs
.iter
()
{
let
state
=
states
.entry
(
*
lease_id
)
.or_insert
(
WorkerLoadState
{
kv_active_blocks
:
None
,
kv_total_blocks
:
None
,
});
state
.kv_total_blocks
=
Some
(
*
total_blocks
);
// Update worker load states with total blocks for all dp_ranks
for
(
lease_id
,
runtime_config
)
in
runtime_configs
.iter
()
{
let
state
=
states
.entry
(
*
lease_id
)
.or_default
();
// Populate total_blocks for all dp_ranks (they share the same total)
// data_parallel_size defaults to 1 via serde in ModelRuntimeConfig
if
let
Some
(
total_blocks
)
=
runtime_config
.total_kv_blocks
{
for
dp_rank
in
0
..
runtime_config
.data_parallel_size
{
state
.kv_total_blocks
.insert
(
dp_rank
,
total_blocks
);
}
}
}
}
...
...
@@ -152,14 +182,12 @@ impl WorkerMonitor {
if
let
Ok
(
load_event
)
=
serde_json
::
from_slice
::
<
LoadEvent
>
(
&
event
.payload
)
{
let
worker_id
=
load_event
.worker_id
;
let
active_blocks
=
load_event
.data.kv_stats.kv_active_blocks
;
let
dp_rank
=
load_event
.data.worker_stats.data_parallel_rank
.unwrap_or
(
0
);
// Update worker load state
// Update worker load state
per dp_rank
let
mut
states
=
worker_load_states
.write
()
.unwrap
();
let
state
=
states
.entry
(
worker_id
)
.or_insert
(
WorkerLoadState
{
kv_active_blocks
:
None
,
kv_total_blocks
:
None
,
});
state
.kv_active_blocks
=
Some
(
active_blocks
);
let
state
=
states
.entry
(
worker_id
)
.or_default
();
state
.kv_active_blocks
.insert
(
dp_rank
,
active_blocks
);
drop
(
states
);
// Recalculate all busy instances and update
...
...
tests/router/test_router_e2e_with_mockers.py
View file @
f978f4d1
...
...
@@ -298,6 +298,7 @@ async def send_request_via_python_kv_router(
worker_id
:
Optional
[
int
]
=
None
,
# If None, Router will select the best available worker
dp_rank
:
Optional
[
int
]
=
None
,
# Data parallel rank (defaults to 0)
):
"""Send a request to the specified mocker instance.
Returns True if mockers respond, otherwise raises or returns False.
...
...
@@ -324,6 +325,7 @@ async def send_request_via_python_kv_router(
output_options
=
output_options
,
router_config_override
=
router_config_override
,
worker_id
=
worker_id
,
dp_rank
=
dp_rank
,
)
if
stream
is
not
None
:
...
...
@@ -1314,33 +1316,38 @@ def test_query_instance_id_returns_worker_and_tokens(
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
model
(
MODEL_NAME
)
def
test_router_decisions
(
request
,
runtime_services
,
predownload_tokenizers
):
"""Validate KV cache prefix reuse by sending progressive requests with overlapping prefixes.
"""Validate KV cache prefix reuse
and dp_rank routing
by sending progressive requests with overlapping prefixes.
Flow:
- Start two mocker workers
sharing a namespace
.
- Start two mocker workers
, each with dp_size=4 (8 total dp ranks)
.
- Wait for workers to be ready.
- Send 4 progressive requests, each extending the previous tokens:
* Request 1: BLOCK_SIZE random tokens
* Request 2: Request 1 tokens + BLOCK_SIZE new random tokens
* Request 3: Request 2 tokens + BLOCK_SIZE new random tokens
* Request 4: Request 3 tokens + BLOCK_SIZE new random tokens
* Request 1: BLOCK_SIZE random tokens
(forced to specific worker_id and dp_rank=1)
* Request 2: Request 1 tokens + BLOCK_SIZE new random tokens
(naturally routed)
* Request 3: Request 2 tokens + BLOCK_SIZE new random tokens
(naturally routed)
* Request 4: Request 3 tokens + BLOCK_SIZE new random tokens
(naturally routed)
- Dump events from router and verify:
* All but one worker should have no events (one worker handles all due to prefix reuse)
* The worker with events should have exactly 4 events (one per request)
* All but one (worker_id, dp_rank) should have no events (due to prefix reuse)
* The (worker_id, dp_rank) with events should have exactly 4 events (one per request)
* All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse)
"""
# runtime_services starts etcd and nats
logger
.
info
(
"Starting test router prefix reuse and KV events synchronization"
)
# Create mocker args dictionary
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
# Create mocker args dictionary with dp_size=4
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"dp_size"
:
4
,
}
try
:
# Start mocker instances with the new CLI interface
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
mockers
=
MockerProcess
(
request
,
mocker_args
=
mocker_args
,
num_mockers
=
NUM_MOCKERS
# Start 2 mocker instances, each with dp_size=4 (8 total dp ranks)
logger
.
info
(
"Starting 2 mocker instances with dp_size=4 each (8 total dp ranks)"
)
mockers
=
MockerProcess
(
request
,
mocker_args
=
mocker_args
,
num_mockers
=
2
)
logger
.
info
(
f
"All mockers using endpoint:
{
mockers
.
endpoint
}
"
)
# Initialize mockers
mockers
.
__enter__
()
...
...
@@ -1363,9 +1370,19 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
# Use async to manage the test flow
async
def
test_sync
():
# Wait for workers to be ready and get their instance IDs
mocker_worker_ids
=
await
wait_for_mockers_ready
(
endpoint
,
kv_push_router
)
mocker_worker_ids
=
await
wait_for_mockers_ready
(
endpoint
,
kv_push_router
,
expected_num_workers
=
2
)
logger
.
info
(
f
"Workers ready:
{
mocker_worker_ids
}
"
)
# Use the first worker_id for forced routing
forced_worker_id
=
mocker_worker_ids
[
0
]
forced_dp_rank
=
1
logger
.
info
(
f
"Will force first request to worker_id=
{
forced_worker_id
}
, dp_rank=
{
forced_dp_rank
}
"
)
# Send 4 progressive requests with overlapping prefixes
cumulative_tokens
=
[]
...
...
@@ -1374,9 +1391,14 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
new_tokens
=
[
random
.
randint
(
1
,
10000
)
for
_
in
range
(
BLOCK_SIZE
)]
cumulative_tokens
.
extend
(
new_tokens
)
# Force first request to specific worker_id and dp_rank=1, let subsequent requests follow naturally
worker_id_override
=
forced_worker_id
if
i
==
0
else
None
dp_rank_override
=
forced_dp_rank
if
i
==
0
else
None
logger
.
info
(
f
"Sending request
{
i
+
1
}
/4 with
{
len
(
cumulative_tokens
)
}
tokens "
f
"(added
{
len
(
new_tokens
)
}
new tokens)"
f
"
{
f
' - FORCING worker_id=
{
worker_id_override
}
,
dp_rank
=
{
dp_rank_override
}
' if worker_id_override is not None else ''
}
"
)
await
send_request_via_python_kv_router
(
...
...
@@ -1388,6 +1410,8 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
"ignore_eos"
:
True
,
# Don't stop on EOS token
"max_tokens"
:
2
,
# Generate exactly 2 tokens
},
worker_id
=
worker_id_override
,
dp_rank
=
dp_rank_override
,
)
# Wait a bit between requests
...
...
@@ -1398,46 +1422,64 @@ def test_router_decisions(request, runtime_services, predownload_tokenizers):
# Dump events from the router
events_json
=
await
kv_push_router
.
dump_events
()
return
events_json
return
events_json
,
forced_worker_id
,
forced_dp_rank
# Run the async test
events_json
=
asyncio
.
run
(
test_sync
())
events_json
,
expected_worker_id
,
expected_dp_rank
=
asyncio
.
run
(
test_sync
())
# Parse events and count by worker
# Parse events and count by
(
worker
_id, dp_rank)
events
=
json
.
loads
(
events_json
)
events_by_worker
:
dict
[
int
,
list
[
Any
]]
=
{}
events_by_worker
_dp
:
dict
[
tuple
[
int
,
int
]
,
list
[
Any
]]
=
{}
for
event
in
events
:
worker_id
=
event
.
get
(
"worker_id"
)
if
worker_id
not
in
events_by_worker
:
events_by_worker
[
worker_id
]
=
[]
events_by_worker
[
worker_id
].
append
(
event
)
# Extract dp_rank from the event's KvCacheEvent
dp_rank
=
event
.
get
(
"event"
,
{}).
get
(
"dp_rank"
,
0
)
key
=
(
worker_id
,
dp_rank
)
if
key
not
in
events_by_worker_dp
:
events_by_worker_dp
[
key
]
=
[]
events_by_worker_dp
[
key
].
append
(
event
)
logger
.
info
(
f
"Events by worker:
{
[(
wid
,
len
(
evts
))
for
wid
,
evts
in
events_by_worker
.
items
()]
}
"
f
"Events by
(
worker
_id, dp_rank)
:
{
[(
key
,
len
(
evts
))
for
key
,
evts
in
events_by_worker
_dp
.
items
()]
}
"
)
# Verify: All but one worker should have no events
# Verify: All but one
(
worker
_id, dp_rank)
should have no events
workers_with_events
=
[
wid
for
wid
,
evts
in
events_by_worker
.
items
()
if
len
(
evts
)
>
0
key
for
key
,
evts
in
events_by_worker
_dp
.
items
()
if
len
(
evts
)
>
0
]
assert
len
(
workers_with_events
)
==
1
,
(
f
"Expected exactly 1 worker to have events (due to prefix reuse), "
f
"but found
{
len
(
workers_with_events
)
}
workers
with events:
{
workers_with_events
}
"
f
"Expected exactly 1
(
worker
_id, dp_rank)
to have events (due to prefix reuse), "
f
"but found
{
len
(
workers_with_events
)
}
with events:
{
workers_with_events
}
"
)
# Verify: The worker with events should have exactly 4 events
active_worker
=
workers_with_events
[
0
]
num_events
=
len
(
events_by_worker
[
active_worker
])
# Verify: The
(
worker
_id, dp_rank)
with events should have exactly 4 events
active_worker
_dp
=
workers_with_events
[
0
]
num_events
=
len
(
events_by_worker
_dp
[
active_worker
_dp
])
assert
num_events
==
4
,
(
f
"Expected worker
{
active_worker
}
to have exactly 4 events, "
f
"Expected
(
worker
_id, dp_rank)
{
active_worker
_dp
}
to have exactly 4 events, "
f
"but found
{
num_events
}
events"
)
# Verify: Both worker_id and dp_rank should match the forced values
active_worker_id
=
active_worker_dp
[
0
]
active_dp_rank
=
active_worker_dp
[
1
]
assert
active_worker_id
==
expected_worker_id
,
(
f
"Expected all events to have worker_id=
{
expected_worker_id
}
(forced in first request), "
f
"but found worker_id=
{
active_worker_id
}
"
)
assert
active_dp_rank
==
expected_dp_rank
,
(
f
"Expected all events to have dp_rank=
{
expected_dp_rank
}
(forced in first request), "
f
"but found dp_rank=
{
active_dp_rank
}
"
)
logger
.
info
(
f
"Successfully verified: Worker
{
active_worker
}
handled all 4 requests with prefix reuse. "
f
"Successfully verified: Worker
{
active_worker_id
}
dp_rank
{
active_dp_rank
}
handled all 4 requests with prefix reuse. "
f
"All events correctly routed to worker_id=
{
expected_worker_id
}
, dp_rank=
{
expected_dp_rank
}
as expected. "
f
"KV events synchronized correctly."
)
...
...
tests/serve/test_sglang.py
View file @
f978f4d1
...
...
@@ -69,7 +69,7 @@ sglang_configs = {
expected_log
=
[
r
"ZMQ listener .* received batch with \d+ events \(seq=\d+\)"
,
r
"Event processor for worker_id \d+ processing event: Stored\("
,
r
"Selected worker:
\d+
, logit: "
,
r
"Selected worker:
worker_id=\d+ dp_rank=.*?
, logit: "
,
]
)
],
...
...
tests/serve/test_trtllm.py
View file @
f978f4d1
...
...
@@ -60,7 +60,7 @@ trtllm_configs = {
chat_payload_default
(
expected_log
=
[
r
"Event processor for worker_id \d+ processing event: Stored\("
,
r
"Selected worker:
\d+
, logit: "
,
r
"Selected worker:
worker_id=\d+ dp_rank=.*?
, logit: "
,
]
)
],
...
...
tests/serve/test_vllm.py
View file @
f978f4d1
...
...
@@ -53,7 +53,7 @@ vllm_configs = {
expected_log
=
[
r
"ZMQ listener .* received batch with \d+ events \(seq=\d+\)"
,
r
"Event processor for worker_id \d+ processing event: Stored\("
,
r
"Selected worker:
\d+
, logit: "
,
r
"Selected worker:
worker_id=\d+ dp_rank=.*?
, logit: "
,
]
)
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment